From 74f7a0c9d6ea5b3a6d37dd61d0a83557a90b1d03 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 12 Dec 2023 19:02:51 -0800 Subject: [PATCH 001/283] Upstream the ONNX importer. (#2636) This is part 1 of 2, which will also include upstreaming the FX importer. I started with ONNX because it forces some project layout updates and is more self contained/easier as a first step. Deviating somewhat from the RFCs on project layout, I made the following decisions: * Locating the `onnx_importer.py` into `torch_mlir.extras` as Maks already has opened up that namespace and it seemed to fit. Better to have fewer things at that level. * Setup the build so that the root project only contains MLIR Python and pure Python deps (like the importers), but this can be augmented with the `projects/` adding more depending on which features are enabled. * The default build continues to build everything whereas in `TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS=1` mode, it builds a `torch-mlir-core` wheel with the pure contents only. `onnx_importer.py` and `importer_smoke_test.py` are almost verbatim copies from SHARK-Turbine. I made some minor local alterations to adapt to paths and generalize the way they interact with the outer project. I expect I can copy these back to Turbine verbatim from here. I also updated the license boilerplate (they have the same license but slightly different project norms for the headers) but retained the correct copyright. Other updates: * Added the ONNX importer unit test (which also can generate test data) in lit, conditioned on the availability of the Python `onnx` package. In a followup once I know everything is stable, I'll add another env var that the CI can set to always enable this so we know conclusively if tests pass. * Moved the ONNX conversion readme to `docs/`. * Renamed CMake option `TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS` -> `TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS` and inverted the sense. Made the JitIR importer and LTC options `cmake_dependent_options` for robustness. --- CMakeLists.txt | 37 +- .../python_deploy/build_linux_packages.sh | 2 +- .../importers/onnx_importer.md | 16 +- projects/CMakeLists.txt | 26 +- projects/pt1/python/CMakeLists.txt | 128 +--- .../torch_mlir/jit_ir_importer/CMakeLists.txt | 4 +- python/CMakeLists.txt | 94 +++ .../pt1/python => python}/TorchMLIRModule.cpp | 0 .../torch_mlir/dialects/TorchBinding.td | 0 .../torch_mlir/dialects/torch/__init__.py | 0 python/torch_mlir/extras/onnx_importer.py | 607 ++++++++++++++++++ setup.py | 41 +- test-requirements.txt | 1 + test/python/lit.local.cfg | 2 + test/python/onnx_importer/.gitignore | 1 + .../onnx_importer/_torch_mlir_config.py | 19 + .../python/onnx_importer/import_smoke_test.py | 374 +++++++++++ test/python/onnx_importer/lit.local.cfg | 5 + 18 files changed, 1208 insertions(+), 149 deletions(-) rename include/torch-mlir/Conversion/TorchOnnxToTorch/README.md => docs/importers/onnx_importer.md (90%) create mode 100644 python/CMakeLists.txt rename {projects/pt1/python => python}/TorchMLIRModule.cpp (100%) rename {projects/pt1/python => python}/torch_mlir/dialects/TorchBinding.td (100%) rename {projects/pt1/python => python}/torch_mlir/dialects/torch/__init__.py (100%) create mode 100644 python/torch_mlir/extras/onnx_importer.py create mode 100644 test/python/lit.local.cfg create mode 100644 test/python/onnx_importer/.gitignore create mode 100644 test/python/onnx_importer/_torch_mlir_config.py create mode 100644 test/python/onnx_importer/import_smoke_test.py create mode 100644 test/python/onnx_importer/lit.local.cfg diff --git a/CMakeLists.txt b/CMakeLists.txt index ccbe7ccb3a98..f821d60034c8 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -25,6 +25,8 @@ project(torch-mlir LANGUAGES CXX C) set(CMAKE_C_STANDARD 11) set(CMAKE_CXX_STANDARD 17) +include(CMakeDependentOption) + #------------------------------------------------------------------------------- # Project options #------------------------------------------------------------------------------- @@ -43,24 +45,11 @@ endif() option(TORCH_MLIR_OUT_OF_TREE_BUILD "Specifies an out of tree build" OFF) -# PT1 options. -option(TORCH_MLIR_ENABLE_PROJECT_PT1 "Enables the PyTorch1 project under projects/pt1" OFF) -# TODO: Rename/scope these. They use historic names for now to ease migration -# burden. -option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON) -option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF) -option(TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS "Build Torch dialect MLIR Python bindings but neither JIT IR Importer nor LTC backend" OFF) -if(TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS) - set(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OFF) - set(TORCH_MLIR_ENABLE_LTC OFF) -endif() -# Force enable the PT1 project if either the JIT_IR_IMPORTER or LTC is enabled. -if(NOT TORCH_MLIR_ENABLE_PROJECT_PT1) - if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OR TORCH_MLIR_ENABLE_LTC) - message(STATUS "Enabling projects/pt1 because features requiring it are enabled") - set(TORCH_MLIR_ENABLE_PROJECT_PT1 ON) - endif() -endif() +# PyTorch native extension gate. If OFF, then no features which depend on +# native extensions will be built. +option(TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS "Enables PyTorch native extension features" ON) +cmake_dependent_option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS OFF) +cmake_dependent_option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS OFF) #------------------------------------------------------------------------------- # Configure out-of-tree vs in-tree build @@ -235,4 +224,16 @@ endif() # Sub-projects #------------------------------------------------------------------------------- +# Sub-projects can bundle additional PyTorch extensions by adding them to this +# source target. It is typically empty unless if features are enabled. +if(MLIR_ENABLE_BINDINGS_PYTHON) + declare_mlir_python_sources(TorchMLIRPythonTorchExtensionsSources) +endif() + +# Build projects first as it may populate additional Python deps. add_subdirectory(projects) + +# Finish with top-level Python bindings so it can handle additional deps. +if(MLIR_ENABLE_BINDINGS_PYTHON) + add_subdirectory(python) +endif() \ No newline at end of file diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index 2a909266f43a..f0336b2a1a4b 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -351,7 +351,6 @@ function setup_venv() { echo ":::: Using stable dependencies" python3 -m pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/build-requirements.txt - python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/test-requirements.txt ;; *) echo "Unrecognized torch version '$torch_version'" @@ -359,6 +358,7 @@ function setup_venv() { ;; esac + python3 -m pip install --no-cache-dir -r /main_checkout/torch-mlir/test-requirements.txt } function build_out_of_tree() { diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/README.md b/docs/importers/onnx_importer.md similarity index 90% rename from include/torch-mlir/Conversion/TorchOnnxToTorch/README.md rename to docs/importers/onnx_importer.md index 6de1cc923411..acc45bb2e602 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/README.md +++ b/docs/importers/onnx_importer.md @@ -3,11 +3,8 @@ We enable the direct representation of many ONNX features directly in the `torch` dialect as `torch.operator` custom ops with names like `onnx.{OperatorName}`. The majority of ONNX operators are represented -with a systematic transformation. See -[onnx_importer.py](https://github.com/nod-ai/SHARK-Turbine/blob/main/python/shark_turbine/importers/onnx_importer.py) -for the reference importer which complies with the rules below -(this is planned to be upstreamed to torch-mlir proper in the near -future). +with a systematic transformation. `torch_mlir.extras.onnx_importer` +for the reference importer which complies with the rules below. ## Adding new ONNX operators @@ -26,10 +23,11 @@ are relatively straight-forward to map, following this general procedure: * Open the corresponding implementation file `DefaultDomainXtoY.cpp` corresponding with the alphabetic sort of the op and add a conversion. * Generate successful test cases: - * Either run the Turbine importer to produce MLIR output for all - ops/models in the ONNX test suite or use a dump that someone has - generated: - * [2023-Nov-21](https://drive.google.com/file/d/1P6QaRXGnCeApjdjNmykLxWa-yqMmIO-d/view?usp=sharing) + * All `onnx_importer.py` tests are dumped to the test temp dir (success + or failure). This is typically located under + `tools/torch-mlir/test/python/onnx_importer/Output`. The `.mlir` files + under there should provide good variants to drive lit test coverage of + conversion. * There are often many variants of tests for checking conformance of different historic ONNX encodings, but these are often not load bearing at the MLIR level. diff --git a/projects/CMakeLists.txt b/projects/CMakeLists.txt index 4b54be65a79d..d4fead890269 100644 --- a/projects/CMakeLists.txt +++ b/projects/CMakeLists.txt @@ -1,7 +1,31 @@ include(AddMLIRPython) +################################################################################ +# PyTorch # Configure PyTorch if we have any features enabled which require it. +################################################################################ if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER OR TORCH_MLIR_ENABLE_LTC) + + if (NOT TORCH_MLIR_USE_INSTALLED_PYTORCH) + # Source builds + message(STATUS "Building libtorch from source (features depend on it and NOT TORCH_MLIR_USE_INSTALLED_PYTORCH)") + set(ENV{TORCH_MLIR_SRC_PYTORCH_REPO} ${TORCH_MLIR_SRC_PYTORCH_REPO}) + set(ENV{TORCH_MLIR_SRC_PYTORCH_BRANCH} ${TORCH_MLIR_SRC_PYTORCH_BRANCH}) + set(ENV{TM_PYTORCH_INSTALL_WITHOUT_REBUILD} ${TM_PYTORCH_INSTALL_WITHOUT_REBUILD}) + set(ENV{MACOSX_DEPLOYMENT_TARGET} ${MACOSX_DEPLOYMENT_TARGET}) + set(ENV{CMAKE_OSX_ARCHITECTURES} ${CMAKE_OSX_ARCHITECTURES}) + set(ENV{CMAKE_C_COMPILER_LAUNCHER} ${CMAKE_C_COMPILER_LAUNCHER}) + set(ENV{CMAKE_CXX_COMPILER_LAUNCHER} ${CMAKE_CXX_COMPILER_LAUNCHER}) + execute_process( + COMMAND ${TORCH_MLIR_SOURCE_DIR}/build_tools/build_libtorch.sh + RESULT_VARIABLE _result + ) + if(_result) + message(FATAL_ERROR "Failed to run `build_libtorch.sh`") + endif() + set(TORCH_INSTALL_PREFIX "libtorch") + endif() + message(STATUS "Enabling PyTorch C++ dep (features depend on it)") include(TorchMLIRPyTorch) @@ -48,6 +72,6 @@ if(TORCH_MLIR_ENABLE_LTC) endif() # Include overall PT1 project. -if(TORCH_MLIR_ENABLE_PROJECT_PT1) +if(TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS) add_subdirectory(pt1) endif() diff --git a/projects/pt1/python/CMakeLists.txt b/projects/pt1/python/CMakeLists.txt index e951772df935..ce40426988a7 100644 --- a/projects/pt1/python/CMakeLists.txt +++ b/projects/pt1/python/CMakeLists.txt @@ -7,79 +7,22 @@ set(CMAKE_PLATFORM_NO_VERSIONED_SONAME ON) # argument. set(TORCH_MLIR_PYTHON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/torch_mlir") - # We vendor our own MLIR instance in the `torch_mlir` namespace. add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=torch_mlir.") -################################################################################ -# PyTorch -################################################################################ - -if (NOT TORCH_MLIR_USE_INSTALLED_PYTORCH) - # Source builds - set(ENV{TORCH_MLIR_SRC_PYTORCH_REPO} ${TORCH_MLIR_SRC_PYTORCH_REPO}) - set(ENV{TORCH_MLIR_SRC_PYTORCH_BRANCH} ${TORCH_MLIR_SRC_PYTORCH_BRANCH}) - set(ENV{TM_PYTORCH_INSTALL_WITHOUT_REBUILD} ${TM_PYTORCH_INSTALL_WITHOUT_REBUILD}) - set(ENV{MACOSX_DEPLOYMENT_TARGET} ${MACOSX_DEPLOYMENT_TARGET}) - set(ENV{CMAKE_OSX_ARCHITECTURES} ${CMAKE_OSX_ARCHITECTURES}) - set(ENV{CMAKE_C_COMPILER_LAUNCHER} ${CMAKE_C_COMPILER_LAUNCHER}) - set(ENV{CMAKE_CXX_COMPILER_LAUNCHER} ${CMAKE_CXX_COMPILER_LAUNCHER}) - execute_process( - COMMAND ${TORCH_MLIR_SOURCE_DIR}/build_tools/build_libtorch.sh - RESULT_VARIABLE _result - ) - if(_result) - message(FATAL_ERROR "Failed to run `build_libtorch.sh`") - endif() - set(TORCH_INSTALL_PREFIX "libtorch") -endif() - -################################################################################ -# Sources -################################################################################ - -declare_mlir_python_sources(TorchMLIRPythonSources) -declare_mlir_python_sources(TorchMLIRPythonExtensions) - -if (NOT TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS) - declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel - ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" - ADD_TO_PARENT TorchMLIRPythonSources - SOURCES - __init__.py - _dynamo_fx_importer.py - compiler_utils.py - dynamo.py - _version.py - ) -endif() - -declare_mlir_python_sources(TorchMLIRPythonSources.Dialects - ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" - ADD_TO_PARENT TorchMLIRPythonSources -) +# ################################################################################ +# # Sources +# ################################################################################ -declare_mlir_dialect_python_bindings( - ADD_TO_PARENT TorchMLIRPythonSources.Dialects +declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" - TD_FILE dialects/TorchBinding.td - SOURCES dialects/torch/__init__.py - DIALECT_NAME torch -) - -################################################################################ -# Extensions -################################################################################ - -declare_mlir_python_extension(TorchMLIRPythonExtensions.Main - MODULE_NAME _torchMlir - ADD_TO_PARENT TorchMLIRPythonExtensions + ADD_TO_PARENT TorchMLIRPythonTorchExtensionsSources SOURCES - TorchMLIRModule.cpp - EMBED_CAPI_LINK_LIBS - TorchMLIRCAPI - PRIVATE_LINK_LIBS - LLVMSupport + __init__.py + _dynamo_fx_importer.py + compiler_utils.py + dynamo.py + _version.py ) ################################################################################ @@ -110,56 +53,23 @@ endif() # add_subdirectory(torch_mlir/_torch_mlir_custom_op_example) -################################################################################ -# Generate packages and shared library -# Downstreams typically will not use these, but they are useful for local -# testing. -################################################################################ - -set(_source_components - # TODO: Core is now implicitly building/registering all dialects, increasing - # build burden by ~5x. Make it stop. - # TODO: Reduce dependencies. We need ExecutionEngine and a bunch of passes - # for the reference backend, but logically they can be separate. But seemingly - # the only way to handle that is to create a separate mlir python package - # tree, which seems excessive. - MLIRPythonSources - MLIRPythonExtension.Core - MLIRPythonExtension.RegisterEverything - TorchMLIRPythonSources - TorchMLIRPythonExtensions -) - -add_mlir_python_common_capi_library(TorchMLIRAggregateCAPI - INSTALL_COMPONENT TorchMLIRPythonModules - INSTALL_DESTINATION python_packages/torch_mlir/torch_mlir/_mlir_libs - OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs" - RELATIVE_INSTALL_ROOT "../../../.." - DECLARED_SOURCES ${_source_components} -) - -add_mlir_python_modules(TorchMLIRPythonModules - ROOT_PREFIX "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir" - INSTALL_PREFIX "python_packages/torch_mlir/torch_mlir" - DECLARED_SOURCES ${_source_components} - COMMON_CAPI_LINK_LIBS - TorchMLIRAggregateCAPI - ) - # TODO: Find a cleaner way to do this. # Can we build the JIT IR importer with `declare_mlir_python_extension`? # Then it would "just work". if(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER) - add_dependencies(TorchMLIRPythonModules TorchMLIRJITIRImporter) - add_dependencies(TorchMLIRPythonModules TorchMLIRJITIRImporterPybind) - # Build the E2E Tests (which depend on the JIT IR importer now). - add_dependencies(TorchMLIRPythonModules TorchMLIRE2ETestPythonModules) + add_dependencies(TorchMLIRPythonTorchExtensionsSources + TorchMLIRJITIRImporter + TorchMLIRJITIRImporterPybind + TorchMLIRE2ETestPythonModules + ) endif() if(TORCH_MLIR_ENABLE_LTC) # Add Torch-MLIR LTC backend as dependency - add_dependencies(TorchMLIRPythonModules torch_mlir_ltc_backend) - add_dependencies(TorchMLIRPythonModules reference_lazy_backend) + add_dependencies(TorchMLIRPythonTorchExtensionsSources + torch_mlir_ltc_backend + reference_lazy_backend + ) endif() add_subdirectory(test) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/CMakeLists.txt b/projects/pt1/python/torch_mlir/jit_ir_importer/CMakeLists.txt index c2883b3dca84..6c2ccf62eb78 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/CMakeLists.txt +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/CMakeLists.txt @@ -4,9 +4,9 @@ ## Declare the sources of the Python module. -declare_mlir_python_sources(TorchMLIRPythonSources.JitIRImporter +declare_mlir_python_sources(TorchMLIRPythonTorchExtensionsSources.JitIRImporter ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" - ADD_TO_PARENT TorchMLIRPythonSources + ADD_TO_PARENT TorchMLIRPythonTorchExtensionsSources SOURCES_GLOB jit_ir_importer/*.py ) diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt new file mode 100644 index 000000000000..7b9bf12f2b8f --- /dev/null +++ b/python/CMakeLists.txt @@ -0,0 +1,94 @@ +# Disables generation of "version soname" (i.e. libFoo.so.), which +# causes pure duplication as part of Python wheels. +set(CMAKE_PLATFORM_NO_VERSIONED_SONAME ON) + +# The directory at which the Python import tree begins. +# See documentation for `declare_mlir_python_sources`'s ROOT_DIR +# argument. +set(TORCH_MLIR_PYTHON_ROOT_DIR "${CMAKE_CURRENT_SOURCE_DIR}/torch_mlir") + + +# We vendor our own MLIR instance in the `torch_mlir` namespace. +add_compile_definitions("MLIR_PYTHON_PACKAGE_PREFIX=torch_mlir.") + +################################################################################ +# Sources +################################################################################ + +declare_mlir_python_sources(TorchMLIRPythonSources) +declare_mlir_python_sources(TorchMLIRPythonExtensions) + +declare_mlir_python_sources(TorchMLIRPythonSources.Dialects + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TorchMLIRPythonSources +) + +declare_mlir_dialect_python_bindings( + ADD_TO_PARENT TorchMLIRPythonSources.Dialects + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + TD_FILE dialects/TorchBinding.td + SOURCES dialects/torch/__init__.py + DIALECT_NAME torch +) + +declare_mlir_python_sources(TorchMLIRPythonSources.Importers + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TorchMLIRPythonSources + SOURCES + extras/onnx_importer.py +) + +################################################################################ +# Extensions +################################################################################ + +declare_mlir_python_extension(TorchMLIRPythonExtensions.Main + MODULE_NAME _torchMlir + ADD_TO_PARENT TorchMLIRPythonExtensions + SOURCES + TorchMLIRModule.cpp + EMBED_CAPI_LINK_LIBS + TorchMLIRCAPI + PRIVATE_LINK_LIBS + LLVMSupport +) + +################################################################################ +# Generate packages and shared library +# Downstreams typically will not use these, but they are useful for local +# testing. +################################################################################ + +set(_source_components + # TODO: Core is now implicitly building/registering all dialects, increasing + # build burden by ~5x. Make it stop. + # TODO: Reduce dependencies. We need ExecutionEngine and a bunch of passes + # for the reference backend, but logically they can be separate. But seemingly + # the only way to handle that is to create a separate mlir python package + # tree, which seems excessive. + MLIRPythonSources + MLIRPythonExtension.Core + MLIRPythonExtension.RegisterEverything + TorchMLIRPythonSources + TorchMLIRPythonExtensions + + # Sources related to optional Torch extension dependent features. Typically + # empty unless if project features are enabled. + TorchMLIRPythonTorchExtensionsSources +) + +add_mlir_python_common_capi_library(TorchMLIRAggregateCAPI + INSTALL_COMPONENT TorchMLIRPythonModules + INSTALL_DESTINATION python_packages/torch_mlir/torch_mlir/_mlir_libs + OUTPUT_DIRECTORY "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir/_mlir_libs" + RELATIVE_INSTALL_ROOT ".." + DECLARED_SOURCES ${_source_components} +) + +add_mlir_python_modules(TorchMLIRPythonModules + ROOT_PREFIX "${TORCH_MLIR_PYTHON_PACKAGES_DIR}/torch_mlir/torch_mlir" + INSTALL_PREFIX "python_packages/torch_mlir/torch_mlir" + DECLARED_SOURCES ${_source_components} + COMMON_CAPI_LINK_LIBS + TorchMLIRAggregateCAPI + ) diff --git a/projects/pt1/python/TorchMLIRModule.cpp b/python/TorchMLIRModule.cpp similarity index 100% rename from projects/pt1/python/TorchMLIRModule.cpp rename to python/TorchMLIRModule.cpp diff --git a/projects/pt1/python/torch_mlir/dialects/TorchBinding.td b/python/torch_mlir/dialects/TorchBinding.td similarity index 100% rename from projects/pt1/python/torch_mlir/dialects/TorchBinding.td rename to python/torch_mlir/dialects/TorchBinding.td diff --git a/projects/pt1/python/torch_mlir/dialects/torch/__init__.py b/python/torch_mlir/dialects/torch/__init__.py similarity index 100% rename from projects/pt1/python/torch_mlir/dialects/torch/__init__.py rename to python/torch_mlir/dialects/torch/__init__.py diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py new file mode 100644 index 000000000000..a9dd52253601 --- /dev/null +++ b/python/torch_mlir/extras/onnx_importer.py @@ -0,0 +1,607 @@ +# Based on code Copyright (c) Advanced Micro Devices, Inc. +# +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +"""Imports ONNX graphs to `torch` dialect ops. + +See documentation: + https://github.com/llvm/torch-mlir/blob/main/docs/importers/onnx_importer.md + +This file is distributed/forked verbatim into various downstream projects, and +it must abide by several rules above and beyond the rest of the codebase: + - It must be standalone, only depending on: + - `onnx` + - `..ir` relative imports to the main IR directory + - `..dialects.func` relative import to the `func` dialect (TODO: + we are looking to eliminate this dep). + - Python standard library + - It does not directly use the ODS generated `torch` dialect Python + wrappers. This allows it to be used in contexts that only build a C++ + compiler with minimal IR Python bindings. + - It is intended as an enabler for full onnx compilation, only handling + the import from ONNX -> the `torch` dialect. Testing, full pipelines, + and utilities belong elsewhere. +""" + +try: + import onnx +except ModuleNotFoundError as e: + raise ModuleNotFoundError( + "The onnx package (`pip install onnx`) is required to use the onnx importer" + ) from e + +from typing import Optional + +from dataclasses import dataclass + +import numpy as np + +from ..ir import ( + ArrayAttr, + Attribute, + Block, + Context, + DenseElementsAttr, + DenseResourceElementsAttr, + DictAttr, + FloatAttr, + BF16Type, + ComplexType, + F16Type, + F32Type, + F64Type, + Float8E4M3FNType, + Float8E5M2FNUZType, + Float8E5M2Type, + FunctionType, + InsertionPoint, + IntegerAttr, + IntegerType, + MLIRError, + RankedTensorType, + Location, + Module, + Operation, + StringAttr, + Type as IrType, + Value, +) + +from ..dialects import ( + func as func_dialect, +) + +@dataclass +class Config: + """Various configuration settings for the importer.""" + + # Ancient ONNX exporters would often add a model input for anything that + # might be mutable, providing an initializer for it as well. More modern + # tools tools realized this is a really bad idea for a lot of reasons. + # We choose to assume more recent norms, even if encountering older + # models. Setting this to False probably won't do what you want but + # should produce interesting errors to waste your time deciphering. + # We mainly use it as a way to document in the code that we are + # making an assumption. + elide_initialized_inputs: bool = True + + +class ModelInfo: + """Top-level accounting and accessors for an ONNX model.""" + + def __init__(self, model_proto: onnx.ModelProto, *, config: Config = Config()): + self.config = config + self.model_proto = model_proto + assert model_proto.graph, "Model must contain a main Graph" + self.main_graph = GraphInfo(self, model_proto.graph) + + def create_module(self, context: Optional[Context] = None) -> Operation: + if not context: + context = Context() + module_op = Module.create(Location.unknown(context)).operation + # TODO: Populate module level metadata from the ModelProto + return module_op + + +class GraphInfo: + """Information about a Graph within a model.""" + + def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto): + self.model_info = model_info + self.graph_proto = graph_proto + self.initializer_map: dict[str, onnx.TensorProto] = { + n.name: n for n in graph_proto.initializer + } + self.value_info_map: dict[str, onnx.ValueInfoProto] = { + n.name: n for n in graph_proto.value_info + } + self.declared_input_map: dict[str, onnx.ValueInfoProto] = { + n.name: n for n in graph_proto.input + } + self.output_map: dict[str, onnx.ValueInfoProto] = { + n.name: n for n in graph_proto.output + } + + # Generate the effective input map, which for old models can be a + # subset of the input map. + if model_info.config.elide_initialized_inputs: + self.input_map = { + k: v + for k, v in self.declared_input_map.items() + if k not in self.initializer_map + } + else: + self.input_map = self.declared_input_map + illegal_input_keys = self.input_map.keys() - ( + self.input_map.keys() - self.initializer_map.keys() + ) + assert self.input_map.keys().isdisjoint(self.initializer_map.keys()), ( + f"When not in elide_initialized_inputs=True, we expect inputs to not " + f"have an initial value (got {illegal_input_keys})." + ) + + def find_type_proto_for_name(self, name: str) -> onnx.TypeProto: + # Node outputs don't typically have type information, but shape inference + # will associate them in the value_info. If not there, it may be a + # graph output, which must have type information. + value_info = self.value_info_map.get(name) or self.output_map.get(name) + if value_info is not None: + return value_info.type + raise OnnxImportError( + f"No type information associated with '{name}'. Run shape inference?" + ) + + +class OnnxImportError(Exception): + ... + + +class NodeImporter: + """Imports graph nodes into MLIR. + + Typically, the top level graph will be imported into a func whereas dependent + graphs may just be imported with references to pre-existing values. + + Note that ONNX requires that graphs be sorted topologically and free of cycles, + so we don't take any special steps to order them for dominance. + """ + + __slots__ = [ + "_c", + "_cc", + "_gi", + "_p", + "_b", + "_nv_map", + ] + + def __init__( + self, + graph_info: GraphInfo, + *, + parent_op: Operation, + block: Block, + context_cache: "ContextCache", + ): + self._c = parent_op.context + self._cc = context_cache + self._gi = graph_info + self._p = parent_op + self._b = block + self._nv_map: dict[str, Value] = {} + + @classmethod + def define_function( + cls, graph_info: GraphInfo, module_op: Operation + ) -> "NodeImporter": + cc = ContextCache(module_op.context) + with module_op.context, Location.name(f"graph:{graph_info.graph_proto.name}"): + body = module_op.regions[0].blocks[0] + func_name = graph_info.graph_proto.name + input_types = [ + cc.type_proto_to_type(inp.type) for inp in graph_info.input_map.values() + ] + output_types = [ + cc.type_proto_to_type(out.type) + for out in graph_info.output_map.values() + ] + ftype = FunctionType.get(input_types, output_types) + func_op = func_dialect.FuncOp(func_name, ftype, ip=InsertionPoint(body)) + block = func_op.add_entry_block( + [Location.name(k) for k in graph_info.input_map.keys()] + ) + imp = NodeImporter(graph_info, parent_op=func_op, block=block, context_cache=cc) + for node_name, input_value in zip(graph_info.input_map.keys(), block.arguments): + imp._nv_map[node_name] = input_value + imp._populate_graph_attrs(func_op) + return imp + + def _populate_graph_attrs(self, container_op: Operation): + """Populates graph level meta attributes on the given container op.""" + m = self._gi.model_info.model_proto + with container_op.context: + i64_type = IntegerType.get_signed(64) + default_opset_version = 0 + opset_versions: dict[str, IntegerAttr] = {} + for opset_import in m.opset_import: + if opset_import.domain: + opset_versions[opset_import.domain] = IntegerAttr.get( + i64_type, opset_import.version + ) + else: + default_opset_version = opset_import.version + if default_opset_version: + container_op.attributes[ + "torch.onnx_meta.opset_version" + ] = IntegerAttr.get(i64_type, default_opset_version) + if opset_versions: + container_op.attributes[ + "torch.onnx_meta.opset_versions" + ] = DictAttr.get(opset_versions) + container_op.attributes["torch.onnx_meta.ir_version"] = IntegerAttr.get( + IntegerType.get_signed(64), m.ir_version + ) + container_op.attributes["torch.onnx_meta.producer_name"] = StringAttr.get( + m.producer_name + ) + container_op.attributes[ + "torch.onnx_meta.producer_version" + ] = StringAttr.get(m.producer_version) + + def import_all(self): + """Imports all nodes topologically.""" + # TODO: Consider pulling in initializers on demand since there can be so + # much unused crap. + for init in self._gi.initializer_map.values(): + self.import_initializer(init) + for node in self._gi.graph_proto.node: + self.import_node(node) + + outputs = [] + for output_name in self._gi.output_map.keys(): + try: + outputs.append(self._nv_map[output_name]) + except KeyError: + raise OnnxImportError( + f"Non topologically produced ONNX graph output '{output_name}'" + ) + with InsertionPoint(self._b), Location.unknown(): + func_dialect.ReturnOp(outputs) + + def import_node(self, node: onnx.NodeProto): + with InsertionPoint(self._b), Location.name(node.name): + op_type = node.op_type + # Handle special op types that materialize to non-op IR constructs. + special_key = f"_handle_node_{op_type}" + if hasattr(self, special_key): + getattr(self, special_key)(node) + return + + # General node import. + input_values = [] + for input_name in node.input: + try: + input_values.append(self._nv_map[input_name]) + except KeyError: + raise OnnxImportError( + f"Non topologically produced ONNX node input '{input_name}': {node}" + ) + + output_names = list(node.output) + output_types = [ + self._cc.type_proto_to_type(self._gi.find_type_proto_for_name(n)) + for n in output_names + ] + + # TODO: Attributes. + attrs = { + "name": StringAttr.get(f"onnx.{op_type}"), + } + self.import_attributes(node.attribute, attrs) + custom_op = Operation.create( + name="torch.operator", + results=output_types, + operands=input_values, + attributes=attrs, + ) + for output_name, output_value in zip(output_names, custom_op.results): + self._nv_map[output_name] = output_value + + def import_attributes( + self, onnx_attrs: list[onnx.AttributeProto], attrs: dict[str, Attribute] + ): + for onnx_attr in onnx_attrs: + attr_type = onnx_attr.type + if attr_type not in ATTRIBUTE_TYPE_HANDLERS: + raise OnnxImportError( + f"Unhandled ONNX attribute type code {attr_type}: {onnx_attr}" + ) + handler = ATTRIBUTE_TYPE_HANDLERS[attr_type] + if handler is None: + # Active skip. + continue + elif handler is False: + # Active error. + raise OnnxImportError( + f"ONNX importer does not support generic node attribute type {attr_type}. " + f"This likely means that this is a special node which requires specific " + f"handling in the importer: {onnx_attr}" + ) + attrs[f"torch.onnx.{onnx_attr.name}"] = handler(onnx_attr, self._cc) + + def import_initializer(self, initializer: onnx.TensorProto) -> Value: + with InsertionPoint(self._b), Location.name(initializer.name): + value_attr = self._cc.tensor_proto_to_attr(initializer) + vtensor_type = self._cc.tensor_proto_to_type(initializer) + literal_op = Operation.create( + name="torch.vtensor.literal", + results=[vtensor_type], + attributes={"value": value_attr}, + ) + self._nv_map[initializer.name] = literal_op.result + return literal_op.result + + def _get_immediate_tensor(self, name: str) -> np.array: + try: + initializer = self._gi.initializer_map[name] + except KeyError: + raise OnnxImportError( + f"An immediate value for '{name}' was required but it is dynamically produced." + ) + try: + dtype = ELEM_TYPE_TO_NUMPY_DTYPE[initializer.data_type] + except KeyError: + raise OnnxImportError( + f"Unknown ONNX tensor element type to numpy dtype mapping: {initializer.data_type}" + ) + raw_data = initializer.raw_data + if raw_data: + return np.frombuffer(raw_data, dtype=dtype).reshape(tuple(initializer.dims)) + else: + raise OnnxImportError( + f"Unhandled ONNX TensorProto immediate data: {initializer}" + ) + + def _handle_node_ConstantOfShape(self, node: onnx.NodeProto): + # This op is special: It has an input of the shape, and in full generality + # could involve eager production of constants of variable size. In + # practice, the DNN profile for ONNX makes this very difficult to do + # and we hard-assert that the input can be resolved to an immediate + # value. + assert len(node.input) == 1 + assert len(node.output) == 1 + shape = self._get_immediate_tensor(node.input[0]).astype(np.int64) + value_proto = _get_attr(node, "value") + assert value_proto.type == onnx.AttributeProto.AttributeType.TENSOR + tensor_proto = value_proto.t + element_type = self._cc.tensor_element_type(tensor_proto.data_type) + vtensor_type = self._cc.get_vtensor_type(tuple(shape), element_type) + assert len(tensor_proto.dims) == 1 and tensor_proto.dims[0] == 1 + try: + cb = ELEM_TYPE_SPLAT_TENSOR_PROTO_CB[tensor_proto.data_type] + except KeyError: + raise OnnxImportError( + f"Unhandled splat type for ConstantOfShape: {node} (possible missing mapping in ELEM_TYPE_SPLAT_TENSOR_PROTO_CB)" + ) + value_attr = cb(tensor_proto, tuple(shape)) + literal_op = Operation.create( + name="torch.vtensor.literal", + results=[vtensor_type], + attributes={"value": value_attr}, + ) + self._nv_map[node.output[0]] = literal_op.result + + +class ContextCache: + """Caches per-context lookups of various things.""" + + __slots__ = [ + "_c", + "_elem_type_map", + "_vtensor_type_map", + ] + + def __init__(self, context: Context): + self._c = context + self._elem_type_map: dict[int, IrType] = {} + self._vtensor_type_map: dict[tuple[tuple[Optional[int]], IrType], IrType] = {} + + def tensor_element_type(self, elem_type: int) -> IrType: + t = self._elem_type_map.get(elem_type) + if t is None: + try: + with self._c: + t = ELEM_TYPE_TO_IR_TYPE_CB[elem_type]() + except KeyError: + raise OnnxImportError(f"Unknown ONNX tensor element type: {elem_type}") + self._elem_type_map[elem_type] = t + return t + + def get_vtensor_type( + self, dims: tuple[Optional[int]], element_type: IrType + ) -> IrType: + key = (dims, element_type) + t = self._vtensor_type_map.get(key) + if t is None: + shape_asm = ",".join("?" if d is None else str(d) for d in dims) + asm = f"!torch.vtensor<[{shape_asm}],{str(element_type)}>" + try: + t = IrType.parse(asm, context=self._c) + except MLIRError as e: + raise OnnxImportError( + f"Unparseable torch type (MLIR asm format bug?): {asm}" + ) from e + self._vtensor_type_map[key] = t + return t + + def tensor_proto_to_type(self, tp: onnx.TensorProto) -> IrType: + element_type = self.tensor_element_type(tp.data_type) + return self.get_vtensor_type(tuple(tp.dims), element_type) + + def tensor_proto_to_builtin_type(self, tp: onnx.TensorProto) -> IrType: + element_type = self.tensor_element_type(tp.data_type) + # TODO: Fixme upstream: RankedTensorType.get should not require a location. + with Location.unknown(): + return RankedTensorType.get(tuple(tp.dims), element_type) + + def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType: + if tp.tensor_type: + tt = tp.tensor_type + if not tt.shape: + raise OnnxImportError( + f"Unsupported Tensor type without shape (run shape inference?): {tp}" + ) + element_type = self.tensor_element_type(tt.elem_type) + dims = tuple( + (d.dim_value if not d.dim_param else None) for d in tt.shape.dim + ) + return self.get_vtensor_type(dims, element_type) + else: + # TODO: Others if ever needed. Or we consider ourselves DNN-only. + # See TypeProto: sequence_type, map_type, optional_type, sparse_tensor_type. + raise OnnxImportError(f"Unsupported ONNX TypeProto: {tp}") + + def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: + tensor_type = self.tensor_proto_to_builtin_type(tp) + if tp.HasField("raw_data"): + # Conveniently, DenseResourceElementsAttr shares the raw data + # format. We just give it maximum numeric alignment. + return DenseResourceElementsAttr.get_from_buffer( + tp.raw_data, tp.name, tensor_type, alignment=8 + ) + else: + # We have to do a data type specific instantiation from proto fields. + # Since this is typically used for small tensor constants, we instantiate + # as a DenseElementsAttr. + handler = ELEM_TYPE_INLINE_TENSOR_PROTO_CB.get(tp.data_type) + if handler is None: + raise OnnxImportError(f"Unhandled ONNX TensorProto data: {tp}") + return handler(tp) + + +ELEM_TYPE_TO_IR_TYPE_CB = { + onnx.TensorProto.DataType.FLOAT: lambda: F32Type.get(), + onnx.TensorProto.DataType.UINT8: lambda: IntegerType.get_unsigned(8), + onnx.TensorProto.DataType.INT8: lambda: IntegerType.get_signed(8), + onnx.TensorProto.DataType.UINT16: lambda: IntegerType.get_unsigned(16), + onnx.TensorProto.DataType.INT16: lambda: IntegerType.get_signed(16), + onnx.TensorProto.DataType.INT32: lambda: IntegerType.get_signed(32), + onnx.TensorProto.DataType.INT64: lambda: IntegerType.get_signed(64), + onnx.TensorProto.DataType.BOOL: lambda: IntegerType.get_signless(1), + onnx.TensorProto.DataType.FLOAT16: lambda: F16Type.get(), + onnx.TensorProto.DataType.DOUBLE: lambda: F64Type.get(), + onnx.TensorProto.DataType.UINT32: lambda: IntegerType.get_unsigned(32), + onnx.TensorProto.DataType.UINT64: lambda: IntegerType.get_unsigned(64), + onnx.TensorProto.DataType.COMPLEX64: lambda: ComplexType.get(F32Type.get()), + onnx.TensorProto.DataType.COMPLEX128: lambda: ComplexType.get(F64Type.get()), + onnx.TensorProto.DataType.BFLOAT16: lambda: BF16Type.get(), + onnx.TensorProto.DataType.FLOAT8E4M3FN: lambda: Float8E4M3FNType.get(), + onnx.TensorProto.DataType.FLOAT8E4M3FNUZ: lambda: Float8E5M2FNUZType.get(), + onnx.TensorProto.DataType.FLOAT8E5M2: lambda: Float8E5M2Type.get(), + onnx.TensorProto.DataType.FLOAT8E5M2FNUZ: lambda: Float8E5M2FNUZType.get(), + # Ommitted: STRING, +} + +ELEM_TYPE_SPLAT_TENSOR_PROTO_CB = { + onnx.TensorProto.DataType.FLOAT: lambda tp, shape: DenseElementsAttr.get_splat( + RankedTensorType.get(shape, F32Type.get()), FloatAttr.get_f32(tp.float_data[0]) + ), + # TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB +} + +# Mapping of TensorProto.DataType to lambda TensorProto, returning a DenseElementsAttr +# of the builtin tensor type for cases where the tensor data is inlined as typed +# values instead of raw_data. +ELEM_TYPE_INLINE_TENSOR_PROTO_CB = { + onnx.TensorProto.DataType.FLOAT: lambda tp: DenseElementsAttr.get( + np.asarray(tp.float_data, dtype=np.float32).reshape(tp.dims), signless=False + ), + onnx.TensorProto.DataType.INT32: lambda tp: DenseElementsAttr.get( + np.asarray(tp.int32_data, dtype=np.int32).reshape(tp.dims), signless=False + ), + onnx.TensorProto.DataType.INT64: lambda tp: DenseElementsAttr.get( + np.asarray(tp.int64_data, dtype=np.int64).reshape(tp.dims), signless=False + ), + onnx.TensorProto.DataType.DOUBLE: lambda tp: DenseElementsAttr.get( + np.asarray(tp.double_data, dtype=np.float64).reshape(tp.dims) + ), + onnx.TensorProto.DataType.UINT32: lambda tp: DenseElementsAttr.get( + # Special case. See proto + np.asarray(tp.uint64_data, dtype=np.uint32).reshape(tp.dims), + signless=False, + ), + onnx.TensorProto.DataType.UINT64: lambda tp: DenseElementsAttr.get( + np.asarray(tp.uint64_data, dtype=np.uint64).reshape(tp.dims), signless=False + ) + # Intentionally unsupported: STRING +} + +ELEM_TYPE_TO_NUMPY_DTYPE = { + onnx.TensorProto.DataType.FLOAT: np.float32, + onnx.TensorProto.DataType.UINT8: np.uint8, + onnx.TensorProto.DataType.INT8: np.int8, + onnx.TensorProto.DataType.UINT16: np.uint16, + onnx.TensorProto.DataType.INT16: np.int16, + onnx.TensorProto.DataType.INT32: np.int32, + onnx.TensorProto.DataType.INT64: np.int64, + onnx.TensorProto.DataType.BOOL: np.bool_, + onnx.TensorProto.DataType.FLOAT16: np.float16, + onnx.TensorProto.DataType.DOUBLE: np.float64, + onnx.TensorProto.DataType.UINT32: np.uint32, + onnx.TensorProto.DataType.UINT64: np.uint64, + onnx.TensorProto.DataType.COMPLEX64: np.complex64, + onnx.TensorProto.DataType.COMPLEX128: np.complex128, + # onnx.TensorProto.DataType.BFLOAT16: + # onnx.TensorProto.DataType.FLOAT8E4M3FN: + # onnx.TensorProto.DataType.FLOAT8E4M3FNUZ: + # onnx.TensorProto.DataType.FLOAT8E5M2: + # onnx.TensorProto.DataType.FLOAT8E5M2FNUZ: + # Ommitted: STRING, +} + +# Mapping of AttributeType code to one of: +# None: Ignore attribute and do not output to MLIR +# False: Error if an attribute of this type is present +# lambda a:AttributeProto, cc: ContextCache that returns an MLIR Attribute +ATTRIBUTE_TYPE_HANDLERS = { + onnx.AttributeProto.AttributeType.UNDEFINED: False, + onnx.AttributeProto.AttributeType.FLOAT: lambda a, cc: FloatAttr.get( + F32Type.get(), a.f + ), + onnx.AttributeProto.AttributeType.INT: lambda a, cc: IntegerAttr.get( + IntegerType.get_signed(64), a.i + ), + onnx.AttributeProto.AttributeType.STRING: lambda a, cc: StringAttr.get(a.s), + onnx.AttributeProto.AttributeType.TENSOR: lambda a, cc: cc.tensor_proto_to_attr( + a.t + ), + onnx.AttributeProto.AttributeType.GRAPH: False, + onnx.AttributeProto.AttributeType.SPARSE_TENSOR: False, + onnx.AttributeProto.AttributeType.TYPE_PROTO: False, + onnx.AttributeProto.AttributeType.FLOATS: lambda a, cc: ArrayAttr.get( + [FloatAttr.get(F32Type.get(), f) for f in a.floats] + ), + onnx.AttributeProto.AttributeType.INTS: lambda a, cc: ArrayAttr.get( + [IntegerAttr.get(IntegerType.get_signed(64), i) for i in a.ints] + ), + onnx.AttributeProto.AttributeType.STRINGS: lambda a, cc: ArrayAttr.get( + [StringAttr.get(s) for s in a.strings] + ), + onnx.AttributeProto.AttributeType.TENSORS: lambda a, cc: ArrayAttr.get( + [cc.tensor_proto_to_attr(t) for t in a.tensors] + ), + onnx.AttributeProto.AttributeType.GRAPHS: False, + onnx.AttributeProto.AttributeType.SPARSE_TENSORS: False, + onnx.AttributeProto.AttributeType.TYPE_PROTOS: False, +} + + +def _get_attr(node: onnx.NodeProto, attr_name: str) -> onnx.AttributeProto: + for attr in node.attribute: + if attr.name == attr_name: + return attr + else: + raise OnnxImportError(f"Required attribute {attr_name} not found in {node}") diff --git a/setup.py b/setup.py index 46217d30718d..a4b42309d755 100644 --- a/setup.py +++ b/setup.py @@ -47,8 +47,6 @@ # If true, enable LTC build by default TORCH_MLIR_ENABLE_LTC_DEFAULT = True TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS = int(os.environ.get('TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS', False)) -if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS: - import torch # Build phase discovery is unreliable. Just tell it what phases to run. class CustomBuild(_build): @@ -91,7 +89,7 @@ def run(self): f"-DCMAKE_C_VISIBILITY_PRESET=hidden", f"-DCMAKE_CXX_VISIBILITY_PRESET=hidden", f"-DTORCH_MLIR_ENABLE_LTC={'ON' if enable_ltc else 'OFF'}", - f"-DTORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS={'ON' if TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else 'OFF'}", + f"-DTORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS={'OFF' if TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else 'ON'}", ] os.makedirs(cmake_build_dir, exist_ok=True) @@ -145,8 +143,31 @@ def build_extension(self, ext): long_description = fh.read() +# Requires and extension modules depend on whether building PyTorch +# extensions. +INSTALL_REQUIRES = [ + "numpy", + "packaging", +] +EXT_MODULES = [ + CMakeExtension("torch_mlir._mlir_libs._torchMlir"), +] +NAME = "torch-mlir-core" + +# If building PyTorch extensions, customize. +if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS: + import torch + NAME = "torch-mlir" + INSTALL_REQUIRES.extend([ + f"torch=={torch.__version__}".split("+", 1)[0], + ]) + EXT_MODULES.extend([ + CMakeExtension("torch_mlir._mlir_libs._jit_ir_importer"), + ]) + + setup( - name="torch-mlir" if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else "torch-mlir-core", + name=NAME, version=f"{PACKAGE_VERSION}", author="Sean Silva", author_email="silvasean@google.com", @@ -159,10 +180,12 @@ def build_extension(self, ext): "built_ext": NoopBuildExtension, "build_py": CMakeBuild, }, - ext_modules=[ - CMakeExtension("torch_mlir._mlir_libs._jit_ir_importer"), - ] if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else [CMakeExtension("torch_mlir._mlir_libs._torchMlir")], - install_requires=["numpy", "packaging"] + ( - [f"torch=={torch.__version__}".split("+", 1)[0], ] if not TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else []), + ext_modules=EXT_MODULES, + install_requires=INSTALL_REQUIRES, + extras_require={ + "onnx": [ + "onnx>=1.15.0", + ], + } zip_safe=False, ) diff --git a/test-requirements.txt b/test-requirements.txt index 523772ddeeb0..315e021308e8 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,3 +1,4 @@ pillow dill multiprocess +onnx==1.15.0 \ No newline at end of file diff --git a/test/python/lit.local.cfg b/test/python/lit.local.cfg new file mode 100644 index 000000000000..4cfe04325d94 --- /dev/null +++ b/test/python/lit.local.cfg @@ -0,0 +1,2 @@ +if not config.enable_bindings_python: + config.unsupported = True diff --git a/test/python/onnx_importer/.gitignore b/test/python/onnx_importer/.gitignore new file mode 100644 index 000000000000..ea1472ec1f38 --- /dev/null +++ b/test/python/onnx_importer/.gitignore @@ -0,0 +1 @@ +output/ diff --git a/test/python/onnx_importer/_torch_mlir_config.py b/test/python/onnx_importer/_torch_mlir_config.py new file mode 100644 index 000000000000..f597b63b4dec --- /dev/null +++ b/test/python/onnx_importer/_torch_mlir_config.py @@ -0,0 +1,19 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s + +"""This file exists so that the tests can find/configure torch_mlir. + +It allows the test file to be standalone and used verbatim in other +projects (i.e. by just providing this file on the side). +""" + +from torch_mlir import ir +from torch_mlir.extras import onnx_importer + +def configure_context(context): + from torch_mlir.dialects import torch as torch_d + torch_d.register_dialect(context) diff --git a/test/python/onnx_importer/import_smoke_test.py b/test/python/onnx_importer/import_smoke_test.py new file mode 100644 index 000000000000..39a0b3098150 --- /dev/null +++ b/test/python/onnx_importer/import_smoke_test.py @@ -0,0 +1,374 @@ +# Based on code Copyright (c) Advanced Micro Devices, Inc. +# +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s --output %t + +from glob import glob +from pathlib import Path + +import logging +import sys +import unittest + +import onnx + +from _torch_mlir_config import ( + configure_context, + ir, + onnx_importer, +) + +# Accept the output path on the command line or default to a sibling +# to this file. We have to pop this off explicitly or else unittest +# won't understand. +if len(sys.argv) > 1 and sys.argv[1] == "--output": + OUTPUT_PATH = Path(sys.argv[2]) + del sys.argv[1:3] +else: + OUTPUT_PATH = Path(__file__).resolve().parent / "output" + + +# TODO: Add some verification and overrides. For now, just use the +# onnx package install for onnx test files, since they were nice +# enough to include the test suite in the deployable. +import onnx.backend.test.data + +ONNX_TEST_DATA_DIR = Path(onnx.backend.test.__file__).resolve().parent / "data" +print(f"ONNX Test Data Dir: {ONNX_TEST_DATA_DIR}") +ONNX_REL_PATHS = glob(f"**/*.onnx", root_dir=ONNX_TEST_DATA_DIR, recursive=True) + +OUTPUT_PATH.mkdir(parents=True, exist_ok=True) + +TEST_CAST_XFAILS = [ + "light_light_bvlc_alexnet", + "light_light_inception_v1", + "light_light_squeezenet", + "light_light_vgg19", + "node_test_affine_grid_2d_align_corners_expanded_model", + "node_test_affine_grid_2d_expanded_model", + "node_test_affine_grid_3d_align_corners_expanded_model", + "node_test_affine_grid_3d_expanded_model", + "node_test_ai_onnx_ml_label_encoder_string_int_model", + "node_test_ai_onnx_ml_label_encoder_string_int_no_default_model", + "node_test_ai_onnx_ml_label_encoder_tensor_mapping_model", + "node_test_ai_onnx_ml_label_encoder_tensor_value_only_mapping_model", + "node_test_cast_FLOAT16_to_FLOAT8E4M3FNUZ_model", + "node_test_cast_FLOAT16_to_FLOAT8E4M3FN_model", + "node_test_cast_FLOAT16_to_FLOAT8E5M2FNUZ_model", + "node_test_cast_FLOAT16_to_FLOAT8E5M2_model", + "node_test_cast_FLOAT8E4M3FNUZ_to_FLOAT16_model", + "node_test_cast_FLOAT8E4M3FNUZ_to_FLOAT_model", + "node_test_cast_FLOAT8E4M3FN_to_FLOAT16_model", + "node_test_cast_FLOAT8E4M3FN_to_FLOAT_model", + "node_test_cast_FLOAT8E5M2FNUZ_to_FLOAT16_model", + "node_test_cast_FLOAT8E5M2FNUZ_to_FLOAT_model", + "node_test_cast_FLOAT8E5M2_to_FLOAT16_model", + "node_test_cast_FLOAT8E5M2_to_FLOAT_model", + "node_test_cast_FLOAT_to_FLOAT8E4M3FNUZ_model", + "node_test_cast_FLOAT_to_FLOAT8E4M3FN_model", + "node_test_cast_FLOAT_to_FLOAT8E5M2FNUZ_model", + "node_test_cast_FLOAT_to_FLOAT8E5M2_model", + "node_test_cast_FLOAT_to_STRING_model", + "node_test_cast_STRING_to_FLOAT_model", + "node_test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FNUZ_model", + "node_test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FN_model", + "node_test_cast_no_saturate_FLOAT16_to_FLOAT8E5M2FNUZ_model", + "node_test_cast_no_saturate_FLOAT16_to_FLOAT8E5M2_model", + "node_test_cast_no_saturate_FLOAT_to_FLOAT8E4M3FNUZ_model", + "node_test_cast_no_saturate_FLOAT_to_FLOAT8E4M3FN_model", + "node_test_cast_no_saturate_FLOAT_to_FLOAT8E5M2FNUZ_model", + "node_test_cast_no_saturate_FLOAT_to_FLOAT8E5M2_model", + "node_test_castlike_FLOAT8E4M3FNUZ_to_FLOAT_expanded_model", + "node_test_castlike_FLOAT8E4M3FNUZ_to_FLOAT_model", + "node_test_castlike_FLOAT8E4M3FN_to_FLOAT_expanded_model", + "node_test_castlike_FLOAT8E4M3FN_to_FLOAT_model", + "node_test_castlike_FLOAT8E5M2FNUZ_to_FLOAT_expanded_model", + "node_test_castlike_FLOAT8E5M2FNUZ_to_FLOAT_model", + "node_test_castlike_FLOAT8E5M2_to_FLOAT_expanded_model", + "node_test_castlike_FLOAT8E5M2_to_FLOAT_model", + "node_test_castlike_FLOAT_to_FLOAT8E4M3FNUZ_expanded_model", + "node_test_castlike_FLOAT_to_FLOAT8E4M3FNUZ_model", + "node_test_castlike_FLOAT_to_FLOAT8E4M3FN_expanded_model", + "node_test_castlike_FLOAT_to_FLOAT8E4M3FN_model", + "node_test_castlike_FLOAT_to_FLOAT8E5M2FNUZ_expanded_model", + "node_test_castlike_FLOAT_to_FLOAT8E5M2FNUZ_model", + "node_test_castlike_FLOAT_to_FLOAT8E5M2_expanded_model", + "node_test_castlike_FLOAT_to_FLOAT8E5M2_model", + "node_test_castlike_FLOAT_to_STRING_expanded_model", + "node_test_castlike_FLOAT_to_STRING_model", + "node_test_castlike_STRING_to_FLOAT_expanded_model", + "node_test_castlike_STRING_to_FLOAT_model", + "node_test_center_crop_pad_crop_axes_chw_expanded_model", + "node_test_center_crop_pad_crop_axes_hwc_expanded_model", + "node_test_center_crop_pad_crop_negative_axes_hwc_expanded_model", + "node_test_clip_default_inbounds_model", + "node_test_clip_default_int8_inbounds_model", + "node_test_clip_default_int8_max_model", + "node_test_clip_default_max_model", + "node_test_constantofshape_float_ones_model", + "node_test_constantofshape_int_shape_zero_model", + "node_test_constantofshape_int_zeros_model", + "node_test_dequantizelinear_e4m3fn_model", + "node_test_dequantizelinear_e4m3fn_zero_point_model", + "node_test_dequantizelinear_e5m2_model", + "node_test_dft_axis_model", + "node_test_dft_inverse_model", + "node_test_dft_model", + "node_test_equal_string_broadcast_model", + "node_test_equal_string_model", + "node_test_gru_defaults_model", + "node_test_gru_seq_length_model", + "node_test_gru_with_initial_bias_model", + "node_test_identity_opt_model", + "node_test_identity_sequence_model", + "node_test_if_model", + "node_test_if_opt_model", + "node_test_if_seq_model", + "node_test_layer_normalization_2d_axis0_expanded_model", + "node_test_layer_normalization_2d_axis0_expanded_ver18_model", + "node_test_layer_normalization_2d_axis1_expanded_model", + "node_test_layer_normalization_2d_axis1_expanded_ver18_model", + "node_test_layer_normalization_2d_axis_negative_1_expanded_model", + "node_test_layer_normalization_2d_axis_negative_1_expanded_ver18_model", + "node_test_layer_normalization_2d_axis_negative_2_expanded_model", + "node_test_layer_normalization_2d_axis_negative_2_expanded_ver18_model", + "node_test_layer_normalization_3d_axis0_epsilon_expanded_model", + "node_test_layer_normalization_3d_axis0_epsilon_expanded_ver18_model", + "node_test_layer_normalization_3d_axis1_epsilon_expanded_model", + "node_test_layer_normalization_3d_axis1_epsilon_expanded_ver18_model", + "node_test_layer_normalization_3d_axis2_epsilon_expanded_model", + "node_test_layer_normalization_3d_axis2_epsilon_expanded_ver18_model", + "node_test_layer_normalization_3d_axis_negative_1_epsilon_expanded_model", + "node_test_layer_normalization_3d_axis_negative_1_epsilon_expanded_ver18_model", + "node_test_layer_normalization_3d_axis_negative_2_epsilon_expanded_model", + "node_test_layer_normalization_3d_axis_negative_2_epsilon_expanded_ver18_model", + "node_test_layer_normalization_3d_axis_negative_3_epsilon_expanded_model", + "node_test_layer_normalization_3d_axis_negative_3_epsilon_expanded_ver18_model", + "node_test_layer_normalization_4d_axis0_expanded_model", + "node_test_layer_normalization_4d_axis0_expanded_ver18_model", + "node_test_layer_normalization_4d_axis1_expanded_model", + "node_test_layer_normalization_4d_axis1_expanded_ver18_model", + "node_test_layer_normalization_4d_axis2_expanded_model", + "node_test_layer_normalization_4d_axis2_expanded_ver18_model", + "node_test_layer_normalization_4d_axis3_expanded_model", + "node_test_layer_normalization_4d_axis3_expanded_ver18_model", + "node_test_layer_normalization_4d_axis_negative_1_expanded_model", + "node_test_layer_normalization_4d_axis_negative_1_expanded_ver18_model", + "node_test_layer_normalization_4d_axis_negative_2_expanded_model", + "node_test_layer_normalization_4d_axis_negative_2_expanded_ver18_model", + "node_test_layer_normalization_4d_axis_negative_3_expanded_model", + "node_test_layer_normalization_4d_axis_negative_3_expanded_ver18_model", + "node_test_layer_normalization_4d_axis_negative_4_expanded_model", + "node_test_layer_normalization_4d_axis_negative_4_expanded_ver18_model", + "node_test_layer_normalization_default_axis_expanded_model", + "node_test_layer_normalization_default_axis_expanded_ver18_model", + "node_test_loop11_model", + "node_test_loop13_seq_model", + "node_test_loop16_seq_none_model", + "node_test_lstm_defaults_model", + "node_test_lstm_with_initial_bias_model", + "node_test_lstm_with_peepholes_model", + "node_test_optional_get_element_optional_sequence_model", + "node_test_optional_get_element_optional_tensor_model", + "node_test_optional_get_element_sequence_model", + "node_test_optional_has_element_empty_no_input_name_optional_input_model", + "node_test_optional_has_element_empty_no_input_name_tensor_input_model", + "node_test_optional_has_element_empty_optional_input_model", + "node_test_optional_has_element_optional_input_model", + "node_test_optional_has_element_tensor_input_model", + "node_test_quantizelinear_e4m3fn_model", + "node_test_quantizelinear_e5m2_model", + "node_test_range_float_type_positive_delta_expanded_model", + "node_test_range_int32_type_negative_delta_expanded_model", + "node_test_regex_full_match_basic_model", + "node_test_regex_full_match_email_domain_model", + "node_test_regex_full_match_empty_model", + "node_test_resize_downsample_scales_cubic_A_n0p5_exclude_outside_model", + "node_test_resize_downsample_scales_cubic_align_corners_model", + "node_test_resize_downsample_scales_cubic_antialias_model", + "node_test_resize_downsample_scales_cubic_model", + "node_test_resize_downsample_scales_linear_align_corners_model", + "node_test_resize_downsample_scales_linear_antialias_model", + "node_test_resize_downsample_scales_linear_half_pixel_symmetric_model", + "node_test_resize_downsample_scales_linear_model", + "node_test_resize_downsample_scales_nearest_model", + "node_test_resize_downsample_sizes_cubic_antialias_model", + "node_test_resize_downsample_sizes_cubic_model", + "node_test_resize_downsample_sizes_linear_antialias_model", + "node_test_resize_downsample_sizes_linear_pytorch_half_pixel_model", + "node_test_resize_downsample_sizes_nearest_model", + "node_test_resize_downsample_sizes_nearest_not_larger_model", + "node_test_resize_downsample_sizes_nearest_not_smaller_model", + "node_test_resize_tf_crop_and_resize_axes_2_3_model", + "node_test_resize_tf_crop_and_resize_axes_3_2_model", + "node_test_resize_tf_crop_and_resize_model", + "node_test_resize_upsample_scales_cubic_A_n0p5_exclude_outside_model", + "node_test_resize_upsample_scales_cubic_align_corners_model", + "node_test_resize_upsample_scales_cubic_asymmetric_model", + "node_test_resize_upsample_scales_cubic_model", + "node_test_resize_upsample_scales_linear_align_corners_model", + "node_test_resize_upsample_scales_linear_half_pixel_symmetric_model", + "node_test_resize_upsample_scales_linear_model", + "node_test_resize_upsample_scales_nearest_axes_2_3_model", + "node_test_resize_upsample_scales_nearest_axes_3_2_model", + "node_test_resize_upsample_scales_nearest_model", + "node_test_resize_upsample_sizes_cubic_model", + "node_test_resize_upsample_sizes_nearest_axes_2_3_model", + "node_test_resize_upsample_sizes_nearest_axes_3_2_model", + "node_test_resize_upsample_sizes_nearest_ceil_half_pixel_model", + "node_test_resize_upsample_sizes_nearest_floor_align_corners_model", + "node_test_resize_upsample_sizes_nearest_model", + "node_test_resize_upsample_sizes_nearest_not_larger_model", + "node_test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric_model", + "node_test_rnn_seq_length_model", + "node_test_scan9_sum_model", + "node_test_scan_sum_model", + "node_test_sequence_insert_at_back_model", + "node_test_sequence_insert_at_front_model", + "node_test_sequence_map_add_1_sequence_1_tensor_expanded_model", + "node_test_sequence_map_add_1_sequence_1_tensor_model", + "node_test_sequence_map_add_2_sequences_expanded_model", + "node_test_sequence_map_add_2_sequences_model", + "node_test_sequence_map_extract_shapes_expanded_model", + "node_test_sequence_map_extract_shapes_model", + "node_test_sequence_map_identity_1_sequence_1_tensor_expanded_model", + "node_test_sequence_map_identity_1_sequence_1_tensor_model", + "node_test_sequence_map_identity_1_sequence_expanded_model", + "node_test_sequence_map_identity_1_sequence_model", + "node_test_sequence_map_identity_2_sequences_expanded_model", + "node_test_sequence_map_identity_2_sequences_model", + "node_test_simple_rnn_defaults_model", + "node_test_simple_rnn_with_initial_bias_model", + "node_test_split_to_sequence_1_model", + "node_test_split_to_sequence_2_model", + "node_test_split_to_sequence_nokeepdims_model", + "node_test_stft_model", + "node_test_string_concat_broadcasting_model", + "node_test_string_concat_empty_string_model", + "node_test_string_concat_model", + "node_test_string_concat_utf8_model", + "node_test_string_concat_zero_dimensional_model", + "node_test_string_split_basic_model", + "node_test_string_split_consecutive_delimiters_model", + "node_test_string_split_empty_string_delimiter_model", + "node_test_string_split_empty_tensor_model", + "node_test_string_split_maxsplit_model", + "node_test_string_split_no_delimiter_model", + "node_test_strnormalizer_export_monday_casesensintive_lower_model", + "node_test_strnormalizer_export_monday_casesensintive_nochangecase_model", + "node_test_strnormalizer_export_monday_casesensintive_upper_model", + "node_test_strnormalizer_export_monday_empty_output_model", + "node_test_strnormalizer_export_monday_insensintive_upper_twodim_model", + "node_test_strnormalizer_nostopwords_nochangecase_model", + "simple_test_sequence_model1_model", + "simple_test_sequence_model2_model", + "simple_test_sequence_model3_model", + "simple_test_sequence_model4_model", + "simple_test_sequence_model5_model", + "simple_test_sequence_model6_model", + "simple_test_sequence_model7_model", + "simple_test_sequence_model8_model", + "simple_test_strnorm_model_monday_casesensintive_lower_model", + "simple_test_strnorm_model_monday_casesensintive_nochangecase_model", + "simple_test_strnorm_model_monday_casesensintive_upper_model", + "simple_test_strnorm_model_monday_empty_output_model", + "simple_test_strnorm_model_monday_insensintive_upper_twodim_model", + "simple_test_strnorm_model_nostopwords_nochangecase_model", +] + + +class ImportSmokeTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.unexpected_failure_count = 0 + ImportSmokeTest.actual_failures = [] + + @classmethod + def tearDownClass(cls): + if cls.unexpected_failure_count: + # Print a helpful message with copy-paste XFAIL def. + failure_report_path = OUTPUT_PATH / "import_smoke_test_report.txt" + print( + "Unexpected failures. Writing copy/paste report to:", + failure_report_path, + ) + with open(failure_report_path, "wt") as f: + lines = [f' "{s}",' for s in ImportSmokeTest.actual_failures] + print( + f"Unexpected failures in the following. Copy/paste to update `TEST_CAST_XFAILS`:", + file=f, + ) + print(f"TEST_CAST_XFAILS = [", file=f) + [print(l, file=f) for l in lines] + print(f"]", file=f) + + ImportSmokeTest.actual_failures.clear() + + def load_onnx_model(self, file_path: Path) -> onnx.ModelProto: + raw_model = onnx.load(file_path) + try: + inferred_model = onnx.shape_inference.infer_shapes(raw_model) + except onnx.onnx_cpp2py_export.shape_inference.InferenceError as e: + print("WARNING: Shape inference failure (skipping test):", e) + self.skipTest(reason="shape inference failure") + + # inferred_model = raw_model + return inferred_model + + def run_import_test(self, norm_name: str, rel_path: str): + context = ir.Context() + configure_context(context) + + model_info = onnx_importer.ModelInfo( + self.load_onnx_model(ONNX_TEST_DATA_DIR / rel_path), + ) + m = model_info.create_module(context=context) + try: + imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m) + imp.import_all() + m.verify() + finally: + # Use a ".txt" extension to avoid lit test discovery. + with open(OUTPUT_PATH / f"{norm_name}.mlir", "wt") as f: + print(m.get_asm(), file=f) + + def testExists(self): + # We expect a lot of test cases. Die if not the case (i.e. if paths change + # or something). + self.assertGreater(len(ONNX_REL_PATHS), 10) + + +# Generate test methods for each onnx file. +for _rel_path in ONNX_REL_PATHS: + + def attach_test(rel_path): + norm_name = rel_path.removesuffix(".onnx").replace("/", "_") + + def test_method(self: ImportSmokeTest): + try: + self.run_import_test(norm_name, rel_path) + except onnx_importer.OnnxImportError as e: + # All legitimate failures should be caught and reported + # as an OnnxImportError. + ImportSmokeTest.actual_failures.append(norm_name) + if norm_name not in TEST_CAST_XFAILS: + ImportSmokeTest.unexpected_failure_count += 1 + raise e + + test_method.__name__ = f"test_{norm_name}" + + if norm_name in TEST_CAST_XFAILS: + test_method = unittest.expectedFailure(test_method) + + setattr(ImportSmokeTest, test_method.__name__, test_method) + + attach_test(_rel_path) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() diff --git a/test/python/onnx_importer/lit.local.cfg b/test/python/onnx_importer/lit.local.cfg new file mode 100644 index 000000000000..8e0adb7c1c49 --- /dev/null +++ b/test/python/onnx_importer/lit.local.cfg @@ -0,0 +1,5 @@ +try: + import onnx +except ModuleNotFoundError: + print("Skipping onnx tests.. no onnx") + config.unsupported = True From 7cf52ae73f59ae90abcdbe3e7f60bef95e1a59eb Mon Sep 17 00:00:00 2001 From: JianzheXiao Date: Tue, 12 Dec 2023 19:05:12 -0800 Subject: [PATCH 002/283] [Torch Dialect]Add Support for AtenGroupNormOp and AtenNativeGroupNormOp (#2591) Co-authored-by: LiuYuanqiang --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 28 +++ .../Transforms/AbstractInterpLibrary.cpp | 40 +++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 161 ++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 2 + projects/pt1/e2e_testing/xfail_sets.py | 5 + .../build_tools/abstract_interp_lib_gen.py | 18 ++ .../build_tools/torch_ods_gen.py | 3 + .../test_suite/__init__.py | 1 - .../test_suite/norm_like.py | 43 ++++- 9 files changed, 298 insertions(+), 3 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 731a4971036a..947e258108f6 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5640,6 +5640,34 @@ def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [ }]; } +def Torch_AtenGroupNormOp : Torch_Op<"aten.group_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + Torch_IntType:$num_groups, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + Torch_FloatType:$eps, + Torch_BoolType:$cudnn_enabled + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenGroupNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 6, 1); + } + void AtenGroupNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 6, 1); + } + }]; +} + def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 1bad23b76e73..a1db752b55f0 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8074,6 +8074,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.batch_norm(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.optional>, !torch.optional>, !torch.optional>, !torch.optional>, !torch.bool, !torch.float, !torch.float, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.group_norm\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float, %arg5: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.native_group_norm\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.float) -> !torch.tuple, list, list> {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" %1 = torch.prim.ListConstruct %arg3, %arg6 : (!torch.int, !torch.int) -> !torch.list\n" +" %2 = torch.prim.ListConstruct %arg3, %arg6 : (!torch.int, !torch.int) -> !torch.list\n" +" %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" +" return %3 : !torch.tuple, list, list>\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8748,6 +8759,35 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.group_norm\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.float, %arg5: !torch.bool) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.native_group_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.float) -> !torch.tuple {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.prim.TupleConstruct %0#1, %0#1, %0#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" return %3 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 0e8cad63e536..d7e7a2af6dd2 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3753,6 +3753,165 @@ class DecomposeConstantTensorAllocLikeOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenGroupNormOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenGroupNormOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *context = op.getContext(); + + Value input = op.getInput(); + Value weight = op.getWeight(); + Value bias = op.getBias(); + Value numGroups = op.getNumGroups(); + Value eps = op.getEps(); + + Value cstZero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + auto baseType = ValueTensorType::getWithLeastStaticInformation(context); + + Value N = rewriter.create(loc, input, cstZero); + Value C = rewriter.create(loc, input, cstOne); + Value numElements = rewriter.create(loc, input); + Value numElementsDivN = + rewriter.create(loc, numElements, N); + Value HxW = rewriter.create(loc, numElementsDivN, C); + + AtenNativeGroupNormOp newOp = rewriter.create( + loc, ArrayRef{op.getResult().getType(), baseType, baseType}, + input, weight, bias, N, C, HxW, numGroups, eps); + + rewriter.replaceOp(op, newOp.getResult0()); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenNativeGroupNormOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNativeGroupNormOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *context = op.getContext(); + + Value input = op.getInput(); + Value weight = op.getWeight(); + Value bias = op.getBias(); + Value numGroups = op.getGroup(); + Value eps = op.getEps(); + + // Check the rank of the input/outputs tensor. + auto inputType = input.getType().cast(); + auto outputType = op.getResult0().getType().cast(); + auto meanType = op.getResult1().getType().cast(); + auto rsqrtVarType = op.getResult2().getType().cast(); + if (!inputType.hasSizes() || !outputType.hasSizes() || + !meanType.hasSizes() || !rsqrtVarType.hasSizes()) { + return rewriter.notifyMatchFailure( + op, "input/outputs tensor should have known sizes."); + } + + Value none = rewriter.create(loc); + Value cstZero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value cstNegtiveOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(-1)); + Value cstTrue = rewriter.create(loc, true); + Value cstFalse = rewriter.create(loc, false); + auto baseType = ValueTensorType::getWithLeastStaticInformation(context); + + // GroupNorm requires the channel dimension (C) to be exactly divisible by + // the number of groups. + Value channel = rewriter.create(loc, input, cstOne); + Value remainder = + rewriter.create(loc, channel, numGroups); + Value eqOrNot = rewriter.create(loc, remainder, cstZero); + rewriter.create( + loc, eqOrNot, + rewriter.getStringAttr("the number of channels must be divisible by " + "the number of groups")); + + // Reshape the input tensor to (N, numGroups, -1) to apply normalization. + SmallVector newShape; + newShape.push_back(rewriter.create(loc, input, cstZero)); + newShape.push_back(numGroups); + newShape.push_back(cstNegtiveOne); + Value reshapedInput = rewriter.create( + loc, baseType, input, + rewriter.create( + loc, Torch::ListType::get(IntType::get(context)), newShape)); + + // Now we proceed with the normalization steps across the 'groupSize' + // Compute the mean and variance for each group + Value dimList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + ArrayRef{cstNegtiveOne}); + auto mean = rewriter.create( + loc, baseType, reshapedInput, /*dims=*/dimList, /*keepdim=*/cstTrue, + /*dtype=*/none); + auto var = rewriter.create( + loc, baseType, reshapedInput, /*dims=*/dimList, /*unbiased=*/cstFalse, + /*keepdim=*/cstTrue); + + // Compute the normalized output: (input - mean) * rsqrt(var + eps) + auto varPlusEps = rewriter.create(loc, baseType, var, eps, + /*alpha=*/cstOne); + auto invStd = rewriter.create(loc, baseType, varPlusEps); + auto inputSubMean = rewriter.create( + loc, baseType, reshapedInput, mean, /*alpha=*/cstOne); + auto normalizedOutput = + rewriter.create(loc, baseType, inputSubMean, invStd); + + // Reshape normalized output back to the original input shape + auto inputShape = rewriter.create( + loc, Torch::ListType::get(IntType::get(context)), input); + auto reshapedOutput = rewriter.create( + loc, inputType, normalizedOutput, /*shape=*/inputShape); + + // Apply weight and bias if they are not None + // Reshape weight and bias to C,1,1,... + SmallVector viewShape = {channel}; + for (unsigned i = 2; i < inputType.getSizes().size(); i++) { + viewShape.push_back(cstOne); + } + Value viewShapeSizeList = rewriter.create( + loc, ListType::get(IntType::get(context)), viewShape); + + Value groupNormOutput = reshapedOutput; + if (!weight.getType().isa()) { + auto weightReshaped = rewriter.create( + loc, baseType, weight, /*shape=*/viewShapeSizeList); + groupNormOutput = rewriter.create( + loc, inputType, groupNormOutput, weightReshaped); + } + if (!bias.getType().isa()) { + auto biasReshaped = rewriter.create( + loc, baseType, bias, /*shape=*/viewShapeSizeList); + groupNormOutput = rewriter.create( + loc, inputType, groupNormOutput, biasReshaped, + /*alpha=*/cstOne); + } + + Value squeezedMean = + rewriter.create(loc, meanType, mean, cstNegtiveOne); + Value squeezedRsqrtVar = rewriter.create( + loc, rsqrtVarType, invStd, cstNegtiveOne); + + rewriter.replaceOp( + op, ArrayRef{groupNormOutput, squeezedMean, squeezedRsqrtVar}); + + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenNativeBatchNormOp : public OpRewritePattern { @@ -6204,6 +6363,8 @@ class DecomposeComplexOpsPass DecomposeAtenAddCLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeAten_ConvolutionLikeOp>(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index d4c4c00f3723..48a3e5c65ad6 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -407,6 +407,8 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 0e5bb5747615..98ee0fa6ccaf 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -306,6 +306,10 @@ # ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4])) "ArangeStartOutViewModule_basic", + + # ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' + "GroupNormModule_basic", + "GroupNormNoWeightAndBiasModule_basic", } TORCHDYNAMO_CRASHING_SET = { @@ -586,6 +590,7 @@ "NewFullModuleInt2DStatic_basic", "NewFullModuleInt2D_basic", "NewFullModuleInt3D_basic", + "GroupNormModule_basic", "GatherStaticModule_basic", "GatherModule_basic", "Gather2DInputModdule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index ba70122aab0d..6d175712b2d0 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1130,6 +1130,12 @@ def aten〇convolution_backward〡shape(grad_output: List[int], input: List[int] def aten〇batch_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]: return upstream_shape_functions.batch_norm(input, weight, bias, running_mean, running_var, training, momentum, eps, cudnn_enabled) +def aten〇group_norm〡shape(input: List[int], num_groups: int, weight: Optional[List[int]] = None, bias: Optional[List[int]] = None, eps: float = 1.0000000000000001e-05, cudnn_enabled: bool = True) -> List[int]: + return upstream_shape_functions.unary(input) + +def aten〇native_group_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], N: int, C: int, HxW: int, group: int, eps: float) -> Tuple[List[int], List[int], List[int]]: + return upstream_shape_functions.unary(input), [N, group], [N, group] + def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]: return upstream_shape_functions.slice(self, dim, start, end, step) @@ -1671,6 +1677,18 @@ def aten〇batch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dty input_rank, input_dtype = input_rank_dtype return input_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], error_types={*all_integer_dtypes()}, num_groups=1)) +def aten〇group_norm〡dtype(input_rank_dtype: Tuple[int, int], num_groups: int, weight_rank_dtype: Optional[Tuple[int, int]] = None, bias_rank_dtype: Optional[Tuple[int, int]] = None, eps: float = 1.0000000000000001e-05, cudnn_enabled: bool = True) -> int: + input_rank, input_dtype = input_rank_dtype + assert not is_integer_dtype(input_dtype) + return input_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7), (3,), (3,)], error_types={*all_integer_dtypes()}, N=2, C=3, HxW=35, group=1, eps=0.000001)) +def aten〇native_group_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], N: int, C: int, HxW: int, group: int, eps: float) -> Tuple[int, int, int]: + input_rank, input_dtype = input_rank_dtype + assert not is_integer_dtype(input_dtype) + return input_dtype, input_dtype, input_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇bernoulli_〇float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 51aa1debfbeb..fd6bfbc968eb 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -421,6 +421,9 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)" ) + emit( + 'aten::group_norm : (Tensor, int, Tensor?, Tensor?, float, bool) -> (Tensor)' + ) emit( "aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)" ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py index ee6878701f1c..79712a16f65b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -10,7 +10,6 @@ from torch_mlir._version import torch_version_for_comparison, version COMMON_TORCH_MLIR_LOWERING_XFAILS = { - "NativeGroupNormModule_basic", "NativeGroupNormBackwardModule_basic", "QuantizedMLP_basic", "ReduceMaxAlongDimUnsignedInt_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py index f59695620064..59a251082303 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py @@ -243,6 +243,42 @@ def NativeBatchNormNoneWeightModule_basic(module, tu: TestUtils): # ============================================================================== +class GroupNormModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 4, 6, 7], torch.float32, True), + ([4], torch.float32, True), + ([4], torch.float32, True), + ]) + def forward(self, x, weight, bias): + return torch.ops.aten.group_norm(x, 2, weight, bias, 1.0000000000000001e-05, False) + +@register_test_case(module_factory=lambda: GroupNormModule()) +def GroupNormModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 6, 7), tu.rand(4), tu.rand(4)) + +class GroupNormNoWeightAndBiasModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 4, 6, 7], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.group_norm(x, 2, None, None, 1.0000000000000001e-05, False) + +@register_test_case(module_factory=lambda: GroupNormNoWeightAndBiasModule()) +def GroupNormNoWeightAndBiasModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 6, 7)) + +# ============================================================================== + class NativeGroupNormModule(torch.nn.Module): def __init__(self): super().__init__() @@ -257,13 +293,15 @@ def __init__(self): def forward(self, x, weight, bias): return torch.ops.aten.native_group_norm( x, weight, bias, - 2, 6, 4, 3, 0.000001); + 2, 6, 4, 3, 0.000001) @register_test_case(module_factory=lambda: NativeGroupNormModule()) def NativeGroupNormModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 6, 2, 2), tu.rand(6), tu.rand(6)) +# ============================================================================== + class NativeGroupNormBackwardModule(torch.nn.Module): def __init__(self): super().__init__() @@ -280,7 +318,7 @@ def __init__(self): def forward(self, grad_out, x, mean, rstd, weight): return torch.ops.aten.native_group_norm_backward( grad_out, x, mean, rstd, weight, - 2, 6, 4, 3, [True, True, True]); + 2, 6, 4, 3, [True, True, True]) @register_test_case(module_factory=lambda: NativeGroupNormBackwardModule()) @@ -450,3 +488,4 @@ def forward(self, x): @register_test_case(module_factory=lambda: LayerNormNormalizeOverAllDimsModule()) def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 2, 3)) + From ed4df38e8d86083c4dcc1b58f7d59a4c8cf6ab85 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 12 Dec 2023 22:01:30 -0800 Subject: [PATCH 003/283] [onnx] Add torch-mlir-import-onnx tool. (#2637) Simple Python console script to import an ONNX protobuf to the torch dialect for additional processing. For installed wheels, this can be used with something like: ``` torch-mlir-import-onnx test/python/onnx_importer/LeakyReLU.onnx ``` Or from a dev setup: ``` python -m torch_mlir.tools.import_onnx ... ``` --- python/CMakeLists.txt | 7 ++ .../torch_mlir/tools/import_onnx/__main__.py | 77 +++++++++++++++++++ setup.py | 7 +- test/lit.cfg.py | 2 +- test/python/onnx_importer/LeakyReLU.onnx | 15 ++++ .../onnx_importer/import_onnx_tool.runlit | 3 + 6 files changed, 109 insertions(+), 2 deletions(-) create mode 100644 python/torch_mlir/tools/import_onnx/__main__.py create mode 100644 test/python/onnx_importer/LeakyReLU.onnx create mode 100644 test/python/onnx_importer/import_onnx_tool.runlit diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 7b9bf12f2b8f..f29429b7246c 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -38,6 +38,13 @@ declare_mlir_python_sources(TorchMLIRPythonSources.Importers extras/onnx_importer.py ) +declare_mlir_python_sources(TorchMLIRPythonSources.Tools + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TorchMLIRPythonSources + SOURCES + tools/import_onnx/__main__.py +) + ################################################################################ # Extensions ################################################################################ diff --git a/python/torch_mlir/tools/import_onnx/__main__.py b/python/torch_mlir/tools/import_onnx/__main__.py new file mode 100644 index 000000000000..b300b4100b3e --- /dev/null +++ b/python/torch_mlir/tools/import_onnx/__main__.py @@ -0,0 +1,77 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +"""Console tool for converting an ONNX proto to torch IR. + +Typically, when installed from a wheel, this can be invoked as: + + torch-mlir-import-onnx some.pb + +Or from Python: + + python -m torch_mlir.tools.import_onnx ... +""" +import argparse +from pathlib import Path +import sys + +import onnx + +from ...extras import onnx_importer + +from ...dialects import torch as torch_d +from ...ir import ( + Context, +) + + +def main(args): + model_proto = load_onnx_model(args.input_file) + context = Context() + torch_d.register_dialect(context) + model_info = onnx_importer.ModelInfo(model_proto) + m = model_info.create_module(context=context) + imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m) + imp.import_all() + if not args.no_verify: + m.verify() + + # TODO: This isn't very efficient output. If these files ever + # get large, enable bytecode and direct binary emission to save + # some copies. + if args.output_file and args.output_file != "-": + with open(args.output_file, "wt") as f: + print(m.get_asm(assume_verified=not args.no_verify), file=f) + else: + print(m.get_asm(assume_verified=not args.no_verify)) + + +def load_onnx_model(file_path: Path) -> onnx.ModelProto: + raw_model = onnx.load(file_path) + inferred_model = onnx.shape_inference.infer_shapes(raw_model) + return inferred_model + + +def parse_arguments(argv=None): + parser = argparse.ArgumentParser(description="Torch-mlir ONNX import tool") + parser.add_argument("input_file", help="ONNX protobuf input", type=Path) + parser.add_argument( + "-o", dest="output_file", help="Output path (or '-' for stdout)" + ) + parser.add_argument( + "--no-verify", + action="store_true", + help="Disable verification prior to printing", + ) + args = parser.parse_args(argv) + return args + + +def _cli_main(): + sys.exit(main(parse_arguments())) + + +if __name__ == "__main__": + _cli_main() diff --git a/setup.py b/setup.py index a4b42309d755..77c8b2ad047d 100644 --- a/setup.py +++ b/setup.py @@ -186,6 +186,11 @@ def build_extension(self, ext): "onnx": [ "onnx>=1.15.0", ], - } + }, + entry_points={ + "console_scripts": [ + "torch-mlir-import-onnx = torch_mlir.tools.import_onnx:_cli_main", + ], + }, zip_safe=False, ) diff --git a/test/lit.cfg.py b/test/lit.cfg.py index a9753bf22719..4608dfb6c892 100644 --- a/test/lit.cfg.py +++ b/test/lit.cfg.py @@ -24,7 +24,7 @@ config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) # suffixes: A list of file extensions to treat as test files. -config.suffixes = ['.mlir', '.py'] +config.suffixes = ['.mlir', '.py', '.runlit'] # test_source_root: The root path where tests are located. config.test_source_root = os.path.dirname(__file__) diff --git a/test/python/onnx_importer/LeakyReLU.onnx b/test/python/onnx_importer/LeakyReLU.onnx new file mode 100644 index 000000000000..f76bccbce92a --- /dev/null +++ b/test/python/onnx_importer/LeakyReLU.onnx @@ -0,0 +1,15 @@ +pytorch0.3:h +" +01" LeakyRelu* +alpha +×#< torch-jit-exportZ +0 + + + +b +1 + + + +B \ No newline at end of file diff --git a/test/python/onnx_importer/import_onnx_tool.runlit b/test/python/onnx_importer/import_onnx_tool.runlit new file mode 100644 index 000000000000..45b733f9da7a --- /dev/null +++ b/test/python/onnx_importer/import_onnx_tool.runlit @@ -0,0 +1,3 @@ +# RUN: %PYTHON -m torch_mlir.tools.import_onnx %S/LeakyReLU.onnx | FileCheck %s + +# CHECK: torch.operator "onnx.LeakyRelu" From 42392bc84524cff83f6e87e17b02e563aebf2122 Mon Sep 17 00:00:00 2001 From: John Wu Date: Wed, 13 Dec 2023 09:35:32 -0800 Subject: [PATCH 004/283] [MLIR][ONNX] Add OnnxToTorch support for matmul ops (#2629) This commit adds the OnnxToTorch support for Matmul. --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 15 ++++++++++- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 26 +++++++++++++++++++ 2 files changed, 40 insertions(+), 1 deletion(-) create mode 100644 test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index af4f06fdef77..34e7068c8681 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -26,4 +26,17 @@ using namespace mlir::torch::onnx_c; // results in a lot of ONNX test cases that all reduce to the exact same // thing here, so we simplify. void mlir::torch::onnx_c::populateDefaultDomainGtoP( - OnnxCustomOpConversionPattern &patterns) {} + OnnxCustomOpConversionPattern &patterns) { + + patterns.onOp("MatMul", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); +} diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir new file mode 100644 index 000000000000..28b180a1e541 --- /dev/null +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -0,0 +1,26 @@ +// RUN: torch-mlir-opt <%s -convert-torch-onnx-to-torch | FileCheck %s +// Generally, the test cases accumulated here come from running the importer +// over all included backend tests that involve simple ops with no model +// level constants. This is a pragmatic choice which lets us have a lot +// of tests in this file, whereas the others tend to be more bespoke. + +// CHECK-LABEL: @test_matmul_2d +func.func @test_matmul_2d(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[3,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[3,3],f32> + %0 = torch.operator "onnx.MatMul"(%arg0, %arg1) : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,3],f32> + return %0 : !torch.vtensor<[3,3],f32> +} + +// CHECK-LABEL: @test_matmul_3d +func.func @test_matmul_3d(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,4,3],f32>) -> !torch.vtensor<[2,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,4,3],f32> -> !torch.vtensor<[2,3,3],f32> + %0 = torch.operator "onnx.MatMul"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,4,3],f32>) -> !torch.vtensor<[2,3,3],f32> + return %0 : !torch.vtensor<[2,3,3],f32> +} + +// CHECK-LABEL: @test_matmul_4d +func.func @test_matmul_4d(%arg0: !torch.vtensor<[1,2,3,4],f32>, %arg1: !torch.vtensor<[1,2,4,3],f32>) -> !torch.vtensor<[1,2,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[1,2,3,4],f32>, !torch.vtensor<[1,2,4,3],f32> -> !torch.vtensor<[1,2,3,3],f32> + %0 = torch.operator "onnx.MatMul"(%arg0, %arg1) : (!torch.vtensor<[1,2,3,4],f32>, !torch.vtensor<[1,2,4,3],f32>) -> !torch.vtensor<[1,2,3,3],f32> + return %0 : !torch.vtensor<[1,2,3,3],f32> +} \ No newline at end of file From 6ddeb1a6efe7f61e80be3791952dc7c730dc1c72 Mon Sep 17 00:00:00 2001 From: JianzheXiao Date: Wed, 13 Dec 2023 20:28:08 -0800 Subject: [PATCH 005/283] [torch] Add support for aten.selu (#2640) Add `aten.selu` operation to `torch` dialect. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 45 +++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 26 ++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 50 +++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 2 + .../build_tools/abstract_interp_lib_gen.py | 11 ++++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/elementwise.py | 21 ++++++++ 8 files changed, 157 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 947e258108f6..7b684886d5ea 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -346,6 +346,51 @@ def Torch_AtenLog_Op : Torch_Op<"aten.log_", [ }]; } +def Torch_AtenSeluOp : Torch_Op<"aten.selu", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::selu : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSeluOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenSeluOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenSelu_Op : Torch_Op<"aten.selu_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::selu_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSelu_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenSelu_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenSigmoidOp : Torch_Op<"aten.sigmoid", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index a1db752b55f0..7089f4860531 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6746,6 +6746,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.selu\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.gather\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg2) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -10434,6 +10438,28 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.selu\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int11 = torch.constant.int 11\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %3 = torch.aten.__not__ %2 : !torch.bool -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.remainder.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index d7e7a2af6dd2..d4712e547264 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1937,6 +1937,55 @@ class DecomposeAtenEluOp : public OpRewritePattern { }; } // namespace +// Selu = scale * (max(0,x) + min(0,alpha * (exp(x) − 1))) +namespace { +class DecomposeAtenSeluOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSeluOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value input = op.getSelf(); + auto resType = op.getType().cast(); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + + // Define λ and α + double scale = 1.0507009873554804934193349852946; + double alpha = 1.6732632423543772848170429916717; + + // Create constants for λ and α + Value scaleVal = rewriter.create(loc, rewriter.getF64FloatAttr(scale)); + Value alphaVal = rewriter.create(loc, rewriter.getF64FloatAttr(alpha)); + + // Create zero tensor for comparison + Value constantZero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value zeroTensor = createRank0Tensor(rewriter, loc, resType, constantZero); + + // Calculate positive and negative parts + Value constantOne = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value positiveOutput = rewriter.create(loc, resType, zeroTensor, input); + Value minZeroX = + rewriter.create(loc, resType, zeroTensor, input); + Value expInput = rewriter.create(loc, resType, minZeroX); + Value expInputMinusOne = rewriter.create(loc, resType, expInput, constantOne, constantOne); + Value negativeOutput = rewriter.create(loc, resType, expInputMinusOne, alphaVal); + + // Multiply the result by λ + Value seluOutput = rewriter.create( + loc, resType, positiveOutput, negativeOutput, constantOne); + seluOutput = rewriter.create(loc, resType, seluOutput, scaleVal); + + // Replace the original operation + rewriter.replaceOp(op, seluOutput); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenTOp : public OpRewritePattern { public: @@ -6460,6 +6509,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 48a3e5c65ad6..79f64ef32fbf 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -437,6 +437,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 98ee0fa6ccaf..c70b01b47819 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -486,6 +486,7 @@ "ElementwiseLeakyReluModule_basic", "ElementwiseEluModule_basic", "ElementwiseEluNonDefaultModule_basic", + "ElementwiseSeluModule_basic", "ElementwiseLogModule_basic", "ElementwiseNegModule_basic", "ElementwiseRsqrtModule_basic", @@ -1115,6 +1116,7 @@ "ElementwiseRemainderScalarModule_Int_basic", "ElementwiseRemainderScalarModule_Int_basic", "ElementwiseRsqrtModule_basic", + "ElementwiseSeluModule_basic", "ElementwiseSigmoidModule_basic", "ElementwiseSignModule_basic", "ElementwiseSqrtIntModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 6d175712b2d0..903b6c3bf36e 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -373,6 +373,9 @@ def aten〇elu〡shape(self: List[int], alpha: float = 1, scale: float = 1, inpu def aten〇prelu〡shape(self: List[int], weight: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇selu〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇gather〡shape(self: List[int], dim: int, index: List[int], sparse_grad: bool = False) -> List[int]: return upstream_shape_functions.unary(index) @@ -3066,6 +3069,14 @@ def aten〇elu〡dtype(self_rank_dtype: Tuple[int, int], alpha: Union[int, float assert not is_integer_dtype(self_dtype) return self_dtype +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.bool, torch.int8, torch.uint8, torch.int16, torch.int32, torch.int64})) +def aten〇selu〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.bool + assert not is_integer_dtype(self_dtype) + return self_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, other=1.0)) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index fd6bfbc968eb..32f0d0df862d 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -262,6 +262,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::relu6 : (Tensor) -> (Tensor)", "aten::leaky_relu : (Tensor, Scalar) -> (Tensor)", "aten::log : (Tensor) -> (Tensor)", + "aten::selu : (Tensor) -> (Tensor)", "aten::sigmoid : (Tensor) -> (Tensor)", "aten::sign : (Tensor) -> (Tensor)", "aten::sgn : (Tensor) -> (Tensor)", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 0cec9e5deaf8..0b5330b91380 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -564,6 +564,27 @@ def ElementwiseGeluModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseSeluModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.selu(x) + +@register_test_case(module_factory=lambda: ElementwiseSeluModule()) +def ElementwiseSeluModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-1, high=1)) + + +# ============================================================================== + + class ElementwiseSigmoidModule(torch.nn.Module): def __init__(self): From 4857606ffe3a04f6f040ca3f829904d292afba2e Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 14 Dec 2023 08:53:47 -0800 Subject: [PATCH 006/283] [onnx] Lowerings from `onnx.selu` (#2634) Lowerings for `selu` lowerings for ONNX to the corresponding torch implementations. Torch's `selu` implementation has fewer features so we use the a generalized `elu` with the input scale set to `1.0`. --- .../Conversion/TorchOnnxToTorch/Patterns.h | 19 ++++++++++++ .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 31 ++++++++++++++++++- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 16 ++++++++++ 3 files changed, 65 insertions(+), 1 deletion(-) create mode 100644 test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 5b144503c0ec..ccb0c033c617 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -113,6 +113,25 @@ struct OpBinder { return failure(); } + ParseResult f32FloatAttr(float &value, StringRef nameSuffix, + float defaultValue = 0.0f) { + SmallString<64> name("torch.onnx."); + name.append(nameSuffix); + auto attr = op->getAttr(name); + if (!attr) { + value = defaultValue; + return success(); + } + if (auto floatAttr = dyn_cast(attr)) { + FloatType t = cast(floatAttr.getType()); + if (t.getWidth() != 32) + return failure(); + value = floatAttr.getValueAsDouble(); + return success(); + } + return failure(); + } + ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix, std::string defaultValue = "") { SmallString<64> name("torch.onnx."); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 23af89f329ab..b9fd49bd33f7 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -26,4 +26,33 @@ using namespace mlir::torch::onnx_c; // results in a lot of ONNX test cases that all reduce to the exact same // thing here, so we simplify. void mlir::torch::onnx_c::populateDefaultDomainQtoZ( - OnnxCustomOpConversionPattern &patterns) {} + OnnxCustomOpConversionPattern &patterns) { + + patterns.onOp( + "Selu", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + float alpha, gamma; + Value operand; + if (binder.tensorOperand(operand) || + binder.f32FloatAttr(alpha, "alpha") || + binder.f32FloatAttr(gamma, "gamma") || + binder.tensorResultType(resultType)) + return failure(); + + Value vAlpha = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), alpha)); + + Value vScale = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), gamma)); + + Value vInputScale = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), 1.0)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, vAlpha, vScale, vInputScale); + return success(); + }); +} diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir new file mode 100644 index 000000000000..8b98838dc769 --- /dev/null +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -0,0 +1,16 @@ +// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s +// Generally, the test cases accumulated here come from running the importer +// over all included backend tests that involve simple ops with no model +// level constants. This is a pragmatic choice which lets us have a lot +// of tests in this file, whereas the others tend to be more bespoke. + + +// CHECK-LABEL: func.func @test_selu +func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} { + // CHECK-DAG: %[[F1:.+]] = torch.constant.float 1 + // CHECK-DAG: %[[F2:.+]] = torch.constant.float 2 + // CHECK-DAG: %[[F3:.+]] = torch.constant.float 3 + // CHECK: %[[ELU:.+]] = torch.aten.elu %arg0, %[[F2]], %[[F3]], %[[F1]] + %0 = torch.operator "onnx.Selu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32, torch.onnx.gamma = 3.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} \ No newline at end of file From 65f517b3d0956bcebeeb362c0b841216e18d314b Mon Sep 17 00:00:00 2001 From: Sungsoon Cho <1025787+godot73@users.noreply.github.com> Date: Thu, 14 Dec 2023 12:43:21 -0800 Subject: [PATCH 007/283] Bump LLVM version to 762964e97fd66ab7728ecc92aa153a61266fa9df. (#2645) --- externals/llvm-project | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/externals/llvm-project b/externals/llvm-project index f7250179e22c..762964e97fd6 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit f7250179e22ce4aab96166493b27223fa28c2181 +Subproject commit 762964e97fd66ab7728ecc92aa153a61266fa9df From 4ec8b9fc02adfab64f363ee40e6949ea73b9a628 Mon Sep 17 00:00:00 2001 From: Andreas Falkenberg <149819731+afalkenberg1@users.noreply.github.com> Date: Thu, 14 Dec 2023 19:23:23 -0800 Subject: [PATCH 008/283] [onnx] add support for onnx.LessOrEqual (#2639) Added the less or equal operation to OnnxToTorch. onnx.LessOrEqual --------- Co-authored-by: root --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 15 ++++++++++++++- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 9 +++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 34e7068c8681..9e2ca4e06c79 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -39,4 +39,17 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); -} + patterns.onOp("LessOrEqual", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + +} \ No newline at end of file diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 28b180a1e541..9a29de584327 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -23,4 +23,13 @@ func.func @test_matmul_4d(%arg0: !torch.vtensor<[1,2,3,4],f32>, %arg1: !torch.vt // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[1,2,3,4],f32>, !torch.vtensor<[1,2,4,3],f32> -> !torch.vtensor<[1,2,3,3],f32> %0 = torch.operator "onnx.MatMul"(%arg0, %arg1) : (!torch.vtensor<[1,2,3,4],f32>, !torch.vtensor<[1,2,4,3],f32>) -> !torch.vtensor<[1,2,3,3],f32> return %0 : !torch.vtensor<[1,2,3,3],f32> +} + +// CHECK-LABEL: func.func @test_less_or_equal +func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> + // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.le.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> + %0 = torch.operator "onnx.LessOrEqual"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> + return %0 : !torch.vtensor<[3,4,5],i1> } \ No newline at end of file From f59c01fd2fbcb5162b8147f34c68cd970f2c26a9 Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Fri, 15 Dec 2023 09:36:18 -0800 Subject: [PATCH 009/283] [MLIR][ONNX] Add OnnxToTorch support for q-z ops (specific ops in description) (#2601) This commit adds the OnnxToTorch support for Reciprocal, Round, ScatterElements, Sigmoid, Sin, Tanh, Sqrt, Sub, Sum, Where, Xor, Squeeze, Unsqueeze ops. For reviewers, the ops that weren't trivial and probably require extra review are Sum, Squeeze, and Unsqueeze. --- .../Conversion/TorchOnnxToTorch/Patterns.h | 14 + .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 404 +++++++++++++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 472 +++++++++++++++++- 3 files changed, 889 insertions(+), 1 deletion(-) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index ccb0c033c617..94451ac5c927 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -54,6 +54,20 @@ struct OpBinder { return success(); } + ParseResult tensorOperands(SmallVector &valueList, + int64_t numOperands) { + if (op->getNumOperands() != numOperands) + return failure(); + for (int i = 0; i < numOperands; i++) { + Value curr = op->getOperand(i); + if (!toValidTensorType(curr.getType())) { + return failure(); + } + valueList.push_back(curr); + } + return success(); + } + ParseResult tensorOperandAtIndex(Value &valueIdx, int64_t idx) { if (idx >= op->getNumOperands()) return failure(); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index b9fd49bd33f7..a8fa8972fafc 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -8,6 +8,8 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" using namespace mlir; using namespace mlir::torch; @@ -27,6 +29,408 @@ using namespace mlir::torch::onnx_c; // thing here, so we simplify. void mlir::torch::onnx_c::populateDefaultDomainQtoZ( OnnxCustomOpConversionPattern &patterns) { + patterns.onOp("Reciprocal", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp( + "Relu", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value x; + if (binder.tensorOperand(x) || binder.tensorResultType(resultType)) + return failure(); + + rewriter.replaceOpWithNewOp(binder.op, resultType, + x); + return success(); + }); + patterns.onOp("Round", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp( + "ScatterElements", 18, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + SmallVector valList; + int64_t axis; + std::string reduction; + int64_t numOperands = binder.op->getNumOperands(); + if (binder.tensorOperands(valList, numOperands) || + binder.s64IntegerAttr(axis, "axis", 0) || + binder.customOpNameStringAttr(reduction, "reduction", "none") || + binder.tensorResultType(resultType)) + return failure(); + + Value data = valList[0]; + Value indices = valList[1]; + Value updates = valList[2]; + + // ONNX allows negative axis. + if (axis < 0) + axis += + cast(data.getType()).getSizes().size(); + + Value constAxis = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); + + if (reduction == "none") { + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, constAxis, indices, updates); + return success(); + } + + // TODO: Implement max and min cases + if (reduction == "mul") { + reduction = "multiply"; + } else if (reduction == "max" || reduction == "min") { + return rewriter.notifyMatchFailure( + binder.op, "max/min reduction unsupported for scatter elements"); + } + + Value cstStrReduction = + rewriter.create(binder.getLoc(), reduction); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, constAxis, indices, updates, + cstStrReduction); + return success(); + }); + patterns.onOp( + "Sigmoid", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value x; + if (binder.tensorOperand(x) || binder.tensorResultType(resultType)) + return failure(); + + rewriter.replaceOpWithNewOp(binder.op, resultType, + x); + return success(); + }); + patterns.onOp("Sin", 7, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp("Tanh", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp("Sqrt", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp( + "Sub", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value x; + Value y; + if (binder.tensorOperands(x, y) || binder.tensorResultType(resultType)) + return failure(); + Value const1 = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, x, y, const1); + return success(); + }); + patterns.onOp( + "Sum", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + if (binder.op->getNumOperands() == 1) { + Torch::ValueTensorType resultType; + Value x; + if (binder.tensorOperand(x) || binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOp(binder.op, x); + return success(); + } + Torch::ValueTensorType resultType; + SmallVector valList; + int64_t numOperands = binder.op->getNumOperands(); + if (binder.tensorOperands(valList, numOperands) || + binder.tensorResultType(resultType)) + return failure(); + Value const1 = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); + // Short circuit to binary add + if (numOperands == 2) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, valList[0], valList[1], const1); + return success(); + } + // When binder.op->getNumOperands() > 2 + auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation( + binder.op->getContext()); + Value curr = rewriter.create( + binder.getLoc(), resultType, valList[0], valList[1], const1); + for (int i = 2; i < numOperands; i++) { + if (i == numOperands - 1) { + curr = rewriter.create( + binder.getLoc(), resultType, curr, valList[i], const1); + } else { + curr = rewriter.create( + binder.getLoc(), baseType, curr, valList[i], const1); + } + } + rewriter.replaceOp(binder.op, curr); + return success(); + }); + patterns.onOp("Where", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + SmallVector valList; + int64_t numOperands = binder.op->getNumOperands(); + if (binder.tensorOperands(valList, numOperands) || + binder.tensorResultType(resultType)) + return failure(); + Value condition = valList[0]; + Value x = valList[1]; + Value y = valList[2]; + rewriter.replaceOpWithNewOp( + binder.op, resultType, condition, x, y); + return success(); + }); + patterns.onOp( + "Xor", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value x; + Value y; + if (binder.tensorOperands(x, y) || binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp(binder.op, + resultType, x, y); + return success(); + }); + patterns.onOp( + "Squeeze", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + Value axes; + Value result; + if (binder.tensorOperands(data, axes) || + binder.tensorResultType(resultType)) + return failure(); + Torch::BaseTensorType axesType = + axes.getType().cast(); + SmallVector dimList; + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = axesType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); + auto sizes = + dyn_cast(axes.getType()).getSizes(); + if (sizes.size() == 0) { + rewriter.replaceOpWithNewOp(binder.op, + resultType, data); + return success(); + } + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + int64_t adjustmentInt = + cast(data.getType()).getSizes().size(); + Value adjustment = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + adjustmentInt)); + for (int i = 0; i < sizes[0]; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, axes, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + // deal with neg axis: if (axis < 0) axis += rank + Value isNegative = + rewriter.create(binder.getLoc(), dim, zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, adjustment); + Value finalDim = rewriter.create( + binder.getLoc(), dim, finalOffset); + dimList.push_back(finalDim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList); + return success(); + }); + patterns.onOp( + "Unsqueeze", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // Unlike squeeze where we are able to lower to Torch::PrimsSqueezeOp, + // pytorch does not support torch.unsqueeze to insert multiple new dims. + // discussion can be found here: + // https://github.com/pytorch/pytorch/issues/9410 + // So, for now, we unroll into multiple unsqueezes. + Torch::ValueTensorType resultType; + Value data; + Value axes; + if (binder.tensorOperands(data, axes) || + binder.tensorResultType(resultType)) + return failure(); + Torch::BaseTensorType axesType = + axes.getType().cast(); + SmallVector dimList; + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = axesType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); + auto sizes = + dyn_cast(axes.getType()).getSizes(); + if (sizes.size() == 0) { + rewriter.replaceOp(binder.op, data); + return success(); + } + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + int64_t adjustmentInt = + cast(data.getType()).getSizes().size(); + Value adjustment = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + adjustmentInt)); + for (int i = 0; i < sizes[0]; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, axes, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + // deal with neg axis: if (axis < 0) axis += rank + Value isNegative = + rewriter.create(binder.getLoc(), dim, zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, adjustment); + Value finalDim = rewriter.create( + binder.getLoc(), dim, finalOffset); + dimList.push_back(finalDim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + Value noneVal = rewriter.create(binder.getLoc()); + Value updatedAxes = rewriter.create( + binder.getLoc(), + axesType.getWithSizesAndDtype(sizes, axesType.getOptionalDtype()), + dimValueList, /*dtype=*/noneVal, /*device=*/noneVal, cstFalse); + // Sort the list of dims, so we don't run into this situation: + // data.sizes = [2, 3, 4] + // dims = [4, 0] + // index 4 will be invalid to add a singleton dimension because + // data.sizes.size == 3 We have to work with sorted dims to avoid this + // situation. + auto sortIndicesType = axesType.getWithSizesAndDtype( + axesType.getOptionalSizes(), + IntegerType::get(binder.op->getContext(), 64, IntegerType::Signed)); + auto sortOpResult = rewriter.create( + binder.getLoc(), axes.getType(), sortIndicesType, updatedAxes, zero, + cstFalse); + Value result; + auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation( + binder.op->getContext()); + // Go through the updated, sorted axes. Do unsqueeze for each dim. + for (int i = 0; i < sizes[0]; i++) { + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, sortOpResult->getResult(0), + zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + if (sizes[0] == 1) { + result = rewriter.create( + binder.getLoc(), resultType, data, dim); + } else if (i == 0) { + result = rewriter.create( + binder.getLoc(), baseType, data, dim); + } else if (i == sizes[0] - 1) { + result = rewriter.create( + binder.getLoc(), resultType, result, dim); + } else { + result = rewriter.create( + binder.getLoc(), baseType, result, dim); + } + } + rewriter.replaceOp(binder.op, result); + return success(); + }); + patterns.onOp( + "Softmax", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + int64_t axis; + if (binder.tensorOperand(input) || + binder.s64IntegerAttr(axis, "axis", -1) || + binder.tensorResultType(resultType)) + return failure(); + + // ONNX allows negative axis. + if (axis < 0) + axis += + cast(input.getType()).getSizes().size(); + + Value constAxis = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); + + Value noneVal = rewriter.create(binder.getLoc()); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, constAxis, /*dtype=*/noneVal); + return success(); + }); patterns.onOp( "Selu", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 8b98838dc769..f85221d971b4 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -4,6 +4,476 @@ // level constants. This is a pragmatic choice which lets us have a lot // of tests in this file, whereas the others tend to be more bespoke. +// CHECK-LABEL: func.func @test_reciprocal +func.func @test_reciprocal(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.reciprocal %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Reciprocal"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: func.func @test_relu +func.func @test_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.relu %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Relu"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: func.func @test_round +func.func @test_round(%arg0: !torch.vtensor<[15],f32>) -> !torch.vtensor<[15],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + //CHECK: torch.aten.round %arg0 : !torch.vtensor<[15],f32> -> !torch.vtensor<[15],f32> + %0 = torch.operator "onnx.Round"(%arg0) : (!torch.vtensor<[15],f32>) -> !torch.vtensor<[15],f32> + return %0 : !torch.vtensor<[15],f32> +} + +// CHECK-LABEL: func.func @test_scatter_elements_with_axis +func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.scatter.src %arg0, %int1, %arg1, %arg2 : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32> -> !torch.vtensor<[1,5],f32> + %0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> + return %0 : !torch.vtensor<[1,5],f32> +} + +// CHECK-LABEL: func.func @test_scatter_elements_with_duplicate_indices +func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[STR:.*]] = torch.constant.str "add" + // CHECK: torch.aten.scatter.reduce %arg0, %int1, %arg1, %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32> + %0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "add"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> + return %0 : !torch.vtensor<[1,5],f32> +} + +// CHECK-LABEL: func.func @test_scatter_elements_without_axis +func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>, %arg1: !torch.vtensor<[2,3],si64>, %arg2: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[3,3],f32>, !torch.int, !torch.vtensor<[2,3],si64>, !torch.vtensor<[2,3],f32> -> !torch.vtensor<[3,3],f32> + %0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) : (!torch.vtensor<[3,3],f32>, !torch.vtensor<[2,3],si64>, !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3,3],f32> + return %0 : !torch.vtensor<[3,3],f32> +} + +// CHECK-LABEL: func.func @test_scatter_elements_with_reduction_mul +func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[STR:.*]] = torch.constant.str "multiply" + // CHECK: torch.aten.scatter.reduce %arg0, %int1, %arg1, %arg2, %str : !torch.vtensor<[1,5],f32>, !torch.int, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>, !torch.str -> !torch.vtensor<[1,5],f32> + %0 = torch.operator "onnx.ScatterElements"(%arg0, %arg1, %arg2) {torch.onnx.axis = 1 : si64, torch.onnx.reduction = "mul"} : (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,2],si64>, !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> + return %0 : !torch.vtensor<[1,5],f32> +} + +// CHECK-LABEL: func.func @test_sigmoid_example +func.func @test_sigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.sigmoid %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Sigmoid"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// CHECK-LABEL: func.func @test_sin_example +func.func @test_sin_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.sin %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Sin"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// CHECK-LABEL: func.func @test_tanh_example +func.func @test_tanh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.tanh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Tanh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// CHECK-LABEL: func.func @test_sqrt_example +func.func @test_sqrt_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.sqrt %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Sqrt"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// CHECK-LABEL: func.func @test_sub_bcast +func.func @test_sub_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Sub"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: func.func @test_sub_example +func.func @test_sub_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Sub"(%arg0, %arg1) : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// CHECK-LABEL: func.func @test_sub +func.func @test_sub(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Sub"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: func.func @test_sub_uint8 +func.func @test_sub_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.sub.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8>, !torch.int -> !torch.vtensor<[3,4,5],ui8> + %0 = torch.operator "onnx.Sub"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> + return %0 : !torch.vtensor<[3,4,5],ui8> +} + +// CHECK-LABEL: func.func @test_sum_example +func.func @test_sum_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: torch.aten.add.Tensor %0, %arg2, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor + // CHECK: torch.aten.add.Tensor %1, %arg3, %int1 : !torch.vtensor, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Sum"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// CHECK-LABEL: func.func @test_sum_one_input +func.func @test_sum_one_input(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.Sum"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// CHECK-LABEL: func.func @test_sum_two_inputs +func.func @test_sum_two_inputs(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Sum"(%arg0, %arg1) : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// CHECK-LABEL: func.func @test_where_example +func.func @test_where_example(%arg0: !torch.vtensor<[2,2],i1>, %arg1: !torch.vtensor<[2,2],f32>, %arg2: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[2,2],i1>, !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32> -> !torch.vtensor<[2,2],f32> + %0 = torch.operator "onnx.Where"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,2],i1>, !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> + return %0 : !torch.vtensor<[2,2],f32> +} + +// CHECK-LABEL: func.func @test_where_long_example +func.func @test_where_long_example(%arg0: !torch.vtensor<[2,2],i1>, %arg1: !torch.vtensor<[2,2],si64>, %arg2: !torch.vtensor<[2,2],si64>) -> !torch.vtensor<[2,2],si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[2,2],i1>, !torch.vtensor<[2,2],si64>, !torch.vtensor<[2,2],si64> -> !torch.vtensor<[2,2],si64> + %0 = torch.operator "onnx.Where"(%arg0, %arg1, %arg2) : (!torch.vtensor<[2,2],i1>, !torch.vtensor<[2,2],si64>, !torch.vtensor<[2,2],si64>) -> !torch.vtensor<[2,2],si64> + return %0 : !torch.vtensor<[2,2],si64> +} + +// CHECK-LABEL: func.func @test_xor2d +func.func @test_xor2d(%arg0: !torch.vtensor<[3,4],i1>, %arg1: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4],i1>, !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1> + %0 = torch.operator "onnx.Xor"(%arg0, %arg1) : (!torch.vtensor<[3,4],i1>, !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> + return %0 : !torch.vtensor<[3,4],i1> +} + +// CHECK-LABEL: func.func @test_xor3d +func.func @test_xor3d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.vtensor<[3,4,5],i1>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],i1> -> !torch.vtensor<[3,4,5],i1> + %0 = torch.operator "onnx.Xor"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],i1>) -> !torch.vtensor<[3,4,5],i1> + return %0 : !torch.vtensor<[3,4,5],i1> +} + +// CHECK-LABEL: func.func @test_xor4d +func.func @test_xor4d(%arg0: !torch.vtensor<[3,4,5,6],i1>, %arg1: !torch.vtensor<[3,4,5,6],i1>) -> !torch.vtensor<[3,4,5,6],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4,5,6],i1>, !torch.vtensor<[3,4,5,6],i1> -> !torch.vtensor<[3,4,5,6],i1> + %0 = torch.operator "onnx.Xor"(%arg0, %arg1) : (!torch.vtensor<[3,4,5,6],i1>, !torch.vtensor<[3,4,5,6],i1>) -> !torch.vtensor<[3,4,5,6],i1> + return %0 : !torch.vtensor<[3,4,5,6],i1> +} + +// CHECK-LABEL: func.func @test_xor_bcast3v1d +func.func @test_xor_bcast3v1d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.vtensor<[5],i1>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[5],i1> -> !torch.vtensor<[3,4,5],i1> + %0 = torch.operator "onnx.Xor"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],i1>, !torch.vtensor<[5],i1>) -> !torch.vtensor<[3,4,5],i1> + return %0 : !torch.vtensor<[3,4,5],i1> +} + +// CHECK-LABEL: func.func @test_xor_bcast4v4d +func.func @test_xor_bcast4v4d(%arg0: !torch.vtensor<[1,4,1,6],i1>, %arg1: !torch.vtensor<[3,1,5,6],i1>) -> !torch.vtensor<[3,4,5,6],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[1,4,1,6],i1>, !torch.vtensor<[3,1,5,6],i1> -> !torch.vtensor<[3,4,5,6],i1> + %0 = torch.operator "onnx.Xor"(%arg0, %arg1) : (!torch.vtensor<[1,4,1,6],i1>, !torch.vtensor<[3,1,5,6],i1>) -> !torch.vtensor<[3,4,5,6],i1> + return %0 : !torch.vtensor<[3,4,5,6],i1> +} + +// CHECK-LABEL: func.func @test_squeeze +func.func @test_squeeze(%arg0: !torch.vtensor<[1,3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT4:.*]] = torch.constant.int 4 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: torch.prims.squeeze %arg0, %6 : !torch.vtensor<[1,3,4,5],f32>, !torch.list -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Squeeze"(%arg0, %arg1) : (!torch.vtensor<[1,3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: func.func @test_squeeze_two_axes +func.func @test_squeeze_two_axes(%arg0: !torch.vtensor<[3,1,4,5,1],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT5:.*]] = torch.constant.int 5 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int5 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %9, %int5 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5, %11 : (!torch.int, !torch.int) -> !torch.list + // CHECK: torch.prims.squeeze %arg0, %12 : !torch.vtensor<[3,1,4,5,1],f32>, !torch.list -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Squeeze"(%arg0, %arg1) : (!torch.vtensor<[3,1,4,5,1],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: func.func @test_unsqueeze_axis_0 +func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT3:.*]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: torch.constant.bool false + // CHECK: torch.constant.none + // CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[1,3,4,5],f32> + %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,4,5],f32> + return %0 : !torch.vtensor<[1,3,4,5],f32> +} + +// CHECK-LABEL: func.func @test_unsqueeze_axis_1 +func.func @test_unsqueeze_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT3:.*]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,1,4,5],f32> + %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,4,5],f32> + return %0 : !torch.vtensor<[3,1,4,5],f32> +} + +// CHECK-LABEL: func.func @test_unsqueeze_axis_2 +func.func @test_unsqueeze_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,1,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT3:.*]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor<[3,4,1,5],f32> + %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,1,5],f32> + return %0 : !torch.vtensor<[3,4,1,5],f32> +} + +// CHECK-LABEL: func.func @test_unsqueeze_negative_axes +func.func @test_unsqueeze_negative_axes(%arg0: !torch.vtensor<[1,3,1,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,1,1,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT4:.*]] = torch.constant.int 4 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.tensor %6, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.sort %7, %int0, %false : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %8 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.unsqueeze %arg0, %9 : !torch.vtensor<[1,3,1,5],f32>, !torch.int -> !torch.vtensor<[1,3,1,1,5],f32> + %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[1,3,1,5],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,1,1,5],f32> + return %0 : !torch.vtensor<[1,3,1,1,5],f32> +} + +// CHECK-LABEL: func.func @test_unsqueeze_three_axes +func.func @test_unsqueeze_three_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT3:.*]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %13, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %15, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.tensor %18, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[3],si64> + // CHECK: torch.aten.sort %19, %int0, %false : !torch.vtensor<[3],si64>, !torch.int, !torch.bool -> !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64> + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %20 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.unsqueeze %arg0, %21 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor + // CHECK: %[[INT1_2:.*]] = torch.constant.int 1 + // CHECK: torch.aten.select.int %values, %int0, %int1_2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %23 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.unsqueeze %22, %24 : !torch.vtensor, !torch.int -> !torch.vtensor + // CHECK: %[[INT2_3:.*]] = torch.constant.int 2 + // CHECK: torch.aten.select.int %values, %int0, %int2_3 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %26 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.unsqueeze %25, %27 : !torch.vtensor, !torch.int -> !torch.vtensor<[3,4,1,5,1,1],f32> + %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> + return %0 : !torch.vtensor<[3,4,1,5,1,1],f32> +} + +// CHECK-LABEL: func.func @test_unsqueeze_unsorted_axes +func.func @test_unsqueeze_unsorted_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT3:.*]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %7, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %13, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %15, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.tensor %18, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[3],si64> + // CHECK: torch.aten.sort %19, %int0, %false : !torch.vtensor<[3],si64>, !torch.int, !torch.bool -> !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64> + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %values, %int0, %int0_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %20 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.unsqueeze %arg0, %21 : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.vtensor + // CHECK: %[[INT1_2:.*]] = torch.constant.int 1 + // CHECK: torch.aten.select.int %values, %int0, %int1_2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %23 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.unsqueeze %22, %24 : !torch.vtensor, !torch.int -> !torch.vtensor + // CHECK: %[[INT2_3:.*]] = torch.constant.int 2 + // CHECK: torch.aten.select.int %values, %int0, %int2_3 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %26 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.unsqueeze %25, %27 : !torch.vtensor, !torch.int -> !torch.vtensor<[3,4,1,5,1,1],f32> + %0 = torch.operator "onnx.Unsqueeze"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> + return %0 : !torch.vtensor<[3,4,1,5,1,1],f32> +} + +// CHECK-LABEL: func.func @test_softmax_axis_0 +func.func @test_softmax_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.softmax.int %arg0, %int0, %none : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.none -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Softmax"(%arg0) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: func.func @test_softmax_axis_1 +func.func @test_softmax_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.softmax.int %arg0, %int1, %none : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.none -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Softmax"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: func.func @test_softmax_axis_2 +func.func @test_softmax_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.softmax.int %arg0, %int2, %none : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.none -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Softmax"(%arg0) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: func.func @test_softmax_default_axis +func.func @test_softmax_default_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.softmax.int %arg0, %int2, %none : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.none -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Softmax"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: func.func @test_softmax_large_number +func.func @test_softmax_large_number(%arg0: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.softmax.int %arg0, %int1, %none : !torch.vtensor<[2,4],f32>, !torch.int, !torch.none -> !torch.vtensor<[2,4],f32> + %0 = torch.operator "onnx.Softmax"(%arg0) : (!torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> + return %0 : !torch.vtensor<[2,4],f32> +} + +// CHECK-LABEL: func.func @test_softmax_negative_axis +func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.softmax.int %arg0, %int2, %none : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.none -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Softmax"(%arg0) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} // CHECK-LABEL: func.func @test_selu func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} { @@ -13,4 +483,4 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // CHECK: %[[ELU:.+]] = torch.aten.elu %arg0, %[[F2]], %[[F3]], %[[F1]] %0 = torch.operator "onnx.Selu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32, torch.onnx.gamma = 3.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> -} \ No newline at end of file +} From d9f4a80b10ae849a0615674e87751b78d00be241 Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Fri, 15 Dec 2023 12:47:22 -0500 Subject: [PATCH 010/283] Bump LLVM version to fcd54b368e6713acd236dc47401b5292755900d7 (#2654) This bumps the llvm submodule to HEAD to pick up recent fixes. --- externals/llvm-project | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/externals/llvm-project b/externals/llvm-project index 762964e97fd6..fcd54b368e67 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 762964e97fd66ab7728ecc92aa153a61266fa9df +Subproject commit fcd54b368e6713acd236dc47401b5292755900d7 From eb9249e601b0a3b01650f07b8c99b2b8a3a3a190 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Sat, 16 Dec 2023 00:48:28 +0530 Subject: [PATCH 011/283] [ONNX][MLIR] Add support for LeakyRelu and GatherElements op (#2655) This commit adds support for `LeakyRelu and GatherElements` op in the onnx pipeline. Signed-off-by: Gaurav Shukla --- .../Conversion/TorchOnnxToTorch/Patterns.h | 2 +- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 41 +++++++++++++++++-- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 19 ++++++++- 3 files changed, 57 insertions(+), 5 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 94451ac5c927..d8d519294c54 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -140,7 +140,7 @@ struct OpBinder { FloatType t = cast(floatAttr.getType()); if (t.getWidth() != 32) return failure(); - value = floatAttr.getValueAsDouble(); + value = floatAttr.getValue().convertToFloat(); return success(); } return failure(); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 9e2ca4e06c79..d97964bcf608 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -27,7 +27,6 @@ using namespace mlir::torch::onnx_c; // thing here, so we simplify. void mlir::torch::onnx_c::populateDefaultDomainGtoP( OnnxCustomOpConversionPattern &patterns) { - patterns.onOp("MatMul", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -51,5 +50,41 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); - -} \ No newline at end of file + patterns.onOp( + "GatherElements", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data, indices; + int64_t axis; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorOperandAtIndex(indices, 1) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(axis, "axis", 0)) + return failure(); + Value constAxis = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); + Value sparseGrad = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getBoolAttr(false)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, constAxis, indices, sparseGrad); + return success(); + }); + patterns.onOp("LeakyRelu", 16, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + float alpha; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType) || + binder.f32FloatAttr(alpha, "alpha", 0.01)) + return failure(); + Value constAlpha = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(alpha)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, constAlpha); + return success(); + }); +} diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 9a29de584327..1ab2db46d1d2 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -4,6 +4,23 @@ // level constants. This is a pragmatic choice which lets us have a lot // of tests in this file, whereas the others tend to be more bespoke. +// CHECK-LABEL: func.func @test_gather_elements +func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} { + // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[GATHER:.+]] = torch.aten.gather %arg0, %[[INT0]], %arg1, %[[FALSE]] + %0 = torch.operator "onnx.GatherElements"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: func.func @test_leaky_relu +func.func @test_leaky_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 16 : si64} { + // CHECK-DAG: %[[F2:.+]] = torch.constant.float 2 + // CHECK: %[[LRELU:.+]] = torch.aten.leaky_relu %arg0, %[[F2]] + %0 = torch.operator "onnx.LeakyRelu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + // CHECK-LABEL: @test_matmul_2d func.func @test_matmul_2d(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[3,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[3,3],f32> @@ -32,4 +49,4 @@ func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch. // CHECK: torch.aten.le.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> %0 = torch.operator "onnx.LessOrEqual"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> return %0 : !torch.vtensor<[3,4,5],i1> -} \ No newline at end of file +} From 55e9401c5ca13aaa4a327bf5ed91583a9e5a9f4d Mon Sep 17 00:00:00 2001 From: Sungsoon Cho <1025787+godot73@users.noreply.github.com> Date: Fri, 15 Dec 2023 11:19:26 -0800 Subject: [PATCH 012/283] Implement lowering of aten.cosh op. (#2635) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 45 +++++++++++++++++++ .../TorchToLinalg/Uncategorized.cpp | 8 +++- .../Transforms/AbstractInterpLibrary.cpp | 11 ++++- .../build_tools/abstract_interp_lib_gen.py | 10 +++++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/elementwise.py | 44 ++++++++++++++++++ 6 files changed, 116 insertions(+), 3 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 7b684886d5ea..f9c878874cf5 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -64,6 +64,51 @@ def Torch_AtenTanh_Op : Torch_Op<"aten.tanh_", [ }]; } +def Torch_AtenCoshOp : Torch_Op<"aten.cosh", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::cosh : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCoshOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenCoshOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenCosh_Op : Torch_Op<"aten.cosh_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::cosh_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCosh_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenCosh_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenHardtanhOp : Torch_Op<"aten.hardtanh", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 2cc37a88313a..5e38ec1a1490 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -220,6 +220,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); @@ -1311,7 +1315,7 @@ class ConvertElementwiseOp : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (!isa) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.cosh\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.tanh\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8523,7 +8527,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" -" func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.tuple) -> !torch.int {\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cosh\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" @@ -8565,6 +8569,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.prim.ListConstruct %int5, %int15, %int6, %int7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.exp\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 903b6c3bf36e..66e47bd45ef8 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -62,6 +62,9 @@ def aten〇tril〡shape(self: List[int], diagonal: int = 0) -> List[int]: def aten〇atan〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇cosh〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇tanh〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -1538,6 +1541,13 @@ def prims〇split_dim〡dtype(a_rank_dtype: Tuple[int, int], dim: int, outer_len _, a_dtype = a_rank_dtype return a_dtype + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇cosh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 32f0d0df862d..97480bbd4b09 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -256,6 +256,7 @@ def emit_with_mutating_variants(key, **kwargs): # Elementwise tensor compute ops for key in [ "aten::tanh : (Tensor) -> (Tensor)", + "aten::cosh : (Tensor) -> (Tensor)", "aten::hardtanh : (Tensor, Scalar, Scalar) -> (Tensor)", "aten::elu : (Tensor, Scalar, Scalar, Scalar) -> (Tensor)", "aten::relu : (Tensor) -> (Tensor)", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 0b5330b91380..0b45a151c681 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -63,6 +63,50 @@ def ElementwiseUnaryIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseCoshModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.cosh(a) + + +@register_test_case(module_factory=lambda: ElementwiseCoshModule()) +def ElementwiseCoshModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseCoshIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.cosh(a) + + +@register_test_case(module_factory=lambda: ElementwiseCoshIntModule()) +def ElementwiseCoshIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + +# ============================================================================== + + class ElementwiseBinaryModule(torch.nn.Module): def __init__(self): From 061af696ce94c932152bdf64ca7eba3b4034b367 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 15 Dec 2023 11:37:49 -0800 Subject: [PATCH 013/283] [onnx] Lowering for `onnx.shape` to `torch` and `tensor` (#2648) Includes the lowering from the `aten` equivalent to `tensor` operations. --- include/torch-mlir/Conversion/Passes.td | 9 ++ .../Conversion/TorchToTensor/TorchToTensor.h | 23 +++++ lib/Conversion/CMakeLists.txt | 2 + lib/Conversion/Passes.cpp | 7 +- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 13 +++ lib/Conversion/TorchToTensor/CMakeLists.txt | 18 ++++ .../TorchToTensor/TorchToTensor.cpp | 93 +++++++++++++++++++ .../TorchToTensor/torch_to_tensor.mlir | 8 ++ 8 files changed, 170 insertions(+), 3 deletions(-) create mode 100644 include/torch-mlir/Conversion/TorchToTensor/TorchToTensor.h create mode 100644 lib/Conversion/TorchToTensor/CMakeLists.txt create mode 100644 lib/Conversion/TorchToTensor/TorchToTensor.cpp create mode 100644 test/Conversion/TorchToTensor/torch_to_tensor.mlir diff --git a/include/torch-mlir/Conversion/Passes.td b/include/torch-mlir/Conversion/Passes.td index 3a130f472b3b..ed58c699559c 100644 --- a/include/torch-mlir/Conversion/Passes.td +++ b/include/torch-mlir/Conversion/Passes.td @@ -105,6 +105,15 @@ def ConvertTorchToLinalg : Pass<"convert-torch-to-linalg", "func::FuncOp"> { let constructor = "mlir::torch::createConvertTorchToLinalgPass()"; } +def ConvertTorchToTensor : Pass<"convert-torch-to-tensor", "func::FuncOp"> { + let summary = "Convert Torch ops to the Tensor dialect"; + let description = [{ + Converts any `Torch` operators that were expressible as `Tensor` dialect + operations. + }]; + let constructor = "mlir::torch::createConvertTorchToTensorPass()"; +} + def ConvertTorchToTosa : Pass<"convert-torch-to-tosa", "func::FuncOp"> { let summary = "Convert Torch ops to TOSA ops"; let description = [{ diff --git a/include/torch-mlir/Conversion/TorchToTensor/TorchToTensor.h b/include/torch-mlir/Conversion/TorchToTensor/TorchToTensor.h new file mode 100644 index 000000000000..9dd5a65429ed --- /dev/null +++ b/include/torch-mlir/Conversion/TorchToTensor/TorchToTensor.h @@ -0,0 +1,23 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHTOTENSOR_TORCHTOTENSOR_H +#define TORCHMLIR_CONVERSION_TORCHTOTENSOR_TORCHTOTENSOR_H + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include + +namespace mlir { +namespace torch { +std::unique_ptr> createConvertTorchToTensorPass(); +} // namespace torch +} // namespace mlir + +#endif // TORCHMLIR_CONVERSION_TORCHTOTENSOR_TORCHTOTENSOR_H diff --git a/lib/Conversion/CMakeLists.txt b/lib/Conversion/CMakeLists.txt index afbe775d3a20..dd9e94a50080 100644 --- a/lib/Conversion/CMakeLists.txt +++ b/lib/Conversion/CMakeLists.txt @@ -2,6 +2,7 @@ add_subdirectory(TorchOnnxToTorch) add_subdirectory(TorchToLinalg) add_subdirectory(TorchToSCF) add_subdirectory(TorchToArith) +add_subdirectory(TorchToTensor) add_subdirectory(TorchToTosa) if(TORCH_MLIR_ENABLE_STABLEHLO) add_subdirectory(TorchToStablehlo) @@ -14,6 +15,7 @@ add_subdirectory(Utils) set(linked_libs TorchMLIRTorchToLinalg TorchMLIRTorchToSCF TorchMLIRTorchToArith + TorchMLIRTorchToTensor TorchMLIRTorchToTosa TorchMLIRTorchToTMTensor TorchMLIRTorchConversionToMLProgram diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index 0dae24678a4b..b9af2afa3f81 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -13,12 +13,13 @@ #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #endif // TORCH_MLIR_ENABLE_STABLEHLO +#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" +#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" -#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" -#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" -#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" +#include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h" +#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" //===----------------------------------------------------------------------===// // Pass registration diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index a8fa8972fafc..482e20d6a7d3 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -459,4 +459,17 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, operand, vAlpha, vScale, vInputScale); return success(); }); + + patterns.onOp("Shape", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); } diff --git a/lib/Conversion/TorchToTensor/CMakeLists.txt b/lib/Conversion/TorchToTensor/CMakeLists.txt new file mode 100644 index 000000000000..21082d1d1258 --- /dev/null +++ b/lib/Conversion/TorchToTensor/CMakeLists.txt @@ -0,0 +1,18 @@ +add_mlir_conversion_library(TorchMLIRTorchToTensor + TorchToTensor.cpp + + ADDITIONAL_HEADER_DIRS + ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToTensor + + DEPENDS + TorchMLIRConversionPassIncGen + + LINK_LIBS PUBLIC + MLIRIR + MLIRPass + MLIRTensorDialect + TorchMLIRTorchDialect + TorchMLIRConversionUtils +) + +torch_mlir_target_includes(TorchMLIRTorchToTensor) diff --git a/lib/Conversion/TorchToTensor/TorchToTensor.cpp b/lib/Conversion/TorchToTensor/TorchToTensor.cpp new file mode 100644 index 000000000000..417fd17fcb86 --- /dev/null +++ b/lib/Conversion/TorchToTensor/TorchToTensor.cpp @@ -0,0 +1,93 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v3.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-1.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h" + +#include "../PassDetail.h" + +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Traits.h" +#include "mlir/IR/Matchers.h" +#include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { + +class ConvertAtenShapeToTensorPatternOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename Aten_ShapeAsTensorOp::Adaptor; + LogicalResult + matchAndRewrite(Aten_ShapeAsTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto operand = adaptor.getOperands()[0]; + auto operandTy = operand.getType().cast(); + auto resultTy = + getTypeConverter()->convertType(op.getType()).cast(); + + int64_t rank = operandTy.getRank(); + SmallVector dims; + for (int i = 0; i < rank; ++i) { + Value dim = rewriter.createOrFold(loc, operand, i); + dim = rewriter.createOrFold( + loc, resultTy.getElementType(), dim); + dims.push_back(dim); + } + + Value tensor = + rewriter.createOrFold(op.getLoc(), dims); + rewriter.replaceOp(op, tensor); + return success(); + } +}; + +class ConvertTorchToTensor + : public ConvertTorchToTensorBase { +public: + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + TorchConversion::getBackendTypeConversionDependentDialects(registry); + } + void runOnOperation() override { + MLIRContext *context = &getContext(); + ConversionTarget target(*context); + target.addLegalDialect(); + target.addLegalDialect(); + target.addIllegalOp(); + + TypeConverter typeConverter; + typeConverter.addConversion([](Type type) { return type; }); + TorchConversion::setupBackendTypeConversion(target, typeConverter); + + RewritePatternSet patterns(context); + patterns.add(typeConverter, context); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> +mlir::torch::createConvertTorchToTensorPass() { + return std::make_unique(); +} diff --git a/test/Conversion/TorchToTensor/torch_to_tensor.mlir b/test/Conversion/TorchToTensor/torch_to_tensor.mlir new file mode 100644 index 000000000000..277dabc3b891 --- /dev/null +++ b/test/Conversion/TorchToTensor/torch_to_tensor.mlir @@ -0,0 +1,8 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-tensor | FileCheck %s + +// CHECK-LABEL: func.func @test_shape +func.func @test_shape(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3],si64> { + // CHECK: %[[SHAPE:.+]] = arith.constant dense<[3, 4, 5]> : tensor<3xi64> + %0 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3],si64> + return %0 : !torch.vtensor<[3],si64> +} From 030b0140d45559743dff85573ca00ba10cce7a5a Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Fri, 15 Dec 2023 15:45:32 -0500 Subject: [PATCH 014/283] [TorchToLinalg] Lower aten.cat to tensor.concat (#2650) This replaces the lowering of aten.cat with tensor.concat, allowing more efficient handling of concatenations in downstream flows. The refbackend populates concat decomposition patterns that can be used to recover the previous lowering. --- include/torch-mlir/RefBackend/Passes.h | 2 + include/torch-mlir/RefBackend/Passes.td | 5 ++ lib/Conversion/TorchToLinalg/DataMovement.cpp | 51 +++---------------- lib/RefBackend/RefBackend.cpp | 27 +++++++++- .../linalg_on_tensors_backends/refbackend.py | 1 + test/Conversion/TorchToLinalg/basic.mlir | 38 ++++++++++++++ 6 files changed, 79 insertions(+), 45 deletions(-) diff --git a/include/torch-mlir/RefBackend/Passes.h b/include/torch-mlir/RefBackend/Passes.h index 8f1b2b525a22..be5e43a1e63c 100644 --- a/include/torch-mlir/RefBackend/Passes.h +++ b/include/torch-mlir/RefBackend/Passes.h @@ -31,6 +31,8 @@ std::unique_ptr> createMLProgramBufferizePass(); std::unique_ptr> createMungeMemrefCopyPass(); +std::unique_ptr> createGeneralizeTensorConcatPass(); + std::unique_ptr> createGeneralizeTensorPadPass(); } // namespace RefBackend } // namespace torch diff --git a/include/torch-mlir/RefBackend/Passes.td b/include/torch-mlir/RefBackend/Passes.td index 12d182e49e3a..3d8b7fd41b1b 100644 --- a/include/torch-mlir/RefBackend/Passes.td +++ b/include/torch-mlir/RefBackend/Passes.td @@ -35,6 +35,11 @@ def MungeMemrefCopy : Pass<"refback-munge-memref-copy", "func::FuncOp"> { let dependentDialects = ["memref::MemRefDialect"]; } +def GeneralizeTensorConcat : Pass<"refback-generalize-tensor-concat", "func::FuncOp"> { + let summary = "Convert tensor.concat to other tensor ops"; + let constructor = "mlir::torch::RefBackend::createGeneralizeTensorConcatPass()"; +} + def GeneralizeTensorPad : Pass<"refback-generalize-tensor-pad", "func::FuncOp"> { let summary = "Convert tensor.pad to linalg ops"; let constructor = "mlir::torch::RefBackend::createGeneralizeTensorPadPass()"; diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 4eb02215a8bf..dae387422b52 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1033,8 +1033,11 @@ class ConvertAtenCatOp : public OpConversionPattern { auto outElemType = newResultType.getElementType(); for (size_t i = 0; i < tensors.size(); ++i) { - tensors[i] = torch_to_linalg::convertTensorToElementType( - rewriter, loc, tensors[i], outElemType); + auto inputType = cast(tensors[i].getType()); + if (inputType.getElementType() != outElemType) { + tensors[i] = torch_to_linalg::convertTensorToElementType( + rewriter, loc, tensors[i], outElemType); + } } int rank = newResultType.getRank(); @@ -1046,48 +1049,8 @@ class ConvertAtenCatOp : public OpConversionPattern { if (!isValidDim(dim, rank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); - SmallVector offsets, sizes, strides; - sizes.reserve(rank); - strides.resize(rank, rewriter.create(loc, 1)); - offsets.resize(rank, rewriter.create(loc, 0)); - - for (int i = 0; i < rank; ++i) - sizes.push_back(rewriter.createOrFold(loc, tensors[0], i)); - - // Calculate the size of the `dim` result dimension by adding the dim size - // of each tensor together. - Value resultDimSize = sizes[dim]; - - Value dimIndex = rewriter.createOrFold( - loc, rewriter.getIndexAttr(dim)); - for (auto tensor : ArrayRef(tensors).drop_front()) { - auto size = rewriter.createOrFold(loc, tensor, dimIndex); - resultDimSize = - rewriter.createOrFold(loc, resultDimSize, size); - } - sizes[dim] = resultDimSize; - - auto toOpFoldResult = [](Value v) -> OpFoldResult { - auto op = v.getDefiningOp(); - if (!op) - return v; - return op.getValue(); - }; - - Value result = rewriter.create( - loc, getAsOpFoldResult(sizes), newResultType.getElementType()); - for (auto tensor : tensors) { - SmallVector sizes = getTensorSizes(rewriter, loc, tensor); - result = rewriter.createOrFold( - loc, tensor, result, - llvm::to_vector(llvm::map_range(offsets, toOpFoldResult)), - llvm::to_vector(llvm::map_range(sizes, toOpFoldResult)), - llvm::to_vector(llvm::map_range(strides, toOpFoldResult))); - offsets[dim] = - rewriter.createOrFold(loc, offsets[dim], sizes[dim]); - } - - rewriter.replaceOpWithNewOp(op, newResultType, result); + rewriter.replaceOpWithNewOp(op, newResultType, dim, + tensors); return success(); } }; diff --git a/lib/RefBackend/RefBackend.cpp b/lib/RefBackend/RefBackend.cpp index 481bdf3426d8..4ada196e944c 100644 --- a/lib/RefBackend/RefBackend.cpp +++ b/lib/RefBackend/RefBackend.cpp @@ -20,10 +20,12 @@ #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Transforms/Transforms.h" +#include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/Math/Transforms/Approximation.h" #include "mlir/Dialect/Math/Transforms/Passes.h" -#include "mlir/Dialect/MLProgram/IR/MLProgram.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" +#include "mlir/Dialect/Tensor/Transforms/Transforms.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" @@ -436,6 +438,29 @@ mlir::torch::RefBackend::createMungeMemrefCopyPass() { return std::make_unique(); } +namespace { +class GeneralizeTensorConcat + : public GeneralizeTensorConcatBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + tensor::populateDecomposeTensorConcatPatterns(patterns); + if (failed(applyPatternsAndFoldGreedily(getOperation(), + std::move(patterns)))) { + return signalPassFailure(); + } + } +}; +} // namespace + +std::unique_ptr> +mlir::torch::RefBackend::createGeneralizeTensorConcatPass() { + return std::make_unique(); +} + namespace { class GeneralizeTensorPad : public GeneralizeTensorPadBase { diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 1b9dbb0d2c51..266459e00b0c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -123,6 +123,7 @@ def invoke(*args): LOWERING_PIPELINE = "builtin.module(" + ",".join([ "func.func(refback-generalize-tensor-pad)", + "func.func(refback-generalize-tensor-concat)", # Apply some optimizations. It would be great if MLIR had more useful # optimizations that worked out of the box here. # Note: When measured, this doesn't seem to actually help that much diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index eba7546655e9..0aaca941b0d9 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -287,3 +287,41 @@ func.func @torch.aten.neg.f16(%arg0: !torch.vtensor<[?,?],f16>) -> !torch.vtenso %0 = torch.aten.neg %arg0 : !torch.vtensor<[?,?],f16> -> !torch.vtensor<[?,?],f16> return %0 : !torch.vtensor<[?,?],f16> } + +// ----- + +// CHECK-LABEL: func.func @torch.aten.cat$convert( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, %[[ARG1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[T0:.*]] = torch.prim.ListConstruct %[[ARG0]], %[[ARG1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list +// CHECK: %[[T1:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[T2:.*]] = torch_c.to_builtin_tensor %[[ARG1]] : !torch.vtensor<[?,?],si32> -> tensor +// CHECK: %[[T3:.*]] = linalg.generic {{.*}} ins(%[[T2]] : tensor) outs(%{{.*}}: tensor) +// CHECK: %[[T4:.*]] = tensor.concat dim(0) %[[T1]], %[[T3]] : (tensor, tensor) -> tensor +// CHECK: %[[T5:.*]] = torch_c.from_builtin_tensor %[[T4]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[T5]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],f32> { + %int0 = torch.constant.int 0 + %0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],si32>) -> !torch.list + %1 = torch.aten.cat %0, %int0 : !torch.list, !torch.int -> !torch.vtensor<[?,?],f32> + return %1 : !torch.vtensor<[?,?],f32> +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.cat( +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %[[VAL_0:.*]] = torch.prim.ListConstruct %[[ARG_0]], %[[ARG_1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?],f32> -> tensor +// CHECK: %[[VAL_3:.*]] = tensor.concat dim(0) %[[VAL_1]], %[[VAL_2]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_4:.*]] = torch_c.from_builtin_tensor %[[VAL_3]] : tensor -> !torch.vtensor<[?,?],f32> +// CHECK: return %[[VAL_4]] : !torch.vtensor<[?,?],f32> +func.func @torch.aten.cat(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { + %int0 = torch.constant.int 0 + %0 = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list + %1 = torch.aten.cat %0, %int0 : !torch.list, !torch.int -> !torch.vtensor<[?,?],f32> + return %1 : !torch.vtensor<[?,?],f32> +} From 705ea958ae1406e2f7cbff442c13db05587d3b33 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 15 Dec 2023 15:30:05 -0800 Subject: [PATCH 015/283] [onnx] Lowerings from `onnx.transpose` (#2641) Lowerings for `transpose` from ONNX to `aten`. Implementation depends on making multiple `aten.transpose` operations swapping pairs of dimensions. As `onnx.transpose` can swap around any dimensions it may require constructing multiple `aten.transpose`. --- .../Conversion/TorchOnnxToTorch/Patterns.h | 25 +++++++ .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 69 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 31 +++++++++ 3 files changed, 125 insertions(+) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index d8d519294c54..85d6f805f3f6 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -146,6 +146,31 @@ struct OpBinder { return failure(); } + ParseResult s64IntegerArrayAttr(llvm::SmallVector &values, + StringRef nameSuffix, + ArrayRef defaults) { + SmallString<64> name("torch.onnx."); + name.append(nameSuffix); + auto attr = op->getAttr(name); + if (!attr) { + values.append(defaults.begin(), defaults.end()); + return success(); + } + if (auto arrayAttr = dyn_cast(attr)) { + for (auto element : arrayAttr) { + auto integerAttr = element.dyn_cast(); + if (!integerAttr) + return failure(); + IntegerType t = cast(integerAttr.getType()); + if (!t.isSigned() || t.getWidth() != 64) + return failure(); + values.push_back(integerAttr.getSInt()); + } + return success(); + } + return failure(); + } + ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix, std::string defaultValue = "") { SmallString<64> name("torch.onnx."); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 482e20d6a7d3..de1ef97de7c3 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -472,4 +472,73 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, operand); return success(); }); + + patterns.onOp( + "Transpose", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + auto loc = binder.getLoc(); + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + auto operandType = operand.getType().cast(); + TensorType tensorType = operandType.toBuiltinTensor(); + if (!tensorType || !tensorType.hasRank()) + return failure(); + + // Default permutation is to reverse orders: + int64_t rank = tensorType.getRank(); + llvm::SmallVector reverse(rank); + for (int64_t i = 0; i < rank; ++i) { + reverse[i] = rank - i - 1; + } + + llvm::SmallVector permutations; + if (failed(binder.s64IntegerArrayAttr(permutations, "perm", reverse))) + return rewriter.notifyMatchFailure(binder.op, + "Failed to obtain permutations"); + + if (static_cast(permutations.size()) != rank) + return rewriter.notifyMatchFailure( + binder.op, "Permutation length does not match operand rank"); + + llvm::SmallVector shape(tensorType.getShape()); + llvm::SmallVector current(rank); + for (int64_t i = 0; i < rank; ++i) { + current[i] = i; + } + + for (int64_t i = 0; i < rank; ++i) { + if (current[i] == permutations[i]) + continue; + + int64_t target = i + 1; + for (; target < rank; ++target) { + if (current[target] == permutations[i]) + break; + } + + std::swap(shape[i], shape[target]); + std::swap(current[i], current[target]); + + Value dim0 = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + + Value dim1 = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), target)); + + operand = rewriter.create( + loc, + Torch::ValueTensorType::get(tensorType.getContext(), shape, + operandType.getOptionalDtype()), + operand, dim0, dim1); + } + + rewriter.replaceOp(binder.op, operand); + + return success(); + }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index f85221d971b4..d4ca317feab5 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -475,6 +475,8 @@ func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !to return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_selu func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 6 : si64} { // CHECK-DAG: %[[F1:.+]] = torch.constant.float 1 @@ -484,3 +486,32 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, %0 = torch.operator "onnx.Selu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32, torch.onnx.gamma = 3.000000e+00 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_transpose_default +func.func @test_transpose_default(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I2:.+]] = torch.constant.int 2 + // CHECK: %[[TRANSPOSE:.+]] = torch.aten.transpose.int %arg0, %[[I0]], %[[I2]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,3,2],f32> + %0 = torch.operator "onnx.Transpose"(%arg0) : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32> + + // CHECK: return %[[TRANSPOSE]] + return %0 : !torch.vtensor<[4,3,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_transpose_all_permutations_4 +func.func @test_transpose_all_permutations_4(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I2:.+]] = torch.constant.int 2 + // CHECK: %[[TRANSPOSE0:.+]] = torch.aten.transpose.int %arg0, %[[I0]], %[[I2]] : !torch.vtensor<[2,3,4],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,3,2],f32> + // CHECK-DAG: %[[I1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[I2:.+]] = torch.constant.int 2 + // CHECK: %[[TRANSPOSE1:.+]] = torch.aten.transpose.int %[[TRANSPOSE0]], %[[I1]], %[[I2]] : !torch.vtensor<[4,3,2],f32>, !torch.int, !torch.int -> !torch.vtensor<[4,2,3],f32> + %0 = torch.operator "onnx.Transpose"(%arg0) {torch.onnx.perm = [2 : si64, 0 : si64, 1 : si64]} : (!torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,2,3],f32> + + // CHECK: return %[[TRANSPOSE1]] + return %0 : !torch.vtensor<[4,2,3],f32> +} From b3e94208a891a85cffc5eec1b267ce4e7a761b74 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 15 Dec 2023 16:41:45 -0800 Subject: [PATCH 016/283] Bump LLVM version to aa165edca8545b212de084d5b18c3d30347f774a (#2658) --- externals/llvm-project | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/externals/llvm-project b/externals/llvm-project index fcd54b368e67..aa165edca854 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit fcd54b368e6713acd236dc47401b5292755900d7 +Subproject commit aa165edca8545b212de084d5b18c3d30347f774a From 61888690bba0b766cd3e4fc16ce1d43b3b70b44f Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 15 Dec 2023 21:23:51 -0800 Subject: [PATCH 017/283] [onnx] Add support for `onnx.sinh` (#2643) Adds a lowering from `onnx.sinh` to `aten.sinh`. This includes adding the `aten.sinh` operator. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 45 +++++++++++++++++++ .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 14 +++++- .../build_tools/torch_ods_gen.py | 1 + .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 9 ++++ 4 files changed, 68 insertions(+), 1 deletion(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index f9c878874cf5..4fb5b5cd3b18 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -526,6 +526,51 @@ def Torch_AtenSign_Op : Torch_Op<"aten.sign_", [ }]; } +def Torch_AtenSinhOp : Torch_Op<"aten.sinh", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::sinh : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSinhOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenSinhOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenSinh_Op : Torch_Op<"aten.sinh_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::sinh_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSinh_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenSinh_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenSgnOp : Torch_Op<"aten.sgn", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index de1ef97de7c3..6d2fb81533fa 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -467,11 +467,23 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) return failure(); - rewriter.replaceOpWithNewOp( binder.op, resultType, operand); return success(); }); + + patterns.onOp("Sinh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp( "Transpose", 13, diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 97480bbd4b09..ef1d707e3c00 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -266,6 +266,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::selu : (Tensor) -> (Tensor)", "aten::sigmoid : (Tensor) -> (Tensor)", "aten::sign : (Tensor) -> (Tensor)", + "aten::sinh : (Tensor) -> (Tensor)", "aten::sgn : (Tensor) -> (Tensor)", "aten::hardsigmoid : (Tensor) -> (Tensor)", "aten::hardswish : (Tensor) -> (Tensor)", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index d4ca317feab5..e4c95fe2bbc9 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -489,6 +489,15 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // ----- +// CHECK-LABEL: func.func @test_sinh +func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} { + // CHECK: torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Sinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_transpose_default func.func @test_transpose_default(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0 From cee8563060c79766ae530e539167bd8c872561db Mon Sep 17 00:00:00 2001 From: Andreas Falkenberg <149819731+afalkenberg1@users.noreply.github.com> Date: Sat, 16 Dec 2023 09:42:11 -0800 Subject: [PATCH 018/283] [onnx] Support of onnx.Greater, onnx.Less, onnx.GreaterOrEqual to Torch (#2649) The three remaining compare operations onnx.Greater onnx.Less onnx.GreaterOrEqual Are also added with this push request. This concludes a set of basic tensor compare functions. --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 40 ++++++++++++++++++- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 31 +++++++++++++- 2 files changed, 67 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index d97964bcf608..1de350905a44 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -38,7 +38,43 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); - patterns.onOp("LessOrEqual", 1, + patterns.onOp("Greater", 16, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + std::string direction; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp("GreaterOrEqual", 16, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + std::string direction; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp("Less", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp("LessOrEqual", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; @@ -87,4 +123,4 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, operand, constAlpha); return success(); }); -} +} \ No newline at end of file diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 1ab2db46d1d2..d6b99b62f72f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -4,6 +4,33 @@ // level constants. This is a pragmatic choice which lets us have a lot // of tests in this file, whereas the others tend to be more bespoke. +// CHECK-LABEL: func.func @test_greater +func.func @test_greater(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> + // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.gt.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> + %0 = torch.operator "onnx.Greater"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> + return %0 : !torch.vtensor<[3,4,5],i1> +} + +// CHECK-LABEL: func.func @test_greater_or_equal +func.func @test_greater_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> + // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.ge.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> + %0 = torch.operator "onnx.GreaterOrEqual"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> + return %0 : !torch.vtensor<[3,4,5],i1> +} + +// CHECK-LABEL: func.func @test_less +func.func @test_less(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> + // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.lt.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> + %0 = torch.operator "onnx.Less"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> + return %0 : !torch.vtensor<[3,4,5],i1> +} + // CHECK-LABEL: func.func @test_gather_elements func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} { // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0 @@ -43,10 +70,10 @@ func.func @test_matmul_4d(%arg0: !torch.vtensor<[1,2,3,4],f32>, %arg1: !torch.vt } // CHECK-LABEL: func.func @test_less_or_equal -func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> // CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> // CHECK: torch.aten.le.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> %0 = torch.operator "onnx.LessOrEqual"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> return %0 : !torch.vtensor<[3,4,5],i1> -} +} \ No newline at end of file From ae1a6e4a5a058170dd7b63001322844d04d12aac Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Sat, 16 Dec 2023 10:47:58 -0800 Subject: [PATCH 019/283] [onnx] Lower `onnx.Gemm` to `torch` (#2663) General lowering for `onnx.Gemm` to `torch` --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 74 ++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 76 +++++++++++++++++++ 2 files changed, 150 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 1de350905a44..732f05b4cf95 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -107,6 +107,80 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, data, constAxis, indices, sparseGrad); return success(); }); + patterns.onOp( + "Gemm", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value a, b, c; + float alpha, beta; + int64_t transA, transB; + if (binder.tensorOperandAtIndex(a, 0) || + binder.tensorOperandAtIndex(b, 1) || + binder.tensorOperandAtIndex(c, 2) || + binder.s64IntegerAttr(transA, "transA", 0) || + binder.s64IntegerAttr(transB, "transB", 0) || + binder.f32FloatAttr(alpha, "alpha", 1.0) || + binder.f32FloatAttr(beta, "beta", 1.0) || + binder.tensorResultType(resultType)) + return failure(); + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + Value one = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); + + auto transpose = [&](Value m) -> Value { + auto tty = m.getType().cast(); + auto shape = tty.getOptionalSizes(); + if (shape.has_value()) { + llvm::SmallVector newShape(shape.value()); + std::reverse(newShape.begin(), newShape.end()); + shape = std::move(newShape); + } + auto oty = Torch::ValueTensorType::get(tty.getContext(), shape, + tty.getOptionalDtype()); + return rewriter.create(binder.getLoc(), + oty, m, zero, one); + }; + + if (transA) { + a = transpose(a); + } + + if (transB) { + b = transpose(b); + } + + Value mm = + rewriter.create(binder.getLoc(), resultType, a, b); + if (alpha == 1.0 && beta == 1.0) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, mm, c, one); + return success(); + } + + if (alpha != 1.0 && beta != 1.0) { + Value constAlpha = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(alpha)); + mm = rewriter.create( + binder.getLoc(), resultType, mm, constAlpha); + alpha = 1.0; + } + + if (alpha != 1.0) { + std::swap(alpha, beta); + std::swap(mm, c); + } + + Value constBeta = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(beta)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, mm, c, constBeta); + return success(); + }); patterns.onOp("LeakyRelu", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index d6b99b62f72f..8bb287fb8823 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -40,6 +40,82 @@ func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torc return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + +// CHECK-LABEL: func.func @test_gemm_default +func.func @test_gemm_default(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK: %[[I1:.+]] = torch.constant.int 1 + // CHECK: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32> + // CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[I1]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Gemm"(%arg0, %arg1, %arg2) : (!torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32>, !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_gemm_transposeA +func.func @test_gemm_transposeA(%arg0: !torch.vtensor<[5,3],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK: %[[I0:.+]] = torch.constant.int 0 + // CHECK: %[[I1:.+]] = torch.constant.int 1 + // CHECK: %[[TRANS:.+]] = torch.aten.transpose.int %arg0, %[[I0]], %[[I1]] : !torch.vtensor<[5,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,5],f32> + // CHECK: %[[MM:.+]] = torch.aten.mm %[[TRANS]], %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32> + // CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[I1]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Gemm"(%arg0, %arg1, %arg2) {torch.onnx.transA = 1 : si64} : (!torch.vtensor<[5,3],f32>, !torch.vtensor<[5,4],f32>, !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_gemm_transposeB +func.func @test_gemm_transposeB(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[4,5],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK: %[[I0:.+]] = torch.constant.int 0 + // CHECK: %[[I1:.+]] = torch.constant.int 1 + // CHECK: %[[TRANS:.+]] = torch.aten.transpose.int %arg1, %[[I0]], %[[I1]] : !torch.vtensor<[4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[5,4],f32> + // CHECK: %[[MM:.+]] = torch.aten.mm %arg0, %[[TRANS]] : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32> + // CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[I1]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Gemm"(%arg0, %arg1, %arg2) {torch.onnx.transB = 1 : si64} : (!torch.vtensor<[3,5],f32>, !torch.vtensor<[4,5],f32>, !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_gemm_alpha +func.func @test_gemm_alpha(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK-DAG: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32> + // CHECK-DAG: %[[ALPHA:.+]] = torch.constant.float 5.000000e-01 + // CHECK: torch.aten.add.Tensor %arg2, %[[MM]], %[[ALPHA]] : !torch.vtensor<[1,4],f32>, !torch.vtensor<[3,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Gemm"(%arg0, %arg1, %arg2) {torch.onnx.alpha = 5.000000e-01 : f32} : (!torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32>, !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_gemm_beta +func.func @test_gemm_beta(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK-DAG: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32> + // CHECK-DAG: %[[BETA:.+]] = torch.constant.float 5.000000e-01 + // CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[BETA]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Gemm"(%arg0, %arg1, %arg2) {torch.onnx.beta = 5.000000e-01 : f32} : (!torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32>, !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + + // ----- + +// CHECK-LABEL: func.func @test_gemm_alpha_beta +func.func @test_gemm_alpha_beta(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[I1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32> + // CHECK-DAG: %[[ALPHA:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[BETA:.+]] = torch.constant.float 2.500000e-01 + // CHECK-DAG: %[[MUL:.+]] = torch.aten.mul.Scalar %[[MM]], %[[ALPHA]] : !torch.vtensor<[3,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32> + // CHECK: torch.aten.add.Tensor %[[MUL]], %arg2, %[[BETA]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.float -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Gemm"(%arg0, %arg1, %arg2) {torch.onnx.alpha = 5.000000e-01 : f32, torch.onnx.beta = 2.500000e-01 : f32} : (!torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32>, !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_leaky_relu func.func @test_leaky_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 16 : si64} { // CHECK-DAG: %[[F2:.+]] = torch.constant.float 2 From 9c655d0bfb166785dc17e51b0afa1f937c227cef Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Sun, 17 Dec 2023 06:07:43 -0800 Subject: [PATCH 020/283] [Bazel] Add conversion targets for `TorchToTensor` (#2666) Adapts bazel build per https://github.com/llvm/torch-mlir/pull/2648. https://github.com/sjain-stanford/torch-mlir/actions/runs/7233207462/job/19708228888 --- utils/bazel/torch-mlir-overlay/BUILD.bazel | 24 ++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index 2a9edaac503c..138dcbefb6ac 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -420,6 +420,29 @@ cc_library( ], ) +cc_library( + name = "TorchMLIRTorchToTensor", + srcs = glob([ + "lib/Conversion/*.h", + "lib/Conversion/TorchToTensor/*.cpp", + ]), + hdrs = glob([ + "include/torch-mlir/Conversion/TorchToTensor/*.h", + ]), + strip_include_prefix = "include", + deps = [ + ":TorchMLIRConversionPassesIncGen", + ":TorchMLIRConversionUtils", + ":TorchMLIRTorchBackendTypeConversion", + ":TorchMLIRTorchConversionDialect", + ":TorchMLIRTorchDialect", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:TensorDialect", + ], +) + cc_library( name = "TorchMLIRTorchConversionToMLProgram", srcs = glob([ @@ -515,6 +538,7 @@ cc_library( ":TorchMLIRTorchToSCF", ":TorchMLIRTorchToStablehlo", ":TorchMLIRTorchToTMTensor", + ":TorchMLIRTorchToTensor", ":TorchMLIRTorchToTosa", ], ) From 791c66647927a35c610ecd2ca2d93401f761b422 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 18 Dec 2023 09:15:12 -0800 Subject: [PATCH 021/283] [torch] Lower `torch.aten.sinh` to `linalg` (#2662) --- .../TorchToLinalg/Uncategorized.cpp | 42 ++++++++++--------- .../Conversion/TorchToLinalg/elementwise.mlir | 18 ++++++++ 2 files changed, 41 insertions(+), 19 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 5e38ec1a1490..2b4a95984425 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -220,6 +220,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); @@ -1315,8 +1319,8 @@ class ConvertElementwiseOp : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (!isa(); patterns.add(typeConverter, context); diff --git a/test/Conversion/TorchToLinalg/elementwise.mlir b/test/Conversion/TorchToLinalg/elementwise.mlir index 00e408388b2c..bed94f98da2b 100644 --- a/test/Conversion/TorchToLinalg/elementwise.mlir +++ b/test/Conversion/TorchToLinalg/elementwise.mlir @@ -19,6 +19,8 @@ func.func @elementwise$unary(%arg0: !torch.vtensor<[],f32>) -> !torch.vtensor<[] return %0 : !torch.vtensor<[],f32> } +// ----- + // CHECK-LABEL: func.func @elementwise$binary( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?],f32> { @@ -46,6 +48,8 @@ func.func @elementwise$binary(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vt return %0 : !torch.vtensor<[?,?],f32> } +// ----- + // CHECK-LABEL: func.func @elementwise$ternary( // CHECK: linalg.generic {indexing_maps = [ // CHECK-SAME: affine_map<(d0, d1, d2) -> (d0, d1, d2)>, @@ -57,6 +61,8 @@ func.func @elementwise$ternary(%arg0: !torch.vtensor<[?,?,?],f32>, %arg1: !torch return %0 : !torch.vtensor<[?,?,?],f32> } +// ----- + // CHECK-LABEL: func.func @elementwise$with_scalar_capture( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?],f32>, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[?],f32> { @@ -75,6 +81,8 @@ func.func @elementwise$with_scalar_capture(%arg0: !torch.vtensor<[?],f32>, %arg1 return %0 : !torch.vtensor<[?],f32> } +// ----- + // CHECK-LABEL: func.func @elementwise$static_1( // CHECK: linalg.generic {indexing_maps = [ // CHECK-SAME: affine_map<(d0) -> (d0)>, @@ -84,3 +92,13 @@ func.func @elementwise$static_1(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.vt %1 = torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[?],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[?],f32> return %1 : !torch.vtensor<[?],f32> } + +// ----- + +// CHECK-LABEL: func.func @elementwise_sinh +// CHECK: linalg.generic +// CHECK: math.sinh +func.func @elementwise_sinh(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> { + %0 = torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} From deacb8ef38757386e4303d780c7a74fb06b87e39 Mon Sep 17 00:00:00 2001 From: John Wu Date: Mon, 18 Dec 2023 10:57:08 -0800 Subject: [PATCH 022/283] [MLIR][ONNX] Add OnnxToTorch support for Gelu (#2647) This commit adds the OnnxToTorch support for Gelu op. --------- Co-authored-by: Rob Suderman --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 20 ++++++++ .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 48 ++++++++++++++++++- 2 files changed, 67 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 732f05b4cf95..0191fcf3619e 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -27,6 +27,26 @@ using namespace mlir::torch::onnx_c; // thing here, so we simplify. void mlir::torch::onnx_c::populateDefaultDomainGtoP( OnnxCustomOpConversionPattern &patterns) { + + patterns.onOp( + "Gelu", 20, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Value operand; + Torch::ValueTensorType resultType; + std::string approximate; + + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType) || + binder.customOpNameStringAttr(approximate, "approximate", "none")) + return failure(); + + Value vApproximate = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getStringAttr(approximate)); + + rewriter.replaceOpWithNewOp(binder.op, resultType, + operand, vApproximate); + return success(); + }); patterns.onOp("MatMul", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 8bb287fb8823..27e7f2c6aa26 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt <%s -convert-torch-onnx-to-torch | FileCheck %s +// RUN: torch-mlir-opt <%s -convert-torch-onnx-to-torch --split-input-file | FileCheck %s // Generally, the test cases accumulated here come from running the importer // over all included backend tests that involve simple ops with no model // level constants. This is a pragmatic choice which lets us have a lot @@ -131,6 +131,8 @@ func.func @test_matmul_2d(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtenso return %0 : !torch.vtensor<[3,3],f32> } +// ----- + // CHECK-LABEL: @test_matmul_3d func.func @test_matmul_3d(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2,4,3],f32>) -> !torch.vtensor<[2,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2,4,3],f32> -> !torch.vtensor<[2,3,3],f32> @@ -138,6 +140,8 @@ func.func @test_matmul_3d(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vten return %0 : !torch.vtensor<[2,3,3],f32> } +// ----- + // CHECK-LABEL: @test_matmul_4d func.func @test_matmul_4d(%arg0: !torch.vtensor<[1,2,3,4],f32>, %arg1: !torch.vtensor<[1,2,4,3],f32>) -> !torch.vtensor<[1,2,3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[1,2,3,4],f32>, !torch.vtensor<[1,2,4,3],f32> -> !torch.vtensor<[1,2,3,3],f32> @@ -145,6 +149,48 @@ func.func @test_matmul_4d(%arg0: !torch.vtensor<[1,2,3,4],f32>, %arg1: !torch.vt return %0 : !torch.vtensor<[1,2,3,3],f32> } +// ----- + +// CHECK-LABEL: @test_gelu_default_1 +func.func @test_gelu_default_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[STR1:.*]] = torch.constant.str "none" + // CHECK: torch.aten.gelu %arg0, %[[STR1]] : !torch.vtensor<[3],f32>, !torch.str -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Gelu"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: @test_gelu_default_2 +func.func @test_gelu_default_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[STR1:.*]] = torch.constant.str "none" + // CHECK: torch.aten.gelu %arg0, %[[STR1]] : !torch.vtensor<[3,4,5],f32>, !torch.str -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Gelu"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_gelu_tanh_1 +func.func @test_gelu_tanh_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[STR1:.*]] = torch.constant.str "tanh" + // CHECK: torch.aten.gelu %arg0, %[[STR1]] : !torch.vtensor<[3],f32>, !torch.str -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Gelu"(%arg0) {torch.onnx.approximate = "tanh"} : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: @test_gelu_tanh_2 +func.func @test_gelu_tanh_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[STR1:.*]] = torch.constant.str "tanh" + // CHECK: torch.aten.gelu %arg0, %[[STR1]] : !torch.vtensor<[3,4,5],f32>, !torch.str -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Gelu"(%arg0) {torch.onnx.approximate = "tanh"} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_less_or_equal func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> From 698ff3a736be3dcc201b4f5a7297f417587f3a99 Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Mon, 18 Dec 2023 12:37:31 -0800 Subject: [PATCH 023/283] [MLIR][ONNX] Add OnnxToTorch support for Reduction Ops (#2657) This commit adds the OnnxToTorch support for ReduceSum, ReduceMean, and ReduceMin ops. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 311 +++++++++++++++++- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 297 +++++++++++++++++ 2 files changed, 607 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 6d2fb81533fa..3637f7f35327 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -459,6 +459,316 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, operand, vAlpha, vScale, vInputScale); return success(); }); + patterns.onOp( + "ReduceSum", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + Value axes; + int64_t keepDims; + int64_t noop_with_empty_axes; + if (binder.tensorOperands(data, axes) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + Torch::BaseTensorType axesType = + axes.getType().cast(); + SmallVector dimList; + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = axesType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); + auto sizes = + dyn_cast(axes.getType()).getSizes(); + Value noneVal = rewriter.create(binder.getLoc()); + // Deal with case when axes is empty + if (sizes.size() == 1 && sizes[0] == 0) { + if (noop_with_empty_axes == 0) { + Value keepDimsConstInt = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); + Value keepDimsBool = rewriter.create( + binder.getLoc(), keepDimsConstInt); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, /*dim=*/noneVal, + /*keepdim=*/keepDimsBool, /*dtype=*/noneVal); + } else { + rewriter.replaceOp(binder.op, data); + } + return success(); + } + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + int64_t adjustmentInt = + cast(data.getType()).getSizes().size(); + Value adjustment = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + adjustmentInt)); + // convert axes (tensor) into torch int list while dealing with neg axis + for (int i = 0; i < sizes[0]; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, axes, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + // deal with neg axis: if (axis < 0) axis += rank + Value isNegative = + rewriter.create(binder.getLoc(), dim, zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, adjustment); + Value finalDim = rewriter.create( + binder.getLoc(), dim, finalOffset); + dimList.push_back(finalDim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + Value keepDimBool; + if (keepDims == 1) { + keepDimBool = + rewriter.create(binder.getLoc(), true); + } else { + keepDimBool = + rewriter.create(binder.getLoc(), false); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList, keepDimBool, + /*dtype=*/noneVal); + return success(); + }); + patterns.onOp( + "ReduceMean", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + Value axes; + int64_t keepDims; + int64_t noop_with_empty_axes; + if (binder.tensorOperands(data, axes) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + Torch::BaseTensorType axesType = + axes.getType().cast(); + SmallVector dimList; + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = axesType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); + auto sizes = + dyn_cast(axes.getType()).getSizes(); + Value noneVal = rewriter.create(binder.getLoc()); + // deal with case when axes is empty + if (sizes.size() == 1 && sizes[0] == 0) { + if (noop_with_empty_axes == 0) { + Value keepDimsConstInt = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); + Value keepDimsBool = rewriter.create( + binder.getLoc(), keepDimsConstInt); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, /*dim=*/noneVal, keepDimsBool, + /*dtype=*/noneVal); + } else { + rewriter.replaceOp(binder.op, data); + } + return success(); + } + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + int64_t adjustmentInt = + cast(data.getType()).getSizes().size(); + Value adjustment = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + adjustmentInt)); + // convert axes (tensor) into torch int list while dealing with neg axis + for (int i = 0; i < sizes[0]; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, axes, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + // deal with neg axis: if (axis < 0) axis += rank + Value isNegative = + rewriter.create(binder.getLoc(), dim, zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, adjustment); + Value finalDim = rewriter.create( + binder.getLoc(), dim, finalOffset); + dimList.push_back(finalDim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + Value keepDimBool; + if (keepDims == 1) { + keepDimBool = + rewriter.create(binder.getLoc(), true); + } else { + keepDimBool = + rewriter.create(binder.getLoc(), false); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList, keepDimBool, + /*dtype=*/noneVal); + return success(); + }); + patterns.onOp( + "ReduceMin", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // AtenAminOp allows us to pass a list of dims + Torch::ValueTensorType resultType; + Value data; + Value axes; + int64_t keepDims; + int64_t noop_with_empty_axes; + // Deal with case when no axes arg is passed + if (binder.op->getNumOperands() == 1) { + if (binder.tensorOperand(data) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, + "noop_with_empty_axes", 0)) + return failure(); + if (noop_with_empty_axes == 0) { + Value keepDimsConstInt = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); + Value keepDimsBool = rewriter.create( + binder.getLoc(), keepDimsConstInt); + int64_t numDims = dyn_cast(data.getType()) + .getSizes() + .size(); + SmallVector axesList; + for (int i = 0; i < numDims; i++) { + Value curr = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + axesList.push_back(curr); + } + Value axesValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + axesList); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, axesValueList, keepDimsBool); + } else { + rewriter.replaceOp(binder.op, data); + } + return success(); + } + if (binder.tensorOperands(data, axes) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + Torch::BaseTensorType axesType = + axes.getType().cast(); + SmallVector dimList; + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = axesType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); + auto sizes = + dyn_cast(axes.getType()).getSizes(); + // deal with case when axes is empty + if (sizes.size() == 1 && sizes[0] == 0) { + if (noop_with_empty_axes == 0) { + // create dims list with all dims [0, data.getSizes().size()) + Value keepDimsConstInt = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); + Value keepDimsBool = rewriter.create( + binder.getLoc(), keepDimsConstInt); + int64_t numDims = dyn_cast(data.getType()) + .getSizes() + .size(); + for (int i = 0; i < numDims; i++) { + Value curr = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + dimList.push_back(curr); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + dimList); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList, keepDimsBool); + } else { + rewriter.replaceOp(binder.op, data); + } + return success(); + } + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + int64_t adjustmentInt = + cast(data.getType()).getSizes().size(); + Value adjustment = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + adjustmentInt)); + // convert axes (tensor) into torch int list while dealing with neg axis + for (int i = 0; i < sizes[0]; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, axes, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + // deal with neg axis: if (axis < 0) axis += rank + Value isNegative = + rewriter.create(binder.getLoc(), dim, zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, adjustment); + Value finalDim = rewriter.create( + binder.getLoc(), dim, finalOffset); + dimList.push_back(finalDim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + Value keepDimBool; + if (keepDims == 1) { + keepDimBool = + rewriter.create(binder.getLoc(), true); + } else { + keepDimBool = + rewriter.create(binder.getLoc(), false); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList, keepDimBool); + return success(); + }); patterns.onOp("Shape", 9, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -550,7 +860,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } rewriter.replaceOp(binder.op, operand); - return success(); }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index e4c95fe2bbc9..da2a5c44aa95 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -11,6 +11,8 @@ func.func @test_reciprocal(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_relu func.func @test_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.relu %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -18,6 +20,8 @@ func.func @test_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_round func.func @test_round(%arg0: !torch.vtensor<[15],f32>) -> !torch.vtensor<[15],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { //CHECK: torch.aten.round %arg0 : !torch.vtensor<[15],f32> -> !torch.vtensor<[15],f32> @@ -25,6 +29,8 @@ func.func @test_round(%arg0: !torch.vtensor<[15],f32>) -> !torch.vtensor<[15],f3 return %0 : !torch.vtensor<[15],f32> } +// ----- + // CHECK-LABEL: func.func @test_scatter_elements_with_axis func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -59,6 +65,8 @@ func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5], return %0 : !torch.vtensor<[1,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_sigmoid_example func.func @test_sigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.sigmoid %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> @@ -66,6 +74,8 @@ func.func @test_sigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtenso return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: func.func @test_sin_example func.func @test_sin_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.sin %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> @@ -73,6 +83,8 @@ func.func @test_sin_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3 return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: func.func @test_tanh_example func.func @test_tanh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.tanh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> @@ -80,6 +92,8 @@ func.func @test_tanh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[ return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: func.func @test_sqrt_example func.func @test_sqrt_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.sqrt %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> @@ -87,6 +101,8 @@ func.func @test_sqrt_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[ return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: func.func @test_sub_bcast func.func @test_sub_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -119,6 +135,8 @@ func.func @test_sub_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vten return %0 : !torch.vtensor<[3,4,5],ui8> } +// ----- + // CHECK-LABEL: func.func @test_sum_example func.func @test_sum_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -143,6 +161,8 @@ func.func @test_sum_two_inputs(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vte return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: func.func @test_where_example func.func @test_where_example(%arg0: !torch.vtensor<[2,2],i1>, %arg1: !torch.vtensor<[2,2],f32>, %arg2: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.where.self %arg0, %arg1, %arg2 : !torch.vtensor<[2,2],i1>, !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32> -> !torch.vtensor<[2,2],f32> @@ -157,6 +177,8 @@ func.func @test_where_long_example(%arg0: !torch.vtensor<[2,2],i1>, %arg1: !torc return %0 : !torch.vtensor<[2,2],si64> } +// ----- + // CHECK-LABEL: func.func @test_xor2d func.func @test_xor2d(%arg0: !torch.vtensor<[3,4],i1>, %arg1: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4],i1>, !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1> @@ -192,6 +214,8 @@ func.func @test_xor_bcast4v4d(%arg0: !torch.vtensor<[1,4,1,6],i1>, %arg1: !torch return %0 : !torch.vtensor<[3,4,5,6],i1> } +// ----- + // CHECK-LABEL: func.func @test_squeeze func.func @test_squeeze(%arg0: !torch.vtensor<[1,3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -233,6 +257,8 @@ func.func @test_squeeze_two_axes(%arg0: !torch.vtensor<[3,1,4,5,1],f32>, %arg1: return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_unsqueeze_axis_0 func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -421,6 +447,8 @@ func.func @test_unsqueeze_unsorted_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg return %0 : !torch.vtensor<[3,4,1,5,1,1],f32> } +// ----- + // CHECK-LABEL: func.func @test_softmax_axis_0 func.func @test_softmax_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -489,6 +517,275 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // ----- +// CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example +func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: torch.aten.Bool.int %int1 : !torch.int -> !torch.bool + // CHECK: torch.aten.sum.dim_IntList %arg0, %none, %0, %none : !torch.vtensor<[3,2,2],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1],f32> +} + +// CHECK-LABEL: func.func @test_reduce_sum_do_not_keepdims_example +func.func @test_reduce_sum_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: torch.aten.sum.dim_IntList %arg0, %6, %false, %none : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> +} + +// CHECK-LABEL: func.func @test_reduce_sum_empty_axes_input_noop_example +func.func @test_reduce_sum_empty_axes_input_noop_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[3,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64, torch.onnx.noop_with_empty_axes = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[3,2,2],f32> + return %0 : !torch.vtensor<[3,2,2],f32> +} + +// CHECK-LABEL: func.func @test_reduce_sum_empty_set_non_reduced_axis_zero +func.func @test_reduce_sum_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: torch.aten.sum.dim_IntList %arg0, %6, %true, %none : !torch.vtensor<[2,0,4],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[2,0,1],f32> + %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32> + return %0 : !torch.vtensor<[2,0,1],f32> +} + +// CHECK-LABEL: func.func @test_reduce_sum_keepdims_example +func.func @test_reduce_sum_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: torch.aten.sum.dim_IntList %arg0, %6, %true, %none : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> + %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> +} + +// CHECK-LABEL: func.func @test_reduce_sum_negative_axes_keepdims_example +func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: torch.aten.sum.dim_IntList %arg0, %6, %true, %none : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> + %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_mean_default_axes_keepdims_example +func.func @test_reduce_mean_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: torch.aten.Bool.int %int1 : !torch.int -> !torch.bool + // CHECK: torch.aten.mean.dim %arg0, %none, %0, %none : !torch.vtensor<[3,2,2],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1],f32> +} + +// CHECK-LABEL: func.func @test_reduce_mean_do_not_keepdims_example +func.func @test_reduce_mean_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: torch.aten.mean.dim %arg0, %6, %false, %none : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + %0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> +} + +// CHECK-LABEL: func.func @test_reduce_mean_keepdims_example +func.func @test_reduce_mean_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: torch.aten.mean.dim %arg0, %6, %true, %none : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> + %0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> +} + +// CHECK-LABEL: func.func @test_reduce_mean_negative_axes_keepdims_example +func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: torch.aten.mean.dim %arg0, %6, %true, %none : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> + %0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_min_bool_inputs +func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4,1],i1> + %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> + return %0 : !torch.vtensor<[4,1],i1> +} + +// CHECK-LABEL: func.func @test_reduce_min_default_axes_keepdims_example +func.func @test_reduce_min_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: torch.aten.Bool.int %int1 : !torch.int -> !torch.bool + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: torch.prim.ListConstruct %int0, %int1_0, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.amin %arg0, %1, %0 : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceMin"(%arg0) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1],f32> +} + +// CHECK-LABEL: func.func @test_reduce_min_do_not_keepdims_example +func.func @test_reduce_min_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: torch.aten.amin %arg0, %6, %false : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[3,2],f32> + %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> +} + +// CHECK-LABEL: func.func @test_reduce_min_empty_set +func.func @test_reduce_min_empty_set(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[2,0,4],f32>, !torch.list, !torch.bool -> !torch.vtensor<[2,1,4],f32> + %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> + return %0 : !torch.vtensor<[2,1,4],f32> +} + +// CHECK-LABEL: func.func @test_reduce_min_keepdims_example +func.func @test_reduce_min_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[3,1,2],f32> + %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> +} + +// CHECK-LABEL: func.func @test_reduce_min_negative_axes_keepdims_example +func.func @test_reduce_min_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[3,1,2],f32> + %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_sinh func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} { // CHECK: torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> From 8649b84e3f768daec7d43dc439f1d91cb44a9c6e Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 19 Dec 2023 05:47:11 +0530 Subject: [PATCH 024/283] [MLIR][ONNX] Add OnnxToTorch support for AveragePool op. (#2672) This commit adds the OnnxToTorch support for AveragePool op. Signed-Off By: vivekkhandelwal1424@gmail.com --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 110 ++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 21 ++++ 2 files changed, 131 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 44ced9eb4b64..ae061922e260 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -9,6 +9,7 @@ #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; @@ -164,6 +165,115 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); + patterns.onOp( + "AveragePool", 19, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + std::string autoPad; + SmallVector dilation; + if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) + return failure(); + if (autoPad != "NOTSET") { + // TODO: Add support for `auto_pad` != "NOTSET" + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: auto_pad != NOTSET"); + } + if (binder.s64IntegerArrayAttr(dilation, "dilations", {})) { + return failure(); + } + if (dilation.size() > 0) { + return rewriter.notifyMatchFailure( + binder.op, "dilation is not supported by torch.aten.avgpool op"); + } + + Torch::ValueTensorType resultType; + Value operand; + bool ceilMode, countIncludePad; + if (binder.tensorOperand(operand) || + binder.s64BoolAttr(ceilMode, "ceil_mode", false) || + binder.s64BoolAttr(countIncludePad, "count_include_pad", false) || + binder.tensorResultType(resultType)) + return failure(); + // Determine the rank of input tensor. + std::optional maybeRank = Torch::getTensorRank(operand); + if (!maybeRank) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: unranked tensor"); + unsigned rank = *maybeRank; + + SmallVector kernel, padding, strides; + if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) { + return failure(); + } + if (kernel.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, "kernel list size does not match the number of axes"); + } + if (binder.s64IntegerArrayAttr(padding, "pads", {0})) { + return failure(); + } + if (padding.size() != 1 && padding.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, "padding list size does not match the number of axes"); + } + if (binder.s64IntegerArrayAttr(strides, "strides", {1})) { + return failure(); + } + if (strides.size() != 1 && strides.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, "strides list size does not match the number of axes"); + } + + SmallVector cstKernel, cstPadding, cstStrides; + for (int64_t i : kernel) { + cstKernel.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + for (int64_t i : padding) { + cstPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + for (int64_t i : strides) { + cstStrides.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + Value kernelSizeList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstKernel); + Value paddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstPadding); + Value stridesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstStrides); + Value cstCeilMode = + rewriter.create(binder.getLoc(), ceilMode); + Value cstCountIncludePad = rewriter.create( + binder.getLoc(), countIncludePad); + Value cstNone = rewriter.create(binder.getLoc()); + + if (rank == 3) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludePad); + return success(); + } else if (rank == 4) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludePad, + /*divisor_override=*/cstNone); + return success(); + } else if (rank == 5) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludePad, + /*divisor_override=*/cstNone); + return success(); + } + return failure(); + }); patterns.onOp( "BitShift", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 397d72a4896b..20c8e493dcc7 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -434,3 +434,24 @@ func.func @test_floor(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4 %0 = torch.operator "onnx.Floor"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } + +// CHECK-LABEL: @test_averagepool_1d_default +func.func @test_averagepool_1d_default(%arg0: !torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,31],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.avg_pool1d %arg0, %0, %2, %1, %false, %true : !torch.vtensor<[1,3,32],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,3,31],f32> + %0 = torch.operator "onnx.AveragePool"(%arg0) {torch.onnx.kernel_shape = [2 : si64], torch.onnx.count_include_pad = 1 : si64} : (!torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,31],f32> + return %0 : !torch.vtensor<[1,3,31],f32> +} + +// CHECK-LABEL: @test_averagepool_2d_ceil +func.func @test_averagepool_2d_ceil(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.avg_pool2d %arg0, %0, %2, %1, %true, %false, %none : !torch.vtensor<[1,1,4,4],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,2,2],f32> + %0 = torch.operator "onnx.AveragePool"(%arg0) {torch.onnx.ceil_mode = 1 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32> + return %0 : !torch.vtensor<[1,1,2,2],f32> +} + +// CHECK-LABEL: @test_averagepool_3d_default +func.func @test_averagepool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.avg_pool3d %arg0, %0, %2, %1, %false, %false_2, %none : !torch.vtensor<[1,3,32,32,32],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,31,31,31],f32> + %0 = torch.operator "onnx.AveragePool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> + return %0 : !torch.vtensor<[1,3,31,31,31],f32> +} From 89cfbe894df2592a296f48440005da244768b0fa Mon Sep 17 00:00:00 2001 From: Yinrun Lyu <5969932+yinrun@users.noreply.github.com> Date: Tue, 19 Dec 2023 14:46:55 +0800 Subject: [PATCH 025/283] Update PYTHONPATH in development.md (#2644) Modify PYTHONPATH to new related directory in docs. --- build_tools/write_env_file.sh | 2 +- docs/development.md | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/build_tools/write_env_file.sh b/build_tools/write_env_file.sh index 8f3c9a59357f..05179c56a07c 100755 --- a/build_tools/write_env_file.sh +++ b/build_tools/write_env_file.sh @@ -13,7 +13,7 @@ portable_realpath() { td="$(portable_realpath "$(dirname "$0")"/..)" build_dir="$(portable_realpath "${TORCH_MLIR_BUILD_DIR:-$td/build}")" -python_packages_dir="$build_dir/python_packages" +python_packages_dir="$build_dir/tools/torch-mlir/python_packages" write_env_file() { echo "Updating $build_dir/.env file" diff --git a/docs/development.md b/docs/development.md index c60312e7ac5e..93ec50f4be9a 100644 --- a/docs/development.md +++ b/docs/development.md @@ -109,13 +109,13 @@ cmake --build build ### Linux and macOS ```shell -export PYTHONPATH=`pwd`/build/python_packages/torch_mlir:`pwd`/projects/pt1/examples +export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/projects/pt1/examples ``` ### Windows PowerShell ```shell -$env:PYTHONPATH = "$PWD/build/python_packages/torch_mlir;$PWD/projects/pt1/examples" +$env:PYTHONPATH = "$PWD/build/tools/torch-mlir/python_packages/torch_mlir;$PWD/projects/pt1/examples" ``` ## Testing MLIR output in various dialects @@ -126,7 +126,7 @@ Make sure you have activated the virtualenv and set the `PYTHONPATH` above (if running on Windows, modify the environment variable as shown above): ```shell source mlir_venv/bin/activate -export PYTHONPATH=`pwd`/build/tpython_packages/torch_mlir:`pwd`/projects/pt1/examples +export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/projects/pt1/examples python projects/pt1/examples/torchscript_resnet18_all_output_types.py ``` From ebaab4200f2c9dd7dc817361e916b69eb0379ff8 Mon Sep 17 00:00:00 2001 From: Andreas Falkenberg <149819731+afalkenberg1@users.noreply.github.com> Date: Tue, 19 Dec 2023 08:07:27 -0800 Subject: [PATCH 026/283] [ONNX] ONNX -> TORCH for Erf (#2673) TorchOnnxToTorch For Erf function --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 12 ++++++++++++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 7 +++++++ 2 files changed, 19 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index ae061922e260..510ce8121fde 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -461,6 +461,18 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, lhs, rhs); return success(); }); + patterns.onOp("Erf", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + std::string direction; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp("Floor", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 20c8e493dcc7..5f36d1bbfaa0 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -414,6 +414,13 @@ func.func @test_equal_bcast(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.v return %0 : !torch.vtensor<[3,4,5],i1> } +// CHECK-LABEL: @test_erf +func.func @test_erf(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.erf %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Erf"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + // CHECK-LABEL: @test_equal func.func @test_equal(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32> -> !torch.vtensor<[3,4,5],i1> From be3e74b647375d1c72f075bebf77dc31e74477a2 Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Tue, 19 Dec 2023 13:28:37 -0800 Subject: [PATCH 027/283] Integrate llvm/llvm-project@282d50147628 (2023-12-19) (#2675) --- externals/llvm-project | 2 +- lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp | 2 +- .../Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/externals/llvm-project b/externals/llvm-project index aa165edca854..282d50147628 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit aa165edca8545b212de084d5b18c3d30347f774a +Subproject commit 282d501476284c46fd943dcbae87494cb08e2c5f diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index a4b02cf9e17f..8ba0479625d8 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -274,7 +274,7 @@ class ReduceTrailingUnderscoreInplaceVariant : public RewritePattern { SmallVector fragments; llvm::SplitString(op->getName().getStringRef(), fragments, "."); - assert(fragments.size() >= 3 && fragments[2].endswith("_") && + assert(fragments.size() >= 3 && fragments[2].ends_with("_") && "IsTrailingUnderscoreInplaceVariant incorrectly applied"); fragments[2] = fragments[2].drop_back(); std::string noUnderscoreName = llvm::join(fragments, "."); diff --git a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp index 290beb1da7c9..7c3ceab3afec 100644 --- a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp @@ -78,7 +78,7 @@ LogicalResult Torch::wrapWithCalculateOpIfLibraryFunctionAvailable( // mechanically consistent with existing torch conventions of in-place vs. // out-of-place (value-semantic) variants), remove the prefix when // looking them up in the library. - if (name.startswith("valsem.")) + if (name.starts_with("valsem.")) name = name.drop_front(strlen("valsem.")); if (isa(op)) name = cast(op)->getAttr("name").cast().getValue(); From 869c25877a492ce214e87023e45919fd225ad145 Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Tue, 19 Dec 2023 18:07:23 -0800 Subject: [PATCH 028/283] Integrate llvm/llvm-project@99045b60b575 to fix bazel build. (#2677) https://github.com/llvm/torch-mlir/commit/be3e74b647375d1c72f075bebf77dc31e74477a2 breaks bazel in post-submit. The revision bumps it to include the bazel fix. --- externals/llvm-project | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/externals/llvm-project b/externals/llvm-project index 282d50147628..99045b60b575 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 282d501476284c46fd943dcbae87494cb08e2c5f +Subproject commit 99045b60b57571079f9cb4aea57870692523fbe8 From 20ab88284098e6ef4250d652609dc89542aa1d54 Mon Sep 17 00:00:00 2001 From: Sungsoon Cho <1025787+godot73@users.noreply.github.com> Date: Tue, 19 Dec 2023 20:59:19 -0800 Subject: [PATCH 029/283] Fix typo in DecomposeBernoulli() match failure messages. (#2676) --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index d4712e547264..8162d2bb6131 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3520,7 +3520,7 @@ class DecomposeAtenBernoulliOp : public OpRewritePattern { Value input = op.getSelf(); if (!op.getGenerator().getType().isa()) return rewriter.notifyMatchFailure( - op, "The generator has to ben None because only global default " + op, "The generator has to be None because only global default " "generator is supported"); Value output; if (failed( @@ -3546,7 +3546,7 @@ class DecomposeAtenBernoulliLikeOp : public OpRewritePattern { Value p = op.getP(); if (!op.getGenerator().getType().template isa()) return rewriter.notifyMatchFailure( - op, "The generator has to ben None because only global default " + op, "The generator has to be None because only global default " "generator is supported"); auto inputType = input.getType().cast(); @@ -3578,7 +3578,7 @@ class DecomposeAtenBernoulliTensorOp Value prob = op.getP(); if (!op.getGenerator().getType().isa()) return rewriter.notifyMatchFailure( - op, "The generator has to ben None because only global default " + op, "The generator has to be None because only global default " "generator is supported"); Value output; if (failed( From 8fa81d181b3abf032f23245eadbfe9801d1ddbd3 Mon Sep 17 00:00:00 2001 From: Rik Huijzer Date: Wed, 20 Dec 2023 09:34:50 +0100 Subject: [PATCH 030/283] Tweak development.md for more speed (#2667) Adding the `--progress` flag shows the same output as what `git clone` would show. This is very nice for slow connections. Without it, the command may run for many minutes without providing any indication that it is still doing something. For `--depth=1`, I think it should be safe as most people have new enough git versions nowadays, but let's be safe and make it an optional suggestion. I ran all the tests fine with `--depth=1`, but I don't know whether things will keep working when the submodules get updated for systems with old git versions. --- docs/development.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/development.md b/docs/development.md index 93ec50f4be9a..7927bb39a35f 100644 --- a/docs/development.md +++ b/docs/development.md @@ -5,9 +5,12 @@ ```shell git clone https://github.com/llvm/torch-mlir cd torch-mlir -git submodule update --init +git submodule update --init --progress ``` +Optionally, use `--depth=1` to make a shallow clone of the submodules. +While this is running, you can already setup the Python venv and dependencies in the next step. + ## Setup your Python VirtualEnvironment and Dependencies Also, ensure that you have the appropriate `python-dev` package installed From a24aadbfab8eea598a982dd9f56178c2ba5561ab Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 20 Dec 2023 10:09:10 -0800 Subject: [PATCH 031/283] [aten] Make `torch.aten.matmul` to `linalg` work for non-broadcasting case (#2659) Broadcasting for `torch.aten.matmul` is optional so a MxN with NxK matmul should be legalized to a `linalg.matmul`. --- lib/Conversion/TorchToLinalg/Linear.cpp | 24 ++++++++++++++++++++++-- test/Conversion/TorchToLinalg/basic.mlir | 14 ++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index b263786c3dbb..7c5f2c88c033 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -191,8 +191,9 @@ class ConvertAtenMatmulOp : public OpConversionPattern { Value lhs = adaptor.getSelf(); Value rhs = adaptor.getOther(); - if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) { return failure(); + } auto lhsType = lhs.getType().cast(); auto rhsType = rhs.getType().cast(); @@ -260,7 +261,26 @@ class ConvertAtenMatmulOp : public OpConversionPattern { return success(); } - // Fourth Case: Batch-Matrix Multiplication. + // Fourth Case: Vec-Vec Multiplication. + if (lhsRank == 2 && rhsRank == 2) { + Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0); + Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1); + Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0); + Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1); + checkDimEqualHelper(rewriter, loc, lhsDim1, rhsDim0); + + Value zeroTensor = createZeroInitTensor( + rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType); + Value matmul = + rewriter + .create(loc, zeroTensor.getType(), + ValueRange{lhs, rhs}, zeroTensor) + .getResult(0); + rewriter.replaceOpWithNewOp(op, newResultType, matmul); + return success(); + } + + // Fifth Case: Batch-Matrix Multiplication. // TODO: Handle batch matrix multiplication when one of the matrix is unity // rank and the other has batch dimension. if (lhsRank > 1 && rhsRank > 1) { diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index 0aaca941b0d9..486b8b641dfd 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -29,6 +29,20 @@ func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v // ----- +// CHECK-LABEL: func.func @torch.aten.matmul.2d +func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[LHS:.+]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[8,16],f32> -> tensor<8x16xf32> + // CHECK-DAG: %[[RHS:.+]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[16,8],f32> -> tensor<16x8xf32> + // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 + // CHECK-DAG: %[[EMPTY:.+]] = tensor.empty() : tensor<8x8xf32> + // CHECK-DAG: %[[FILL:.+]] = linalg.fill ins(%[[ZERO]] : f32) outs(%[[EMPTY]] : tensor<8x8xf32>) -> tensor<8x8xf32> + // CHECK: %[[MATMUL:.+]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<8x16xf32>, tensor<16x8xf32>) outs(%[[FILL]] : tensor<8x8xf32>) -> tensor<8x8xf32> + %0 = torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[8,16],f32>, !torch.vtensor<[16,8],f32> -> !torch.vtensor<[8,8],f32> + return %0 : !torch.vtensor<[8,8],f32> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.mm$basic_strict( // CHECK-NOT: assert func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> From 11cc92d4ab41d29a7478fc330b8fc9debf469481 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 20 Dec 2023 10:09:39 -0800 Subject: [PATCH 032/283] [onnx] Lowerings from `onnx.tan` (#2642) Started work on the `tan` lowerings for ONNX to Torch. Uses `sin` and `cos` to represent a `tan`. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 45 +++++++++++++++++++ .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 14 +++++- .../TorchToLinalg/Uncategorized.cpp | 24 +++++----- .../Transforms/AbstractInterpLibrary.cpp | 15 +++++++ .../build_tools/abstract_interp_lib_gen.py | 10 +++++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/elementwise.py | 40 +++++++++++++++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 9 ++++ 8 files changed, 147 insertions(+), 11 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 4fb5b5cd3b18..6013f6da3cfc 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -1066,6 +1066,51 @@ def Torch_AtenAcos_Op : Torch_Op<"aten.acos_", [ }]; } +def Torch_AtenTanOp : Torch_Op<"aten.tan", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::tan : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTanOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenTanOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenTan_Op : Torch_Op<"aten.tan_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::tan_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTan_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenTan_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenAtanOp : Torch_Op<"aten.atan", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 3637f7f35327..7630b9f282b2 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -794,7 +794,19 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, operand); return success(); }); - + + patterns.onOp("Tan", 7, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp( "Transpose", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 2b4a95984425..e947ae73ace0 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -216,6 +216,10 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); @@ -1319,15 +1323,15 @@ class ConvertElementwiseOp : public ConversionPattern { LogicalResult matchAndRewrite(Operation *op, ArrayRef operands, ConversionPatternRewriter &rewriter) const override { - if (!isa) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.tan\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.atan\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -11396,6 +11400,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %4 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.tan\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.atan2\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 66e47bd45ef8..338f5e97e100 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -59,6 +59,9 @@ def aten〇triu〡shape(self: List[int], diagonal: int = 0) -> List[int]: def aten〇tril〡shape(self: List[int], diagonal: int = 0) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇tan〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇atan〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -3721,6 +3724,13 @@ def aten〇var_mean〡dtype(self_rank_dtype: Tuple[int, int], unbiased: bool = T return torch.float64, self_dtype return self_dtype, self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇tan〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.float32 + return self_dtype + @check_dtype_function(_check_two_tensor_op()) def aten〇atan2〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index ef1d707e3c00..efee6c852eb4 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -278,6 +278,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::expm1 : (Tensor) -> (Tensor)", "aten::cos : (Tensor) -> (Tensor)", "aten::acos : (Tensor) -> (Tensor)", + "aten::tan : (Tensor) -> (Tensor)", "aten::atan : (Tensor) -> (Tensor)", "aten::atan2 : (Tensor, Tensor) -> (Tensor)", "aten::neg : (Tensor) -> (Tensor)", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 0b45a151c681..33c420a1c517 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -3009,6 +3009,46 @@ def ElementwiseAcosIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseTanModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.tan(a) + + +@register_test_case(module_factory=lambda: ElementwiseTanModule()) +def ElementwiseTanModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + +# ============================================================================== + +class ElementwiseTanIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.tan(a) + + +@register_test_case(module_factory=lambda: ElementwiseTanIntModule()) +def ElementwiseTanIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + +# ============================================================================== + class ElementwiseNegModule(torch.nn.Module): def __init__(self): diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index da2a5c44aa95..0f4fcb08cdfb 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -795,6 +795,15 @@ func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[ // ----- +// CHECK-LABEL: func.func @test_tan +func.func @test_tan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TAN:.+]] = torch.aten.tan %arg0 + %0 = torch.operator "onnx.Tan"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_transpose_default func.func @test_transpose_default(%arg0: !torch.vtensor<[2,3,4],f32>) -> !torch.vtensor<[4,3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0 From 832899817210ce506e9be9888cb2f7d2a5b59630 Mon Sep 17 00:00:00 2001 From: Rik Huijzer Date: Wed, 20 Dec 2023 22:08:21 +0100 Subject: [PATCH 033/283] Allow printing all IR in `torch_mlir.compile` (#2669) This PR adds the `enable_ir_printing` option to `torch_mlir.compile`, which can be used to print the IR for all intermediate passes. When running the added test file via: ```shell $ python test/python/compile.py 2> tiny.stderr ``` the file `tiny.stderr` is about 700 KB. --- projects/pt1/python/torch_mlir/__init__.py | 12 +++++-- .../pt1/python/torch_mlir/compiler_utils.py | 8 +++-- test/python/compile.py | 34 +++++++++++++++++++ 3 files changed, 50 insertions(+), 4 deletions(-) create mode 100644 test/python/compile.py diff --git a/projects/pt1/python/torch_mlir/__init__.py b/projects/pt1/python/torch_mlir/__init__.py index 8bbcce9943d9..1cf1aa0e048a 100644 --- a/projects/pt1/python/torch_mlir/__init__.py +++ b/projects/pt1/python/torch_mlir/__init__.py @@ -319,7 +319,8 @@ def compile(model: torch.nn.Module, backend_legal_ops: Optional[Sequence[str]] = None, extra_library: Iterable[Callable] = [], verbose: bool = False, - use_make_fx: bool = False): + use_make_fx: bool = False, + enable_ir_printing: bool = False): """Convert a PyTorch model to MLIR. Args: @@ -348,7 +349,13 @@ def compile(model: torch.nn.Module, into the abstract interpretation library. See `docs/adding_abstract_interpretation_functions.md` for more info on the format the functions should have. - verbose: If true, print extra information about the conversion. + verbose: If true, print extra information about the conversion to + stdout. + enable_ir_printing: If true, print the IR before and after each pass to + stderr. This is equivalent to setting MLIR's `-print-ir-after-all` + flag. Note that this can easily generate many gigabytes of text, + so make sure to pipe stderr to a file (for example, run + `python tinymodel.py 2> tinymodel.stderr` on Linux). Returns: An MLIR module that contains the converted model in the specified @@ -452,6 +459,7 @@ def compile(model: torch.nn.Module, mb.module, f"builtin.module(torchscript-module-to-torch-backend-pipeline{option_string})", "Lowering TorchScript IR -> Torch Backend IR", + enable_ir_printing=enable_ir_printing, ) return _lower_mlir_module(verbose, output_type, mb.module) diff --git a/projects/pt1/python/torch_mlir/compiler_utils.py b/projects/pt1/python/torch_mlir/compiler_utils.py index 56e250e16802..3a64473de118 100644 --- a/projects/pt1/python/torch_mlir/compiler_utils.py +++ b/projects/pt1/python/torch_mlir/compiler_utils.py @@ -27,7 +27,8 @@ class TorchMlirCompilerError(Exception): def run_pipeline_with_repro_report(module, pipeline: str, - description: str): + description: str, + enable_ir_printing: bool = False): """Runs `pipeline` on `module`, with a nice repro report if it fails.""" module_name = get_module_name_for_debug_dump(module) try: @@ -36,8 +37,11 @@ def run_pipeline_with_repro_report(module, asm_for_error_report = module.operation.get_asm( large_elements_limit=10, enable_debug_info=True) # Lower module in place to make it ready for compiler backends. - with module.context: + with module.context as ctx: pm = PassManager.parse(pipeline) + if enable_ir_printing: + ctx.enable_multithreading(False) + pm.enable_ir_printing() pm.run(module.operation) except Exception as e: # TODO: More robust. diff --git a/test/python/compile.py b/test/python/compile.py new file mode 100644 index 000000000000..fc2917e9c76a --- /dev/null +++ b/test/python/compile.py @@ -0,0 +1,34 @@ +# RUN: %PYTHON -s %s 2>&1 | FileCheck %s + +import gc +import sys +import torch +import torch_mlir + + +def run_test(f): + print("TEST:", f.__name__, file=sys.stderr) + f() + gc.collect() + + +class TinyModel(torch.nn.Module): + def __init__(self): + super(TinyModel, self).__init__() + + self.linear = torch.nn.Linear(20, 30) + + def forward(self, x): + x = self.linear(x) + return x + + +# CHECK-LABEL: TEST: test_enable_ir_printing +@run_test +def test_enable_ir_printing(): + torch_mlir.compile(TinyModel(), + torch.ones(1, 3, 20, 20), + output_type="linalg-on-tensors", + enable_ir_printing=True) +# CHECK: // -----// IR Dump Before Canonicalizer (canonicalize) +# CHECK-NEXT: module attributes {torch.debug_module_name = "TinyModel"} { From d75cff6cd1ce691083708ff0226ebb1dd02ef5ee Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 20 Dec 2023 19:22:49 -0800 Subject: [PATCH 034/283] NFC: Remove unused variable causing a warning. --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 7630b9f282b2..a8ea1cc102c7 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -242,7 +242,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Torch::ValueTensorType resultType; Value data; Value axes; - Value result; if (binder.tensorOperands(data, axes) || binder.tensorResultType(resultType)) return failure(); From 3226241521f67c3b54a61f5cc6df265059da2c5f Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 19 Dec 2023 12:29:23 +0000 Subject: [PATCH 035/283] [MLIR][ONNX] Add OnnxToTorch support for Conv and ConvTranspose op. This commit adds the OnnxToTorch support for Conv and ConvTranspose op. Signed-Off By: vivekkhandelwal1424@gmail.com --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 336 ++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 130 +++++++ 2 files changed, 466 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 510ce8121fde..421b4edc9e18 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -426,6 +426,342 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( } return failure(); }); + patterns.onOp( + "Conv", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + std::string autoPad; + if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) + return failure(); + if (autoPad != "NOTSET") { + // TODO: Add support for `auto_pad` != "NOTSET" + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: auto_pad != NOTSET"); + } + + Torch::ValueTensorType resultType; + Value input, weight; + int64_t group; + if (binder.tensorOperands(input, weight) || + binder.s64IntegerAttr(group, "group", 1) || + binder.tensorResultType(resultType)) + return failure(); + + auto weightTensorType = weight.getType().cast(); + if (!weightTensorType || !weightTensorType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected weight type having sizes"); + } + ArrayRef weightShape = weightTensorType.getSizes(); + SmallVector kernelShape; + if (binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {})) + return failure(); + if (kernelShape.size()) { + if (kernelShape.size() != weightShape.size() - 2) { + return rewriter.notifyMatchFailure( + binder.op, + "unsupported conversion: kernel_shape list size should have " + "number of values equal to weight_rank - 2"); + } else { + for (unsigned i = 0; i < kernelShape.size(); i++) { + if (weightShape[i + 2] != kernelShape[i]) { + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: kernel_shape value " + "should be equal to the weight tensor shape"); + } + } + } + } + + // Determine the rank of input tensor. + std::optional maybeRank = Torch::getTensorRank(input); + if (!maybeRank) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: unranked tensor"); + unsigned rank = *maybeRank; + + SmallVector padding, strides, dilations; + SmallVector defaultPadding, defaultStrides, defaultDilations; + for (unsigned i = 0; i < rank - 2; i++) { + defaultPadding.push_back(0); + defaultStrides.push_back(1); + defaultDilations.push_back(1); + } + // Padding for the beginning and ending along each spatial axis, it can + // take any value greater than or equal to 0. The value represent the + // number of pixels added to the beginning and end part of the + // corresponding axis. pads format should be as follow [x1_begin, + // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added + // at the beginning of axis i and xi_end, the number of pixels added at + // the end of axis i. + if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) { + return failure(); + } + if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) { + return rewriter.notifyMatchFailure( + binder.op, "padding list size does not match the number of axes"); + } + if (binder.s64IntegerArrayAttr(dilations, "dilations", + defaultDilations)) { + return failure(); + } + if (dilations.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, + "dilations list size does not match the number of axes"); + } + if (binder.s64IntegerArrayAttr(strides, "strides", defaultStrides)) { + return failure(); + } + if (strides.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, "strides list size does not match the number of axes"); + } + + SmallVector cstPadding, cstStrides, cstDilations, + cstOutputPadding; + if (padding.size() != 2 * (rank - 2)) { + for (int64_t i : padding) { + cstPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + } else { + for (unsigned i = 0; i < padding.size() / 2; i++) { + if (padding[i] != padding[i + (padding.size() / 2)]) { + // TODO: Add support for different padding values for the + // beginning and ending along each spatial axis + return rewriter.notifyMatchFailure( + binder.op, + "unsupported conversion: padding values for the beginning " + "and ending along each spatial axis must be equal"); + } + cstPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); + } + } + for (int64_t i : dilations) { + cstDilations.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + for (int64_t i : strides) { + cstStrides.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + Value cstZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + cstOutputPadding = {cstZero, cstZero}; + + Value paddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstPadding); + Value dilationsList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstDilations); + Value stridesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstStrides); + Value outputPaddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstOutputPadding); + Value transposed = + rewriter.create(binder.getLoc(), false); + Value bias; + if (binder.op->getNumOperands() == 3) { + if (binder.tensorOperandAtIndex(bias, 2)) { + return failure(); + } + } else { + bias = rewriter.create(binder.getLoc()); + } + Value cstGroup = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(group)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, weight, bias, stridesList, + paddingList, dilationsList, transposed, outputPaddingList, + cstGroup); + return success(); + }); + patterns.onOp( + "ConvTranspose", 11, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + std::string autoPad; + if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) + return failure(); + if (autoPad != "NOTSET") { + // TODO: Add support for `auto_pad` != "NOTSET" + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: auto_pad != NOTSET"); + } + SmallVector outputShape; + if (binder.s64IntegerArrayAttr(outputShape, "output_shape", {})) + return failure(); + if (outputShape.size()) { + // TODO: Add support for non-None output_shape value. + return rewriter.notifyMatchFailure( + binder.op, + "unsupported conversion: output_shape should be absent"); + } + Torch::ValueTensorType resultType; + Value input, weight; + int64_t group; + if (binder.tensorOperands(input, weight) || + binder.s64IntegerAttr(group, "group", 1) || + binder.tensorResultType(resultType)) + return failure(); + + auto weightTensorType = weight.getType().cast(); + if (!weightTensorType || !weightTensorType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected weight type having sizes"); + } + ArrayRef weightShape = weightTensorType.getSizes(); + SmallVector kernelShape; + if (binder.s64IntegerArrayAttr(kernelShape, "kernel_shape", {})) + return failure(); + if (kernelShape.size()) { + if (kernelShape.size() != weightShape.size() - 2) { + return rewriter.notifyMatchFailure( + binder.op, + "unsupported conversion: kernel_shape list size should have " + "number of values equal to weight_rank - 2"); + } else { + for (unsigned i = 0; i < kernelShape.size(); i++) { + if (weightShape[i + 2] != kernelShape[i]) { + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: kernel_shape value " + "should be equal to the weight tensor shape"); + } + } + } + } + + // Determine the rank of input tensor. + std::optional maybeRank = Torch::getTensorRank(input); + if (!maybeRank) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: unranked tensor"); + unsigned rank = *maybeRank; + + SmallVector padding, strides, dilations, outputPadding; + SmallVector defaultPadding, defaultStrides, defaultDilations, defaultOutputPadding; + for (unsigned i = 0; i < rank - 2; i++) { + defaultPadding.push_back(0); + defaultStrides.push_back(1); + defaultDilations.push_back(1); + defaultOutputPadding.push_back(0); + } + // Padding for the beginning and ending along each spatial axis, it can + // take any value greater than or equal to 0. The value represent the + // number of pixels added to the beginning and end part of the + // corresponding axis. pads format should be as follow [x1_begin, + // x2_begin…x1_end, x2_end,…], where xi_begin the number of pixels added + // at the beginning of axis i and xi_end, the number of pixels added at + // the end of axis i. + if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) { + return failure(); + } + if (padding.size() != rank - 2 && padding.size() != 2 * (rank - 2)) { + return rewriter.notifyMatchFailure( + binder.op, "padding list size does not match the number of axes"); + } + if (binder.s64IntegerArrayAttr(dilations, "dilations", + defaultDilations)) { + return failure(); + } + if (dilations.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, + "dilations list size does not match the number of axes"); + } + if (binder.s64IntegerArrayAttr(strides, "strides", defaultStrides)) { + return failure(); + } + if (strides.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, "strides list size does not match the number of axes"); + } + if (binder.s64IntegerArrayAttr(outputPadding, "output_padding", + defaultOutputPadding)) { + return failure(); + } + if (outputPadding.size() != rank - 2) { + return rewriter.notifyMatchFailure( + binder.op, + "output_padding list size does not match the number of axes"); + } + + SmallVector cstPadding, cstStrides, cstDilations, + cstOutputPadding; + if (padding.size() != 2 * (rank - 2)) { + for (int64_t i : padding) { + cstPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + } else { + for (unsigned i = 0; i < padding.size() / 2; i++) { + if (padding[i] != padding[i + (padding.size() / 2)]) { + // TODO: Add support for different padding values for the + // beginning and ending along each spatial axis + return rewriter.notifyMatchFailure( + binder.op, + "unsupported conversion: padding values for the beginning " + "and ending along each spatial axis must be equal"); + } + cstPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(padding[i]))); + } + } + for (int64_t i : dilations) { + cstDilations.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + for (int64_t i : strides) { + cstStrides.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + for (int64_t i : outputPadding) { + cstOutputPadding.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + + Value paddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstPadding); + Value dilationsList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstDilations); + Value stridesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstStrides); + Value outputPaddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstOutputPadding); + Value transposed = + rewriter.create(binder.getLoc(), true); + Value bias; + if (binder.op->getNumOperands() == 3) { + if (binder.tensorOperandAtIndex(bias, 2)) { + return failure(); + } + } else { + bias = rewriter.create(binder.getLoc()); + } + Value cstGroup = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(group)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, weight, bias, stridesList, + paddingList, dilationsList, transposed, outputPaddingList, + cstGroup); + return success(); + }); patterns.onOp("Cos", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 5f36d1bbfaa0..438d6c7adda8 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -462,3 +462,133 @@ func.func @test_averagepool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32> %0 = torch.operator "onnx.AveragePool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> return %0 : !torch.vtensor<[1,3,31,31,31],f32> } + +// CHECK-LABEL: @test_conv_with_strides_no_padding +func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_1]], %[[C0_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,3,2],f32> + %0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32> + return %0 : !torch.vtensor<[1,1,3,2],f32> +} + +// CHECK-LABEL: @test_conv_with_strides_padding +func.func @test_conv_with_strides_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,4,3],f32> + %0 = torch.operator "onnx.Conv"(%arg0, %arg1) {torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,7,5],f32>, !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32> + return %0 : !torch.vtensor<[1,1,4,3],f32> +} + +// CHECK-LABEL: @test_convtranspose_dilations +func.func @test_convtranspose_dilations(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,1,2,2],f32>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[C0_2:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_1]], %[[C0_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,1,2,2],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,5,5],f32> + %0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.dilations = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,1,2,2],f32>) -> !torch.vtensor<[1,1,5,5],f32> + return %0 : !torch.vtensor<[1,1,5,5],f32> +} + +// CHECK-LABEL: @test_convtranspose +func.func @test_convtranspose(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,5,5],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[C0_2:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0_1]], %[[C0_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,2,5,5],f32> + %0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,5,5],f32> + return %0 : !torch.vtensor<[1,2,5,5],f32> +} + +// CHECK-LABEL: @test_convtranspose_pad + func.func @test_convtranspose_pad(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,10,8],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_2:.*]] = torch.constant.int 1 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C3]], %[[C2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C1_1]], %[[C1_2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,2,10,8],f32> + %0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.output_padding = [1 : si64, 1 : si64], torch.onnx.strides = [3 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,10,8],f32> + return %0 : !torch.vtensor<[1,2,10,8],f32> + } + +// CHECK-LABEL: @test_convtranspose_pads + func.func @test_convtranspose_pads(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,7,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C1_1:.*]] = torch.constant.int 1 + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C0_1:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C1]], %[[C2]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1_0]], %[[C1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C3]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool true + // CHECK: %[[BIAS:.*]] = torch.constant.none + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %[[BIAS]], %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,2,7,3],f32> + %0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.pads = [1 : si64, 2 : si64, 1 : si64, 2 : si64], torch.onnx.strides = [3 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,7,3],f32> + return %0 : !torch.vtensor<[1,2,7,3],f32> + } From 779a141f8d6a276c73ddcf6fea5ee40ae39be0d2 Mon Sep 17 00:00:00 2001 From: John Wu Date: Thu, 21 Dec 2023 07:26:20 -0800 Subject: [PATCH 036/283] Mentioned helpful tooling to convert Onnx models to Torch MLIR (#2683) - Going through the `#torch-mlir` channel on the `llvm` discord, I realize that there are some useful commands that would be extremely helpful in creating Onnx lowers to Torch MLIR. Seems a lot of people are contributing to this. So, I thought it would be good to add this information to the docs. These tools helped streamlined the development of this PR: https://github.com/llvm/torch-mlir/pull/2682 --- docs/importers/onnx_importer.md | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/docs/importers/onnx_importer.md b/docs/importers/onnx_importer.md index acc45bb2e602..796beba1f045 100644 --- a/docs/importers/onnx_importer.md +++ b/docs/importers/onnx_importer.md @@ -28,13 +28,25 @@ are relatively straight-forward to map, following this general procedure: `tools/torch-mlir/test/python/onnx_importer/Output`. The `.mlir` files under there should provide good variants to drive lit test coverage of conversion. + * (Optionally) If there is an Onnx file that uses the op of interest, + convert that file to Onnx MLIR form using the following Python command, + `python -m torch_mlir.tools.import_onnx my_model.onnx`. * There are often many variants of tests for checking conformance of different historic ONNX encodings, but these are often not load bearing at the MLIR level. * Pick a handful of test cases and add them to - `test/Conversion/TorchOnnxToTorch/simple_ops_x_to_y.mlir` corresponding to an - alphabetic breakdown. At this time, ignore tests that are not exercising + `test/Conversion/TorchOnnxToTorch/simple_ops_x_to_y.mlir` corresponding to + an alphabetic breakdown. At this time, ignore tests that are not exercising useful differences in the pattern implementations. + * (Optionally) Use `torch-mlir-opt` to validate the outputs of the new op. + First, build the project using + `cmake --build build --target tools/torch-mlir/all`. This will generate + the conversion binary, `torch-mlir-opt`. Then call `torch-mlir-opt` with + the MLIR pass `convert-torch-onnx-to-torch`: + ``` + build/bin/torch-mlir-opt -convert-torch-onnx-to-torch \ + -split-input-file [DESIRED_ONNX_FILE].mlir + ``` * Generate failure test cases: * Some ops have forms that do not (easily) map to torch-mlir. If you leave an op under-implemented, add a failing test case to From 46f2cb50dca5e789d1114b127d9a4312fbb8e3d9 Mon Sep 17 00:00:00 2001 From: John Wu Date: Thu, 21 Dec 2023 07:29:22 -0800 Subject: [PATCH 037/283] [onnx] Lower onnx.HardSigmoid to torch (#2682) The expression for HardSigmoid in Onnx (https://onnx.ai/onnx/operators/onnx__HardSigmoid.html): max(0, min(1, alpha * x + beta)) is inherently different from HardSigmoid in Torch (https://pytorch.org/docs/stable/generated/torch.nn.Hardsigmoid.html) which is: if x < -3 -> 0 elif x > 3 -> 1 else x/6 + 1/2 That being said, it was just better to compute out the entire expression when translating the Onnx expression to Torch mlir, which is done in this PR. Some of the logic is shared from the files in `DecomposeComplexOps`. Therefore, refactored some shared logic between `DecomposeComplexOps` and `DefaultDomainGToP` and put it in a `Utils` file. --- .../torch-mlir/Dialect/Torch/Utils/Utils.h | 12 ++++ .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 41 +++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 29 -------- lib/Dialect/Torch/Utils/Utils.cpp | 31 ++++++++- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 68 ++++++++++++++++++- 5 files changed, 150 insertions(+), 31 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 0e4c2b0a0ab7..25d35f0f9f2b 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -11,6 +11,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/Value.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" namespace mlir { @@ -117,6 +118,17 @@ LogicalResult checkDefaultStrideHelper(Operation *op, PatternRewriter &rewriter, Value opSize, Value opStride, Location loc); +// Helper to create a tensor filled with the given scalar. Scalar would be +// converted the to the element type of the given tensor type. +Value createInitTensor(PatternRewriter &rewriter, Location loc, + BaseTensorType resultType, Value scalar, + Value sizeList); + +// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar` +// would be converted to the element type of the given `inputType`. +Value createRank0Tensor(PatternRewriter &rewriter, Location loc, + BaseTensorType inputType, Value scalar); + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 0191fcf3619e..b9bb6a540a02 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; using namespace mlir::torch; @@ -27,7 +28,47 @@ using namespace mlir::torch::onnx_c; // thing here, so we simplify. void mlir::torch::onnx_c::populateDefaultDomainGtoP( OnnxCustomOpConversionPattern &patterns) { + patterns.onOp("HardSigmoid", 6, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value tensorOperand; + float alpha, beta; + if (binder.tensorOperand(tensorOperand) || + binder.f32FloatAttr(alpha, "alpha", 0.2) || + binder.f32FloatAttr(beta, "beta", 0.5) || + binder.tensorResultType(resultType)) + return failure(); + + // HardSigmoid computes the following expression: max(0, min(1, alpha * x + beta)) + Value constAlpha = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(alpha)); + + Value constBeta = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(beta)); + + // Expression: alpha * x + beta + Value alpha_x_plus_beta = rewriter.create( + binder.getLoc(), resultType, tensorOperand, constBeta, /*alpha=*/constAlpha); + // Expression: min(1, alpha * x + beta) + Value constantOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value oneTensor = createRank0Tensor(rewriter, binder.getLoc(), + resultType, constantOne); + Value minExpression = rewriter.create( + binder.getLoc(), resultType, oneTensor, alpha_x_plus_beta); + + // Expression: max(0, min(1, alpha * x + beta)) + Value constantZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value zeroTensor = createRank0Tensor(rewriter, binder.getLoc(), + resultType, constantZero); + rewriter.replaceOpWithNewOp( + binder.op, resultType, zeroTensor, minExpression); + return success(); + }); patterns.onOp( "Gelu", 20, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Value operand; diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 8162d2bb6131..d8b8639e0a75 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -126,35 +126,6 @@ static Value createTensorSub(PatternRewriter &rewriter, Location loc, return sub; } -// Helper to create a tensor filled with the given scalar. Scalar would be -// converted the to the element type of the given tensor type. -static Value createInitTensor(PatternRewriter &rewriter, Location loc, - BaseTensorType resultType, Value scalar, - Value sizeList) { - assert(resultType.hasDtype() && "result must have dtype"); - Value noneVal = rewriter.create(loc); - Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype()); - return rewriter.create(loc, resultType, sizeList, scalar, dtype, - /*layout=*/noneVal, - /*device=*/noneVal, - /*memory_format=*/noneVal); -} - -// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar` -// would be converted to the element type of the given `inputType`. -static Value createRank0Tensor(PatternRewriter &rewriter, Location loc, - BaseTensorType inputType, Value scalar) { - assert(inputType.hasDtype() && "input must have dtype"); - SmallVector sizes; - BaseTensorType rank0TensorTy = - inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype()) - .cast(); - Value dimList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())), - ValueRange{}); - return createInitTensor(rewriter, loc, rank0TensorTy, scalar, dimList); -} - // Share code between `softmax_backward` and `log_softmax_backward` ops. // Returns x - y * sum(z, dim). static Value createSoftmaxBackwardCommonKernel(PatternRewriter &rewriter, diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index c5b0eec50be2..4bf5f7e13d1f 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -10,6 +10,7 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "mlir/IR/BuiltinDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" using namespace mlir; using namespace mlir::torch; @@ -74,7 +75,6 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { } llvm::report_fatal_error("unhandled type for getScalarTypeForType"); } - Type Torch::getTypeForTorchType( MLIRContext *context, Type type, mlir::IntegerType::SignednessSemantics signedness) { @@ -471,3 +471,32 @@ LogicalResult Torch::checkDefaultStrideHelper(Operation *op, return success(); } } + +// Helper to create a tensor filled with the given scalar. Scalar would be +// converted the to the element type of the given tensor type. +Value Torch::createInitTensor(PatternRewriter &rewriter, Location loc, + BaseTensorType resultType, Value scalar, + Value sizeList) { + assert(resultType.hasDtype() && "result must have dtype"); + Value noneVal = rewriter.create(loc); + Value dtype = getDtypeIntValueForType(rewriter, loc, resultType.getDtype()); + return rewriter.create(loc, resultType, sizeList, scalar, dtype, + /*layout=*/noneVal, + /*device=*/noneVal, + /*memory_format=*/noneVal); +} + +// Helper to create a rank 0 tensor filled with the given `scalar`. `scalar` +// would be converted to the element type of the given `inputType`. +Value Torch::createRank0Tensor(PatternRewriter &rewriter, Location loc, + BaseTensorType inputType, Value scalar) { + assert(inputType.hasDtype() && "input must have dtype"); + SmallVector sizes; + BaseTensorType rank0TensorTy = + inputType.getWithSizesAndDtype(ArrayRef(sizes), inputType.getDtype()) + .cast(); + Value dimList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(inputType.getContext())), + ValueRange{}); + return createInitTensor(rewriter, loc, rank0TensorTy, scalar, dimList); +} diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 27e7f2c6aa26..08bb69f23fc5 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -198,4 +198,70 @@ func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch. // CHECK: torch.aten.le.Tensor %[[ARG0]], %[[ARG1]] : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],i1> %0 = torch.operator "onnx.LessOrEqual"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> return %0 : !torch.vtensor<[3,4,5],i1> -} \ No newline at end of file +} + +// CHECK-LABEL: @test_hardsigmoid_example +func.func @test_hardsigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 5.000000e-01 + // CHECK: %[[BETA_FLOAT:.*]] = torch.constant.float 0.60000002384185791 + // CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %arg0, %[[BETA_FLOAT:.*]], %[[ALPHA_FLOAT:.*]] : !torch.vtensor<[3],f32>, !torch.float, !torch.float -> !torch.vtensor<[3],f32> + // CHECK: %[[INT_1:.*]] = torch.constant.int 1 + // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE_FOR_ONE:.*]] = torch.constant.none + // CHECK: %[[INT_TYPE_FOR_TENSOR_ONE:.*]] = torch.constant.int 6 + // CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[INT_1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[MIN_EXPRESSION:.*]] = torch.aten.minimum %[[ONE_TENSOR:.*]], %[[ALPHA_MULTI_X_PLUS_BETA:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: %[[INT_0:.*]] = torch.constant.int 0 + // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE_FOR_ZERO:.*]] = torch.constant.none + // CHECK: %[[INT_TYPE_FOR_TENSOR_ZERO:.*]] = torch.constant.int 6 + // CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[INT_0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[RESULT:.*]] = torch.aten.maximum %[[ZERO_TENSOR:.*]], %[[MIN_EXPRESSION:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + // CHECK: return %[[RESULT:.*]] : !torch.vtensor<[3],f32> + + %0 = torch.operator "onnx.HardSigmoid"(%arg0) {torch.onnx.alpha = 5.000000e-01 : f32, torch.onnx.beta = 6.000000e-01 : f32} : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// CHECK-LABEL: @test_hardsigmoid +func.func @test_hardsigmoid(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 5.000000e-01 + // CHECK: %[[BETA_FLOAT:.*]] = torch.constant.float 0.60000002384185791 + // CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %arg0, %[[BETA_FLOAT:.*]], %[[ALPHA_FLOAT:.*]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[INT_1:.*]] = torch.constant.int 1 + // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE_FOR_ONE:.*]] = torch.constant.none + // CHECK: %[[INT_TYPE_FOR_TENSOR_ONE:.*]] = torch.constant.int 6 + // CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[INT_1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[MIN_EXPRESSION:.*]] = torch.aten.minimum %[[ONE_TENSOR:.*]], %[[ALPHA_MULTI_X_PLUS_BETA:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[INT_0:.*]] = torch.constant.int 0 + // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE_FOR_ZERO:.*]] = torch.constant.none + // CHECK: %[[INT_TYPE_FOR_TENSOR_ZERO:.*]] = torch.constant.int 6 + // CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[INT_0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[RESULT:.*]] = torch.aten.maximum %[[ZERO_TENSOR:.*]], %[[MIN_EXPRESSION:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: return %[[RESULT:.*]] : !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.HardSigmoid"(%arg0) {torch.onnx.alpha = 5.000000e-01 : f32, torch.onnx.beta = 6.000000e-01 : f32} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_hardsigmoid_default +func.func @test_hardsigmoid_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 0.20000000298023224 + // CHECK: %[[BETA_FLOAT:.*]] = torch.constant.float 5.000000e-01 + // CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %arg0, %[[BETA_FLOAT:.*]], %[[ALPHA_FLOAT:.*]] : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[INT_1:.*]] = torch.constant.int 1 + // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE_FOR_ONE:.*]] = torch.constant.none + // CHECK: %[[INT_TYPE_FOR_TENSOR_ONE:.*]] = torch.constant.int 6 + // CHECK: %[[ONE_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]], %[[INT_1:.*]], %[[INT_TYPE_FOR_TENSOR_ONE:.*]], %[[NONE_FOR_ONE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: %[[MIN_EXPRESSION:.*]] = torch.aten.minimum %[[ONE_TENSOR:.*]], %[[ALPHA_MULTI_X_PLUS_BETA:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: %[[INT_0:.*]] = torch.constant.int 0 + // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[NONE_FOR_ZERO:.*]] = torch.constant.none + // CHECK: %[[INT_TYPE_FOR_TENSOR_ZERO:.*]] = torch.constant.int 6 + // CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[INT_0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> + // CHECK: torch.aten.maximum %[[ZERO_TENSOR:.*]], %[[MIN_EXPRESSION:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.HardSigmoid"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} From ccd469ca0d626d29fea3ab35d5956cc2882a12be Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Thu, 21 Dec 2023 08:40:10 -0800 Subject: [PATCH 038/283] [fx] Upstream the turbine FxImporter to torch-mlir. (#2681) Changes made during upstreaming: * Removed comments attributing some copied code back to torch-mlir (since it is now repatriated). * Re-organized imports. * Inlined RefMapping/RefTracker and TypeSubclassMap from an external utility module. * Added FxImporter class comments. * Updated stack trace extraction to be fail safe. * Added an entry-point for `import_frozen_exported_program` which uses the shiny new upstream `torch.export.export()` API (versus the lower-level/older API that Turbine is presently using). This necessitated a small FX rewrite to line external state management up with current conventions. * Adapted one of Turbine's importer tests to go with this initial submission. Turbine unfortunately has a lot of more-integration-ey tests, and I would like to extract those as more of unit tests of the importer features and upstream them that way vs trying to copy directly. For now, one overall test with the initial submission gets us moving. I acknowledge that there are some code quality things that could be improved in this submission: this was authored over the course of many months (and often via some trial and error). I would like to keep it relatively converged with the downstream for the next few steps while getting the test suite upstreamed. And then it will be easier to take a hygienic pass through the code. Including co-authors for contributors in the git log of the original repository. Co-authored-by: Ean Garvey <87458719+monorimet@users.noreply.github.com> Co-authored-by: Avinash Sharma Co-authored-by: Arham Khan Co-authored-by: brucekimrokcmu Co-authored-by: saienduri <77521230+saienduri@users.noreply.github.com> --- python/CMakeLists.txt | 1 + python/torch_mlir/extras/fx_importer.py | 1238 +++++++++++++++++++++++ test/python/fx_importer/basic_test.py | 80 ++ 3 files changed, 1319 insertions(+) create mode 100644 python/torch_mlir/extras/fx_importer.py create mode 100644 test/python/fx_importer/basic_test.py diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index f29429b7246c..b8f8394459d9 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -35,6 +35,7 @@ declare_mlir_python_sources(TorchMLIRPythonSources.Importers ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" ADD_TO_PARENT TorchMLIRPythonSources SOURCES + extras/fx_importer.py extras/onnx_importer.py ) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py new file mode 100644 index 000000000000..9ec90e766c46 --- /dev/null +++ b/python/torch_mlir/extras/fx_importer.py @@ -0,0 +1,1238 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import logging +import operator +import re +from types import NoneType, BuiltinMethodType, BuiltinFunctionType +from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union +import weakref + +import numpy as np + +import torch +import torch.export +import torch.fx as torch_fx +from torch.fx.passes.shape_prop import TensorMetadata + +from torch import ( + dtype as TorchDtype, + FunctionSchema, +) + +from torch._ops import ( + OpOverload as TorchOpOverload, +) + +from torch._subclasses import ( + FakeTensor as TorchFakeTensor, +) + +from torch.fx import ( + Graph, + GraphModule, +) + +from torch.fx.node import ( + Argument as NodeArgument, +) + +from ..ir import ( + Attribute, + Block, + Context, + DenseResourceElementsAttr, + FloatAttr, + BF16Type, + ComplexType, + F16Type, + F32Type, + F64Type, + FunctionType, + InsertionPoint, + IntegerAttr, + IntegerType, + RankedTensorType, + Location, + Module, + Operation, + StringAttr, + SymbolTable, + Type as IrType, + Value, +) + +from ..dialects import ( + func as func_dialect, +) + +__all__ = [ + "FxImporter", +] + +# An external callback that, given a Python value and a GraphNodeImporter, may choose +# to materialize IR to load the value as a vtensor. If it returns None, then default +# literal resolution proceeds. +LiteralResolverCallback = Callable[[Any, "GraphNodeImporter"], Optional[Value]] + +REQUIRED_DIALCTS = [ + "builtin", + "func", + "torch", +] + +TORCH_DTYPE_TO_MLIR_TYPE_ASM = { + torch.float16: "f16", + torch.bfloat16: "bf16", + torch.float32: "f32", + torch.float64: "f64", + torch.uint8: "ui8", + torch.int8: "si8", + torch.int16: "si16", + torch.int32: "si32", + torch.int64: "si64", + torch.bool: "i1", + torch.qint8: "!torch.qint8", + torch.quint8: "!torch.quint8", + torch.complex32: "complex", + torch.complex64: "complex", + torch.complex128: "complex", +} + +TORCH_DTYPE_TO_MLIR_TYPE: Dict[torch.dtype, Callable[[], IrType]] = { + torch.float16: lambda: F16Type.get(), + torch.bfloat16: lambda: BF16Type.get(), + torch.float32: lambda: F32Type.get(), + torch.float64: lambda: F64Type.get(), + torch.uint8: lambda: IntegerType.get_unsigned(8), + torch.int8: lambda: IntegerType.get_signed(8), + torch.int16: lambda: IntegerType.get_signed(16), + torch.int32: lambda: IntegerType.get_signed(32), + torch.int64: lambda: IntegerType.get_signed(64), + torch.bool: lambda: IntegerType.get_signless(1), + torch.qint8: lambda: IntegerType.get_signed(8), + torch.quint8: lambda: IntegerType.get_unsigned(8), + torch.complex32: lambda: ComplexType.get(F16Type.get()), + torch.complex64: lambda: ComplexType.get(F32Type.get()), + torch.complex128: lambda: ComplexType.get(F64Type.get()), +} + +TORCH_DTYPE_TO_NPY_TYPE = { + # torch.qint8: None, # no equivalent np datatype + # torch.quint8: None, + torch.uint8: np.uint8, + torch.int8: np.int8, + torch.int16: np.int16, + torch.int32: np.int32, + torch.int64: np.int64, + # torch.bf16: None, there's no equivalent np datatype so this isn't supported right now + torch.float16: np.float16, + torch.float32: np.float32, + torch.float64: np.float64, + torch.bool: np.bool_, + # torch.complex32: None, # no equivalent precision for numpy + torch.complex64: np.complex64, + torch.complex128: np.complex128, +} + +TORCH_DTYPE_TO_INT = { + torch.uint8: 0, + torch.int8: 1, + torch.int16: 2, + torch.int32: 3, + torch.int64: 4, + torch.float16: 5, + torch.float32: 6, + torch.float64: 7, + # torch.complex_half 8 + torch.complex32: 9, + torch.complex64: 10, + torch.bool: 11, + # torch.qint8: 12, # quantized dtypes are not supported in all backends, currently we do not support them + # torch.quint8: 13, + # torch.qint32 14 + torch.bfloat16: 15, +} + +TORCH_MEMORY_FORMAT_TO_INT = { + torch.contiguous_format: 0, + torch.preserve_format: 1, + torch.channels_last: 2, + torch.channels_last_3d: 3, +} + +TORCH_LAYOUT_TO_INT = { + torch.strided: 0, + torch.sparse_coo: 1, + torch.sparse_csr: 2, + torch.sparse_csc: 3, + torch.sparse_bsr: 4, + torch.sparse_bsc: 5, +} + +PY_BUILTIN_TO_TORCH_OP = { + "truediv": torch.ops.aten.div, + "mul": torch.ops.aten.mul, + "add": torch.ops.aten.add, + "sub": torch.ops.aten.sub, + "lt": torch.ops.aten.lt, + "le": torch.ops.aten.le, + "ge": torch.ops.aten.ge, + "ne": torch.ops.aten.ne, + "gt": torch.ops.aten.gt, +} + +SYMBOLIC_TORCH_OPS = { + torch.ops.aten.sym_size, + torch.ops.aten.sym_stride, + torch.ops.aten.sym_numel, +} + +SYMBOLIC_OP_TO_TORCH_OP = { + (torch.ops.aten.sym_size, 1): torch.ops.aten.size.default, + (torch.ops.aten.sym_size, 2): torch.ops.aten.size.int, + (torch.ops.aten.sym_stride, 1): torch.ops.aten.stride.default, + (torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int, + (torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default, +} + + +"""Check whether an object in our graph is symbolic""" + + +def is_symbolic(obj: Any) -> bool: + return isinstance(obj, (torch.SymInt, torch.SymFloat, torch.SymBool)) + + +def is_builtin_function_or_method(obj: Any) -> bool: + return isinstance(obj, (BuiltinMethodType, BuiltinFunctionType)) + + +class FxImporter: + """Main entry-point for importing an fx.GraphModule. + + The FxImporter is a low-level class intended for framework integrators. + It provides several options for customization: + + * config_check: Optionally allows some per-import configuration safety + checks to be skipped. + * literal_resolver_callback: Callback that will be invoked when a literal, + live torch.Tensor is encountered in the FX graph, allowing the default + action (which is to inline the data as a DenseResourceElementsAttr) to + be completely overriden. + * py_attr_tracker: Weak reference tracker for live PyTorch objects used + to unique them with respect to attributes. If not specified, there will + be one reference tracker per import, but this can be injected to share + the same uniqueing across imports (i.e. if building multiple functions + into the same context or module). + """ + + __slots__ = [ + "_c", + "_cc", + "_literal_resolver_callback", + "_m", + "_m_ip", + "_py_attr_tracker", + "symbol_table", + ] + + def __init__( + self, + *, + module: Optional[Module] = None, + context: Optional[Context] = None, + config_check: bool = True, + literal_resolver_callback: Optional[LiteralResolverCallback] = None, + py_attr_tracker: Optional["RefTracker"] = None, + ): + if module is not None: + assert context is None, "If configuring with a Module, context must be None" + self._m = module + self._c = self.module.context + else: + self._c = context if context else Context() + self._m = Module.create(Location.unknown(self._c)) + if config_check: + # Production code can disable this for a bit of a boost. + self._config_check() + self._py_attr_tracker = py_attr_tracker or RefTracker() + self._cc = ContextCache(self._c, py_attr_tracker=self._py_attr_tracker) + self._m_ip = InsertionPoint(self._m.body) + self._literal_resolver_callback = literal_resolver_callback + self.symbol_table = SymbolTable(self._m.operation) + + def _config_check(self): + for dname in REQUIRED_DIALCTS: + try: + self._c.dialects[dname] + logging.debug("Context has registered dialect '%s'", dname) + except IndexError: + raise RuntimeError( + f"The MLIR context {self._c} is missing required dialect '{dname}'" + ) + + @property + def module(self) -> Module: + return self._m + + @property + def module_op(self) -> Operation: + return self._m.operation + + def import_frozen_exported_program(self, prog: torch.export.ExportedProgram): + """Imports a consolidated torch.export.ExportedProgram instance. + + If using the new torch.export path (vs a lower level precursor), then this is + the recommended way to canonically use this importer. + + The ExportedProgram form differs from some of the earlier work primarily in + how it deals with references to external tensors from "outside". In this form, + all such references are checked to have originated from within the exported + scope or from an @assume_constant_result wrapped function. Then they are + transformed to graph inputs and stashed in one of two data structures on + the ExportedProgram: + inputs_to_buffers / buffers : For non-parameter buffers. + inputs_to_parameters / parameters : For parameter buffers. + The values of the mapping in inputs_to_{buffers|parameters} are in the + state_dict. This replaces get_attr nodes that would have classically been + present during lower level tracing. + Historically, torch-mlir has assumed that all such external accesses are + frozen, and this entry-point preserves this behavior, treating each distinct + torch.Tensor encountered in such a way as a `torch.vtensor.literal` (or + delegating to the literal_resolver_callback to make a policy decision). + + As we anticipate more nuanced treatment options in the future, we name this + method to indicate that it is producing "frozen" modules. Additional top-level + approaches to handling state can be introduced later as an addition. + """ + sig = prog.graph_signature + state_dict = prog.state_dict + arg_replacements: dict[str, Any] = {} + # Lift buffers. + for input_name, state_name in sig.inputs_to_buffers.items(): + try: + state_value = state_dict[state_name] + except KeyError as e: + raise AssertionError("Could not find state mapping for buffer") from e + arg_replacements[input_name] = state_value + + # Lift parameters. + for input_name, state_name in sig.inputs_to_parameters.items(): + try: + state_value = state_dict[state_name] + except KeyError as e: + raise AssertionError( + "Could not find state mapping for parameter" + ) from e + arg_replacements[input_name] = state_value + + # Remove any lifted placeholders, replacing their uses with the state + # replacement value. + g = prog.graph + for node in g.nodes: + if node.op == "placeholder": + replacement = arg_replacements.get(node.name) + if replacement is None: + continue + node.replace_all_uses_with(replacement) + g.erase_node(node) + + self.import_stateless_graph(g) + + def import_graph_module(self, gm: GraphModule): + """Low-level import of a GraphModule assuming that it has been functionalized.""" + self.import_stateless_graph(gm.graph) + + def import_stateless_graph(self, g: Graph, func_name: str = "main"): + """Low-level import of a functionalized, assumed stateless Graph as a func.""" + ftype, loc = self._graph_to_function_meta(g) + # TODO: The FuncOp constructor requires a context-manager context. + # Fix upstream and then unnest. + # See: https://github.com/nod-ai/SHARK-Turbine/issues/138 + with loc: + func = func_dialect.FuncOp( + func_name, + ftype, + ip=self._m_ip, + ) + entry_block = Block.create_at_start(func.body, ftype.inputs) + node_importer = GraphNodeImporter( + self, + self._c, + self._cc, + entry_block, + literal_resolver_callback=self._literal_resolver_callback, + ) + node_importer.import_nodes(g.nodes) + self.symbol_table.insert(func) + + def _graph_to_function_meta(self, g: Graph) -> Tuple[FunctionType, Location]: + """Extracts function metadata from the Graph. + + Principally, this includes the FunctionType, but in the future, + it should also return other annotations (input strides, etc) that + affect compilation and should be included as arg attrs. + """ + input_types = [] + result_types = [] + loc = None + for node in g.nodes: + # Assume that the first node we can get a location for is about as + # good as it gets as an overall function location. + if loc is None: + loc = self._cc.get_node_location(node) + if node.op == "placeholder": + input_types.append(self._cc.node_val_to_type(node)) + elif node.op == "output": + # An output node's args[0] is the return value. This seems to + # always be "boxed" as a tuple, which we emit as multi-results. + for result_node in node.args[0]: + if result_node is None: + result_types.append( + IrType.parse("!torch.none", context=self._c) + ) + else: + result_types.append(self._cc.node_val_to_type(result_node)) + return ( + FunctionType.get(input_types, result_types, context=self._c), + loc if loc else Location.unknown(self._c), + ) + + +class ContextCache: + """Caches per-context lookups of various things that we ask for repeatedly.""" + + __slots__ = [ + "_c", + "_dtype_to_type", + "_tensor_metadata_cache", + "_py_attr_tracker", + # Types. + "torch_bool_type", + "torch_float_type", + "torch_int_type", + "torch_none_type", + "torch_str_type", + "torch_device_type", + ] + + def __init__( + self, context: Context, *, py_attr_tracker: Optional["RefTracker"] = None + ): + self._c = context + self._dtype_to_type: Dict[TorchDtype, IrType] = {} + self._tensor_metadata_cache: Dict[Tuple[torch.Size, torch.dtype], IrType] = {} + self._py_attr_tracker = py_attr_tracker or RefTracker() + + # Common types. + with context: + self.torch_bool_type = IrType.parse("!torch.bool") + self.torch_float_type = IrType.parse("!torch.float") + self.torch_int_type = IrType.parse("!torch.int") + self.torch_none_type = IrType.parse("!torch.none") + self.torch_str_type = IrType.parse("!torch.str") + self.torch_device_type = IrType.parse("!torch.Device") + + def integer_attr(self, value: int, bits: int) -> Attribute: + c = self._c + return IntegerAttr.get(IntegerType.get_signless(bits, c), value) + + """Strips symbolic elements from a torch.Size object and returns shape asm""" + + def format_asm_shape(self, shape: torch.Size) -> str: + return ",".join("?" if is_symbolic(d) else str(d) for d in list(shape)) + + """Return IrType for !torch.vtensor with the given shape and dtype""" + + def get_vtensor_type(self, shape: torch.Size, dtype: torch.dtype): + shape_asm = self.format_asm_shape(shape) + mlir_dtype = str(self.dtype_to_type(dtype)) + return IrType.parse( + f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c + ) + + def node_val_to_type(self, node: torch_fx.Node) -> IrType: + try: + tensor_meta = node.meta.get("tensor_meta") + val = node.meta.get("val") + if tensor_meta is not None: + assert isinstance(tensor_meta, TensorMetadata) + # Quantized tensor meta data is not preserved in our lowering, + # so throw error instead of silently doing wrong thing. + if tensor_meta.is_quantized: + raise NotImplementedError( + f"Quantized tensor meta data is not supported." + ) + else: + return self.tensor_metadata_to_type(tensor_meta) + elif val is not None: + # some nodes with symbolic inputs pass a 'val' attribute rather than + # tensor_meta + if isinstance(val, TorchFakeTensor): + return self.get_vtensor_type(val.size(), val.dtype) + + t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val)) + if t is not None: + return IrType.parse(t, self._c) + + raise NotImplementedError( + f"FIXME: Unsupported placeholder node (this often indicates that a necessary) " + f"fx preprocessing pass was not run): {node.meta}" + ) + except KeyError as e: + raise RuntimeError( + f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})" + ) + + def tensor_metadata_to_type(self, tm: TensorMetadata) -> IrType: + tm_shape = tuple( + item.node if is_symbolic(item) else item for item in list(tm.shape) + ) + + key = (tm_shape, tm.dtype) + t = self._tensor_metadata_cache.get(key) + if t is None: + t = self.get_vtensor_type(tm.shape, tm.dtype) + self._tensor_metadata_cache[key] = t + return t + + def dtype_to_type(self, dtype: TorchDtype) -> IrType: + t = self._dtype_to_type.get(dtype) + if t is None: + try: + asm = TORCH_DTYPE_TO_MLIR_TYPE_ASM[dtype] + except IndexError: + raise ValueError(f"Unknown conversion from {dtype} to IREE type") + t = IrType.parse(asm, self._c) + self._dtype_to_type[dtype] = t + return t + + def tensor_to_vtensor_type(self, tensor: torch.Tensor) -> IrType: + dtype_asm = str(self.dtype_to_type(tensor.dtype)) + return IrType.parse(f"!torch.vtensor<{list(tensor.size())},{dtype_asm}>") + + def get_node_location(self, node: torch_fx.Node) -> Optional[Location]: + stack_trace = node.meta.get("stack_trace") + if stack_trace is None: + return None + # Ugh. + # TODO: Avoid needing to regex match this. + # https://github.com/pytorch/pytorch/issues/91000 + stack_trace = node.stack_trace + if stack_trace: + m = re.search(r"""File "([^"]+)", line ([0-9]+),""", stack_trace) + if m: + filename, line = m.group(1), int(m.group(2)) + return Location.file(filename, line, col=0, context=self._c) + return Location.unknown(context=self._c) + + +class GraphNodeImporter: + """Imports graph nodes into an MLIR function. + + The caller must have already created the function. + """ + + __slots__ = [ + "_b", + "_c", + "_cc", + "_literal_resolver_callback", + "_v", + "_multi_result_nodes", + "fx_importer", + ] + + def __init__( + self, + fx_importer: FxImporter, + context: Context, + context_cache: ContextCache, + block: Block, + *, + literal_resolver_callback: Optional[LiteralResolverCallback] = None, + ): + self.fx_importer = fx_importer + self._c = context + self._cc = context_cache + self._b = block + # Map of (Node, result_index) to MLIR Value. + self._v: Dict[Tuple[torch_fx.Node, int], Value] = {} + # Statically multi-result nodes which we have de-tupled are noted here. + # They will have their getitem calls short-circuited. + self._multi_result_nodes: Set[torch_fx.Node] = set() + self._literal_resolver_callback = literal_resolver_callback + + def import_nodes(self, nodes: Sequence[torch_fx.Node]): + with InsertionPoint(self._b): + loc = Location.unknown() + num_placeholders = 0 + for node in nodes: + op = node.op + # Attempt to extract locations. Not everything has them, + # so we do our best. + new_loc = self._cc.get_node_location(node) + if new_loc is not None: + loc = new_loc + if op == "placeholder": + # Associate the placeholder node with corresponding block + # argument. + self._v[(node, 0)] = self._b.arguments[num_placeholders] + num_placeholders += 1 + elif op == "call_function": + target = node.target + if target == operator.getitem: + # Special case handling of getitem for when it is resolving + # against a function call that we know has returned multiple + # results. We short-circuit this case because we have modeled + # function calls to natively return multiple results vs tupling. + getitem_ref, getitem_index = node.args + if getitem_ref in self._multi_result_nodes: + try: + self._v[(node, 0)] = self._v[ + (getitem_ref, getitem_index) + ] + except IndexError: + raise RuntimeError( + f"getitem de-aliasing failed. This likely " + f"indicates a programmer error that usually " + f"would have happened at runtime. Please " + f"notify developers if this case happens " + f"(at {loc})." + ) + else: + raise NotImplementedError( + f"General getitem access to non-multi-result ops" + ) + elif isinstance(target, TorchOpOverload): + # Dispatch to an ATen op. + self._import_torch_op_overload(loc, node, target) + elif target in SYMBOLIC_TORCH_OPS or ( + is_symbolic(node.meta.get("val")) + and is_builtin_function_or_method(target) + ): + self._import_symbolic_torch_op(loc, node, target) + else: + raise NotImplementedError( + f"FIX ME: Unimplemented call_function: target={node.target}, {node.meta}" + ) + elif op == "output": + # args[0] is a singleton tuple that we flatten into multiple + # results. + operands = [self._import_argument(loc, arg) for arg in node.args[0]] + func_dialect.ReturnOp(operands, loc=loc) + + def _promote_symbolic_scalar_int_float(self, loc, graph, param): + temp_target = torch.ops.aten.Float.Scalar + temp_node = torch.fx.Node( + graph=graph, + name=f"{str(param)}_as_float", + op="call_function", + target=temp_target, + args=(param,), + kwargs={}, + return_type=float, + ) + temp_node.meta["val"] = torch.sym_float(param.meta["val"]) + self._import_torch_op_overload(loc, temp_node, temp_target) + return temp_node + + def _import_symbolic_torch_op( + self, + loc: Location, + node: torch_fx.Node, + target: Union[ + torch._ops.OpOverloadPacket, BuiltinMethodType, BuiltinFunctionType + ], + ): + # parse builtin operations like add, sub, mul, etc. because dynamo captures these + # operations on symbolic arguments as regular python expressions rather than as torch ops + if is_builtin_function_or_method(target): + arg_types = [ + arg.meta["val"].node.pytype + if isinstance(arg, torch.fx.Node) + else type(arg) + for arg in node.args + ] + is_int = [item == int for item in arg_types] + if all(is_int): + op_overload = "int" + elif any(is_int): + if target.__name__ in ("add", "lt", "ge", "ne", "gt"): + op_overload = "float_int" + # put float arg first, as expected in signature + if arg_types[1] == float: + node.args = (node.args[1], node.args[0]) + else: + # promote int argument to float - following torch-mlir convention + arg0, arg1 = node.args + if is_int[0]: + if isinstance(arg0, torch.fx.Node): + prom_arg = self._promote_symbolic_scalar_int_float( + loc, node.graph, arg0 + ) + new_args = (prom_arg, arg1) + else: + arg0 = float(arg0) + new_args = (arg0, arg1) + else: + if isinstance(arg1, torch.fx.Node): + prom_arg = self._promote_symbolic_scalar_int_float( + loc, node.graph, arg1 + ) + new_args = (arg0, prom_arg) + else: + arg1 = float(arg1) + new_args = (arg0, arg1) + + node.args = new_args + op_overload = "float" + else: + op_overload = "float" + + torch_op = PY_BUILTIN_TO_TORCH_OP.get(target.__name__) + assert ( + torch_op is not None + ), f"Unsupported builtin function for symbolic types: {target} with args {node.args}" + concrete_target = getattr(torch_op, op_overload) + else: + concrete_target = SYMBOLIC_OP_TO_TORCH_OP.get((target, len(node.args))) + + assert ( + concrete_target is not None + ), f"Unable to parse symbolic operation: {target} with args {node.args}" + self._import_torch_op_overload(loc, node, concrete_target) + + def _import_torch_op_overload( + self, loc: Location, node: torch_fx.Node, target: TorchOpOverload + ): + # replace lift_fresh_copy with clone op + if target == torch.ops.aten.lift_fresh_copy.default: + node.target = target = torch.ops.aten.clone.default + node.args = (node.args[0], None) + elif target == torch.ops.aten.lift_fresh_copy.out: + node.target = target = torch.ops.aten.clone.out + node.args = (node.args[0], None, node.args[1]) + # TODO: generalize empty.memory_format in the future + # Currently, the aten.baddbmm.default op for Unet includes multiplying an + # empty.memory_format input with a constant, which creates NaN values + # because empty.memory_format contains uninitialized data. Converting + # aten.baddbmm.default -> aten.zeros.default fixes the correctness issue + elif target == torch.ops.aten.empty.memory_format: + if len(node.users) == 1: + for key_node in node.users: + if key_node.target == torch.ops.aten.baddbmm.default: + node.target = target = torch.ops.aten.zeros.default + + schema = target._schema + assert isinstance(schema, FunctionSchema) + + # Map to a `torch` dialect name. + namespace, sep, unqualified_name = schema.name.partition("::") + assert sep, f"Malformed Torch op name {schema.name}" + mlir_op_name = f"torch.{namespace}.{unqualified_name}" + if schema.overload_name != "": + mlir_op_name += f".{schema.overload_name}" + + # Intervening to use Scalar ops due to incorrect ops from AOT-autograd with scalar arguments. + if mlir_op_name in TENSOR_SCALAR_OP_CONVERTER and ( + isinstance(node.args[1], float) or isinstance(node.args[1], int) + ): + mlir_op_name = TENSOR_SCALAR_OP_CONVERTER[mlir_op_name] + # we are dynamically changing which op is emitted here due to an issue in + # torch dynamo where it emits the Tensor variant of ops even when processing + # scalar arguments, therefore we retrieve the schema as well so that we + # consume the correct typing information when subsequently importing the + # function arguments and result types + # i.e. the code below is basically doing `schema = torch.ops.aten.my_op.Scalar._schema` + op_attrs = mlir_op_name.split(".") + op_overload = getattr(torch, "ops") + for i in range(1, len(op_attrs)): + op_overload = getattr(op_overload, op_attrs[i]) + schema = op_overload._schema + + return_count = len(schema.returns) + if return_count == 1: + # Unary return directly maps a single meta["val"] and cannot be subscripted. + # if "tensor_meta" is None, this will throw unsupported placeholder node error + result_types = [self._cc.node_val_to_type(node)] + elif return_count == 0: + # Some torch ops do have 0 returns, and these are supported with ZeroResults + # op trait. Python bindings for IR creation allow us to pass empty result_types + # for such ops. Therefore, we pass an empty result types for these cases. + result_types = [] + else: + # Multi-return will unpack the meta["val"] and trigger our getitem subscripting + # short-circuit above. Note that if we ever choose to also fully reify Python + # level result tuples, we will need to create a tuple-boxed version of this and + # redirect to it for generic object access. + + result_types = [] + for v in node.meta["val"]: + result_types.append(self._cc.tensor_metadata_to_type(v)) + result_types = tuple(result_types) + + self._multi_result_nodes.add(node) + # Unroll operands from formal parameters, args and kwargs. + operands = [] + for i, parameter in enumerate(schema.arguments): + if parameter.kwarg_only and parameter.name in node.kwargs: + operands.append( + self._import_argument( + loc, node.kwargs[parameter.name], parameter.type + ) + ) + elif i < len(node.args): + operands.append( + self._import_argument(loc, node.args[i], parameter.type) + ) + else: + operands.append( + self._import_default_value( + loc, parameter.default_value, parameter.type + ) + ) + + # Support unregistered torch ops using torch.operator. + # torch.operator is used to represent ops from registry + # which haven't been generated by torch_ods_gen.py. + if not self._c.is_registered_operation(mlir_op_name): + operation = Operation.create( + "torch.operator", + attributes={"name": StringAttr.get(mlir_op_name)}, + results=result_types, + operands=operands, + loc=loc, + ) + else: + operation = Operation.create( + mlir_op_name, + results=result_types, + operands=operands, + loc=loc, + ) + + # Record value mapping. + for i, value in enumerate(operation.results): + self._v[(node, i)] = value + + def _import_argument( + self, loc: Location, arg: NodeArgument, expected_jit_type=None + ) -> Value: + """Import an FX `Argument`, which must result to an MLIR `Value`.""" + if isinstance(arg, torch_fx.Node): + # If implementing boxed support for multi-result nodes, then + # this will need to do something more intelligent. + if arg in self._multi_result_nodes: + raise RuntimeError(f"Attempt to de-reference a multi-result node") + + # catch references to dynamically created constant attributes and make sure they have an origin in our module + if arg.op == "get_attr" and (arg.target, 0) not in self._v: + gm = arg.graph.owning_module + assert hasattr( + gm, arg.target + ), f"Attempting to retrieve attribute '{arg.target}' from module, but no such attribute exists" + obj = getattr(gm, arg.target) + with loc: + self._v[(arg, 0)] = self._import_literal(obj) + + return self._v[(arg, 0)] + elif isinstance(arg, torch_fx.immutable_collections.immutable_list): + return self._import_list_argument(loc, arg, expected_jit_type) + elif isinstance(expected_jit_type, torch.TensorType) and not isinstance( + arg, torch.Tensor + ): + # promote scalars to tensor types as appropriate + return self._import_scalar_as_tensor(loc, arg) + else: + with loc: + return self._import_literal(arg) + + def _import_literal(self, py_value: Any) -> Value: + # Apply the conversion callback. + user_callback = self._literal_resolver_callback + if user_callback: + user_value = user_callback(py_value, self) + if user_value is not None: + assert isinstance(user_value, Value) + return user_value + + # Default conversion path. + converter = LITERAL_CONVERTER_MAP.lookup(type(py_value)) + if converter is None: + raise TypeError( + f"Unsupported argument -> literal conversion for {py_value.__class__}" + ) + return converter(py_value, self, self._cc) + + def _import_scalar_as_tensor(self, loc: Location, arg: NodeArgument) -> Value: + tensor_arg = torch.tensor(arg) + result_type = self._cc.get_vtensor_type(tensor_arg.size(), tensor_arg.dtype) + with loc: + constant_arg = LITERAL_CONVERTER_MAP.lookup(type(arg))(arg, self, self._cc) + + return Operation.create( + name="torch.prim.NumToTensor.Scalar", + results=[result_type], + operands=[constant_arg], + loc=loc, + ).result + + def _import_list_argument( + self, loc: Location, arg: NodeArgument, expected_jit_type + ) -> Value: + assert ( + isinstance(expected_jit_type, torch.ListType) + or ( + isinstance(expected_jit_type, torch.OptionalType) + and isinstance(expected_jit_type.getElementType(), torch.ListType) + ) + or isinstance(expected_jit_type, NoneType) + ), f"Unexpected jit type as list argument: {arg} of type {expected_jit_type}" + + # parse list type + if expected_jit_type is None: + element_type = type(arg[0]) + else: + element_jit_type = expected_jit_type.getElementType() + + # this branch is needed to handle Optional[List[]] types + if isinstance(element_jit_type, torch.ListType): + element_jit_type = element_jit_type.getElementType() + + # this handles getting the inner types for List[Optional[]] types + is_optional_type = isinstance(element_jit_type, torch.OptionalType) + if is_optional_type: + element_jit_type = element_jit_type.getElementType() + element_type = TORCH_TYPE_TO_PY_TYPE[type(element_jit_type)] + + # create list operands + list_operands = [] + + for operand in arg: + operand_type = type(operand) + if isinstance(operand, torch.fx.Node): + if operand in self._multi_result_nodes: + raise RuntimeError(f"Attempt to de-reference a multi-result node") + val = self._v[(operand, 0)] + val_type = str(val.type) + assert ( + isinstance(element_type, str) and element_type in val_type + ) or SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get( + element_type + ) == val_type, f"Heterogeneous lists are not supported: expected {element_type}, got {val_type}" + else: + assert (is_optional_type and operand_type is NoneType) or ( + element_type == operand_type + ), f"Heterogeneous lists are not supported: expected {element_type}, got {operand_type}" + + operand_jit_type = ( + torch.NoneType if operand_type is NoneType else element_jit_type + ) + val = self._import_default_value(loc, operand, operand_jit_type) + + list_operands.append(val) + + # construct list op + if is_optional_type: + list_type = PY_TYPE_TO_TORCH_OPTIONAL_LIST_TYPE[element_type] + else: + list_type = PY_TYPE_TO_TORCH_LIST_TYPE[element_type] + + result_type = IrType.parse(list_type, context=self._c) + operation = Operation.create( + "torch.prim.ListConstruct", + results=[result_type], + operands=list_operands, + loc=loc, + ) + + return operation.result + + def _import_default_value(self, loc: Location, arg, expected_jit_type) -> Value: + """Imports a defaulted value for a known function schema.""" + if isinstance(arg, list): + return self._import_list_argument(loc, arg, expected_jit_type) + + # The LITERAL_CONVERTER_MAP maps each arg to its respective constant + # of the expected jit IR type (types like torch.dtype will form a chain of + # maps to get to constant of expected_jit_type). + cvt = LITERAL_CONVERTER_MAP.lookup(type(arg)) + if cvt is None: + raise RuntimeError(f"Unhandled default value ({arg.__class__}): {arg})") + with loc: + return cvt(arg, self, self._cc) + + +def _make_constant_op( + op_name: str, value_attr: Attribute, result_type: Optional[IrType] = None +) -> Operation: + return Operation.create( + op_name, + results=[result_type if result_type else value_attr.type], + attributes={"value": value_attr}, + ) + + +def create_mlir_tensor_type(tensor: torch.Tensor) -> IrType: + try: + dtype = tensor.dtype + element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]() + tensor_type = RankedTensorType.get(tuple(tensor.size()), element_type) + return tensor_type + except KeyError: + raise TypeError(f"Could not map Torch dtype {dtype} to an IREE type") + + +def _make_vtensor_literal_op( + tensor: torch.Tensor, vtensor_type: IrType, py_attr_tracker: "RefTracker" +) -> Operation: + mapping = py_attr_tracker.track(tensor) + if mapping.is_empty: + # Resolve the attribute. + npy_dtype = TORCH_DTYPE_TO_NPY_TYPE.get(tensor.dtype) + assert ( + npy_dtype is not None + ), f"Can not create literal tensor for unsupported datatype: {tensor.dtype}" + # We need a raw buffer of data in order to create an ElementsAttr for the invocation of torch.vtensor.literal, + # but torch.Tensor does not fulfill the python buffer/array interface hence we must convert to a numpy array to get + # a raw buffer of our data. We can't call torch.Tensor.numpy() directly because this internally forces a call to + # detach() which throws an error as we are operating in a FakeTensorMode, hence the simplest way to get this raw + # buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as + # desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above) + np_tensor = np.array(tensor.tolist()).astype(npy_dtype) + bytes_view = memoryview(np_tensor) + tensor_type = create_mlir_tensor_type(tensor) + shape_desc = "_".join([str(d) for d in tensor.shape]) + blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}" + elements_attr = DenseResourceElementsAttr.get_from_buffer( + bytes_view, + blob_name, + tensor_type, + ) + mapping.value = elements_attr + else: + elements_attr = mapping.value + return Operation.create( + name="torch.vtensor.literal", + results=[vtensor_type], + attributes={"value": elements_attr}, + ) + + +################################################################################ +# TypeSubclassMapping +################################################################################ + + +class TypeSubclassMap: + """Mapping of super-types to values. + + Maintains a cache of actual types seen and uses that instead of a linear + scan. + """ + + __slots__ = [ + "_cache", + "_mapping", + ] + + def __init__(self): + # The linear list of converters. + self._mapping: List[Tuple[type, Any]] = [] + # When there is a hit on the linear mapping, memoize it here. + self._cache: Dict[type, Any] = {} + + def map(self, t: type, value: Any): + self._mapping.append((t, value)) + self._cache[t] = value + + def lookup(self, t: type) -> Any: + try: + return self._cache[t] + except KeyError: + pass + for t_super, value in self._mapping: + if issubclass(t, t_super): + self._cache[t] = value + return value + else: + self._cache[t] = None + return None + + +############################################################################### +# Reference mapping +############################################################################### + + +# Opaque value to indicate something is empty. Used in cases where 'None' +# may have a different meaning. +class EmptyType: + ... + + +Empty = EmptyType() + + +class RefMapping: + __slots__ = [ + "_referrent", + "value", + ] + + def __init__(self, referrent: Any): + if referrent is not Empty: + self._referrent = weakref.ref(referrent) + self.value = Empty + + @property + def is_empty(self): + return self.value is Empty + + def __repr__(self): + return ( + f" " + f"{self.value if self.value is not Empty else 'empty'}>" + ) + + +class RefTracker: + """Tracks live references from Python values to symbolic associations.""" + + def __init__(self): + self._refs: Dict[int, RefMapping] = {} + + def track(self, referrent: Any) -> RefMapping: + ref_id = id(referrent) + existing = self._refs.get(ref_id) + if existing: + return existing + info = RefMapping(referrent) + if referrent is not Empty: + weakref.finalize(referrent, self._ref_finalizer, ref_id) + self._refs[ref_id] = info + return info + + def _ref_finalizer(self, ref_id: int): + del self._refs[ref_id] + + +################################################################################ +# Mappings +################################################################################ + +LITERAL_CONVERTER_MAP = TypeSubclassMap() +LITERAL_CONVERTER_MAP.map( + NoneType, + lambda arg, gni, cc: Operation.create( + "torch.constant.none", results=[cc.torch_none_type] + ).result, +) +LITERAL_CONVERTER_MAP.map( + bool, + lambda arg, gni, cc: _make_constant_op( + "torch.constant.bool", cc.integer_attr(arg, 1), cc.torch_bool_type + ).result, +) +LITERAL_CONVERTER_MAP.map( + int, + lambda arg, gni, cc: _make_constant_op( + "torch.constant.int", cc.integer_attr(arg, 64), cc.torch_int_type + ).result, +) +LITERAL_CONVERTER_MAP.map( + float, + lambda arg, gni, cc: _make_constant_op( + "torch.constant.float", FloatAttr.get_f64(arg), cc.torch_float_type + ).result, +) +LITERAL_CONVERTER_MAP.map( + str, + lambda arg, gni, cc: _make_constant_op( + "torch.constant.str", StringAttr.get(arg), cc.torch_str_type + ).result, +) +LITERAL_CONVERTER_MAP.map( + torch.Tensor, + lambda arg, gni, cc: _make_vtensor_literal_op( + arg, cc.tensor_to_vtensor_type(arg), cc._py_attr_tracker + ).result, +) +LITERAL_CONVERTER_MAP.map( + torch.device, + lambda arg, gni, cc: _make_constant_op( + "torch.constant.device", StringAttr.get(str(arg)), cc.torch_device_type + ).result, +) +LITERAL_CONVERTER_MAP.map( + torch.dtype, + lambda arg, gni, cc: LITERAL_CONVERTER_MAP.lookup(int)( + TORCH_DTYPE_TO_INT[arg], gni, cc + ), +) +LITERAL_CONVERTER_MAP.map( + torch.layout, + lambda arg, gni, cc: LITERAL_CONVERTER_MAP.lookup(int)( + TORCH_LAYOUT_TO_INT[arg], gni, cc + ), +) +LITERAL_CONVERTER_MAP.map( + torch.memory_format, + lambda arg, gni, cc: LITERAL_CONVERTER_MAP.lookup(int)( + TORCH_MEMORY_FORMAT_TO_INT[arg], gni, cc + ), +) + +TORCH_TYPE_TO_PY_TYPE = { + torch.IntType: int, + torch.FloatType: float, + torch.StringType: str, + torch.BoolType: bool, + torch.TensorType: "vtensor", +} + +PY_TYPE_TO_TORCH_LIST_TYPE = { + int: "!torch.list", + float: "!torch.list", + str: "!torch.list", + bool: "!torch.list", + "tensor": "!torch.list", + "vtensor": "!torch.list", +} + +PY_TYPE_TO_TORCH_OPTIONAL_LIST_TYPE = { + int: "!torch.list>", + float: "!torch.list>", + str: "!torch.list>", + bool: "!torch.list>", + "tensor": "!torch.list>", + "vtensor": "!torch.list>", +} + +SCALAR_TYPE_TO_TORCH_MLIR_TYPE = { + torch.SymInt: "!torch.int", + torch.SymFloat: "!torch.float", + torch.SymBool: "!torch.bool", + int: "!torch.int", + float: "!torch.float", + str: "!torch.str", + bool: "!torch.bool", + NoneType: "!torch.none", +} + + +# AOT-autograd sometimes falsely emit tensor version op with scalar arguments. +# We may remove this dictionary, if we fix such behavior in the backend. +TENSOR_SCALAR_OP_CONVERTER = { + "torch.aten.mul.Tensor": "torch.aten.mul.Scalar", + "torch.aten.div.Tensor": "torch.aten.div.Scalar", + "torch.aten.add.Tensor": "torch.aten.add.Scalar", + "torch.aten.sub.Tensor": "torch.aten.sub.Scalar", + "torch.aten.floor_divide": "torch.aten.floor_divide.Scalar", +} diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py new file mode 100644 index 000000000000..62d3b1203e03 --- /dev/null +++ b/test/python/fx_importer/basic_test.py @@ -0,0 +1,80 @@ +# Copyright 2023 Advanced Micro Devices, Inc +# +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s + +from typing import Optional + +import torch +import torch.export +import torch.nn as nn + +from torch_mlir.extras.fx_importer import FxImporter +from torch_mlir import ir +from torch_mlir.dialects import torch as torch_d + + +def export_and_import( + f, + *args, + fx_importer: Optional[FxImporter] = None, + constraints: Optional[torch.export.Constraint] = None, + **kwargs, +): + context = ir.Context() + torch_d.register_dialect(context) + + if fx_importer is None: + fx_importer = FxImporter(context=context) + prog = torch.export.export(f, args, kwargs, constraints=constraints) + fx_importer.import_frozen_exported_program(prog) + return fx_importer.module_op + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# CHECK-LABEL: test_import_frozen_exported_program +# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> +# CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32> +# CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32> +# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense_resource : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> +# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]] +# CHECK-DAG: %[[mul_a:.+]] = torch.aten.mul.Tensor %[[tanh]], %[[a]] +# CHECK-DAG: %[[mul_b:.+]] = torch.aten.mul.Tensor %[[mul_a]], %[[b]] +# CHECK-DAG: %[[mul_p:.+]] = torch.aten.mul.Tensor %[[mul_b]], %[[p]] +# CHECK: return %[[mul_p]] +# +# Validate dialect resources exist. +# CHECK: dialect_resources: +# CHECK-DAG: torch_tensor_1_4_torch.float32 +# CHECK-DAG: torch_tensor_3_1_torch.float32 +# CHECK-DAG: torch_tensor_1_1_torch.float32 +def test_import_frozen_exported_program(): + # Tests the basic structural premises of import_frozen_exported_program, + # namely that free tensors (buffers) and parameters are treated as + # literals and frozen. + @torch._dynamo.assume_constant_result + def get_a(): + return torch.randn(1, 4) + + class Basic(nn.Module): + def __init__(self): + super().__init__() + self.b = torch.randn(3, 1) + self.p = nn.Parameter(torch.randn(1, 1)) + + def forward(self, x): + return torch.tanh(x) * get_a() * self.b * self.p + + m = export_and_import(Basic(), torch.randn(3, 4)) + print(m) From 85b86b36a28ceebbe57d9aca4083708ea8e675af Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 21 Dec 2023 17:05:18 -0800 Subject: [PATCH 039/283] [onnx] Fix importer variable names to make `mlir` legal (#2690) Some names for `onnx` identifiers are not legal in `mlir-ir`. Sanitize so that the generated `ir` is legal. --- python/torch_mlir/extras/onnx_importer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index a9dd52253601..dbf0adc490bd 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -38,6 +38,7 @@ from dataclasses import dataclass import numpy as np +import re from ..ir import ( ArrayAttr, @@ -464,13 +465,18 @@ def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType: # See TypeProto: sequence_type, map_type, optional_type, sparse_tensor_type. raise OnnxImportError(f"Unsupported ONNX TypeProto: {tp}") + def _sanitize_name(self, name): + if not name.isidentifier(): + name = "_" + name + return re.sub("[:/]", "_", name) + def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: tensor_type = self.tensor_proto_to_builtin_type(tp) if tp.HasField("raw_data"): # Conveniently, DenseResourceElementsAttr shares the raw data # format. We just give it maximum numeric alignment. return DenseResourceElementsAttr.get_from_buffer( - tp.raw_data, tp.name, tensor_type, alignment=8 + tp.raw_data, self._sanitize_name(tp.name), tensor_type, alignment=8 ) else: # We have to do a data type specific instantiation from proto fields. From 9a72c6584e72f009eae765956cfd8a9d55f49497 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 21 Dec 2023 16:04:02 +0000 Subject: [PATCH 040/283] [MLIR][ONNX] Add OnnxToTorch support for BatchNormalization and Concat op. This commit adds the OnnxToTorch support for BatchNormalization and Concat op. Signed-Off By: vivekkhandelwal1424@gmail.com --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 61 +++++++++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 128 ++++++++++++++++++ 2 files changed, 189 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 421b4edc9e18..6e74f39b1b4b 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -165,6 +165,43 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); + patterns.onOp("BatchNormalization", 15, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input, weight, bias, runningMean, runningVar; + bool training; + float momentum, eps; + if (binder.s64BoolAttr(training, "training_mode", 0)) + return failure(); + if (training) { + // TODO: Add support for training = true + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: training = true"); + } + + if (binder.tensorOperandAtIndex(input, 0) || + binder.tensorOperandAtIndex(weight, 1) || + binder.tensorOperandAtIndex(bias, 2) || + binder.tensorOperandAtIndex(runningMean, 3) || + binder.tensorOperandAtIndex(runningVar, 4) || + binder.f32FloatAttr(momentum, "momentum", 0.9) || + binder.f32FloatAttr(eps, "epsilon", 1e-05) || + binder.tensorResultType(resultType)) + return failure(); + + Value cstFalse = rewriter.create( + binder.getLoc(), false); + Value cstMomentum = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(momentum)); + Value cstEps = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(eps)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, weight, bias, runningMean, + runningVar, /*training=*/cstFalse, cstMomentum, cstEps, + /*cudnn_enabled=*/cstFalse); + return success(); + }); patterns.onOp( "AveragePool", 19, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -426,6 +463,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( } return failure(); }); + patterns.onOp( + "Concat", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + SmallVector tensors; + int64_t dim; + if (binder.tensorOperands(tensors, binder.op->getNumOperands()) || + binder.s64IntegerAttr(dim, "axis", 0) || + binder.tensorResultType(resultType)) + return failure(); + Type listElemType = + tensors[0] + .getType() + .cast() + .getWithSizesAndDtype(/*optionalSizes=*/std::nullopt, + /*optionalDtype=*/nullptr); + Type listType = Torch::ListType::get(listElemType); + Value tensorList = rewriter.create( + binder.op->getLoc(), listType, tensors); + Value cstDim = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dim)); + rewriter.replaceOpWithNewOp(binder.op, resultType, + tensorList, cstDim); + return success(); + }); patterns.onOp( "Conv", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 438d6c7adda8..a637837cd671 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -592,3 +592,131 @@ func.func @test_convtranspose(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torc %0 = torch.operator "onnx.ConvTranspose"(%arg0, %arg1) {torch.onnx.pads = [1 : si64, 2 : si64, 1 : si64, 2 : si64], torch.onnx.strides = [3 : si64, 2 : si64]} : (!torch.vtensor<[1,1,3,3],f32>, !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,7,3],f32> return %0 : !torch.vtensor<[1,2,7,3],f32> } + +// CHECK-LABEL: @test_batchnorm_epsilon +func.func @test_batchnorm_epsilon(%arg0: !torch.vtensor<[2,3,4,5],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>, %arg4: !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[MOMENTUM:.*]] = torch.constant.float 0.89999997615814208 + // CHECK: %[[EPS:.*]] = torch.constant.float 0.0099999997764825821 + // CHECK: torch.aten.batch_norm %arg0, %arg1, %arg2, %arg3, %arg4, %[[FALSE]], %[[MOMENTUM]], %[[EPS]], %[[FALSE]] : !torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[2,3,4,5],f32> + %0 = torch.operator "onnx.BatchNormalization"(%arg0, %arg1, %arg2, %arg3, %arg4) {torch.onnx.epsilon = 0.00999999977 : f32} : (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32> + return %0 : !torch.vtensor<[2,3,4,5],f32> +} + +// CHECK-LABEL: @test_batchnorm_example +func.func @test_batchnorm_example(%arg0: !torch.vtensor<[2,3,4,5],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>, %arg4: !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[MOMENTUM:.*]] = torch.constant.float 0.89999997615814208 + // CHECK: %[[EPS:.*]] = torch.constant.float 9.9999997473787516E-6 + // CHECK: torch.aten.batch_norm %arg0, %arg1, %arg2, %arg3, %arg4, %[[FALSE]], %[[MOMENTUM]], %[[EPS]], %[[FALSE]] : !torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[2,3,4,5],f32> + %0 = torch.operator "onnx.BatchNormalization"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[2,3,4,5],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32> + return %0 : !torch.vtensor<[2,3,4,5],f32> +} + +// CHECK-LABEL: @test_concat_1d_axis_0 +func.func @test_concat_1d_axis_0(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int 0 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[4],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// CHECK-LABEL: @test_concat_1d_axis_negative_1 +func.func @test_concat_1d_axis_negative_1(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int -1 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[4],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// CHECK-LABEL: @test_concat_2d_axis_0 +func.func @test_concat_2d_axis_0(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[4,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int 0 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[4,2],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[4,2],f32> + return %0 : !torch.vtensor<[4,2],f32> +} + +// CHECK-LABEL: @test_concat_2d_axis_1 +func.func @test_concat_2d_axis_1(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int 1 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[2,4],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,4],f32> + return %0 : !torch.vtensor<[2,4],f32> +} + +// CHECK-LABEL: @test_concat_2d_axis_negative_1 +func.func @test_concat_2d_axis_negative_1(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int -1 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[2,4],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,4],f32> + return %0 : !torch.vtensor<[2,4],f32> +} + +// CHECK-LABEL: @test_concat_2d_axis_negative_2 +func.func @test_concat_2d_axis_negative_2(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[4,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int -2 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[4,2],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[4,2],f32> + return %0 : !torch.vtensor<[4,2],f32> +} + +// CHECK-LABEL: @test_concat_3d_axis_0 +func.func @test_concat_3d_axis_0(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[4,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int 0 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[4,2,2],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[4,2,2],f32> + return %0 : !torch.vtensor<[4,2,2],f32> +} + +// CHECK-LABEL: @test_concat_3d_axis_1 +func.func @test_concat_3d_axis_1(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,4,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int 1 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[2,4,2],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,4,2],f32> + return %0 : !torch.vtensor<[2,4,2],f32> +} + +// CHECK-LABEL: @test_concat_3d_axis_2 +func.func @test_concat_3d_axis_2(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int 2 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[2,2,4],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,4],f32> + return %0 : !torch.vtensor<[2,2,4],f32> +} + +// CHECK-LABEL: @test_concat_3d_axis_negative_1 +func.func @test_concat_3d_axis_negative_1(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int -1 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[2,2,4],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,4],f32> + return %0 : !torch.vtensor<[2,2,4],f32> +} + +// CHECK-LABEL: @test_concat_3d_axis_negative_2 +func.func @test_concat_3d_axis_negative_2(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,4,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int -2 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[2,4,2],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,4,2],f32> + return %0 : !torch.vtensor<[2,4,2],f32> +} + +// CHECK-LABEL: @test_concat_3d_axis_negative_3 +func.func @test_concat_3d_axis_negative_3(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[4,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list + // CHECK: %[[DIM:.*]] = torch.constant.int -3 + // CHECK: torch.aten.cat %[[TENSORS_LIST]], %[[DIM]] : !torch.list, !torch.int -> !torch.vtensor<[4,2,2],f32> + %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -3 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[4,2,2],f32> + return %0 : !torch.vtensor<[4,2,2],f32> +} From 0849fd0a0681598e1eaadf8bcc23699235819973 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 22 Dec 2023 08:01:13 +0000 Subject: [PATCH 041/283] [MLIR][ONNX] Fix onnx.conv lowering to handle bias tensor Signed-Off By: Vivek Khandelwal --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 6 ++++-- .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 20 +++++++++++++++++++ 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 6e74f39b1b4b..61bea1d866f1 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -501,7 +501,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Torch::ValueTensorType resultType; Value input, weight; int64_t group; - if (binder.tensorOperands(input, weight) || + if (binder.tensorOperandAtIndex(input, 0) || + binder.tensorOperandAtIndex(weight, 1) || binder.s64IntegerAttr(group, "group", 1) || binder.tensorResultType(resultType)) return failure(); @@ -668,7 +669,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Torch::ValueTensorType resultType; Value input, weight; int64_t group; - if (binder.tensorOperands(input, weight) || + if (binder.tensorOperandAtIndex(input, 0) || + binder.tensorOperandAtIndex(weight, 1) || binder.s64IntegerAttr(group, "group", 1) || binder.tensorResultType(resultType)) return failure(); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index a637837cd671..dc4d3e163052 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -505,6 +505,26 @@ func.func @test_conv_with_strides_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, return %0 : !torch.vtensor<[1,1,4,3],f32> } +// CHECK-LABEL: @test_conv_with_bias_strides_padding +func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,224],f32>, %arg1: !torch.vtensor<[64,3,7,7],f32>, %arg2: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,112,112],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C3]], %[[C3_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[DILATIONS:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C2]], %[[C2_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[OUTPUT_PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[TRANSPOSED:.*]] = torch.constant.bool false + // CHECK: %[[GROUPS:.*]] = torch.constant.int 1 + // CHECK: torch.aten.convolution %arg0, %arg1, %arg2, %[[STRIDE]], %[[PADDING]], %[[DILATIONS]], %[[TRANSPOSED]], %[[OUTPUT_PADDING]], %[[GROUPS]] : !torch.vtensor<[?,?,224,224],f32>, !torch.vtensor<[64,3,7,7],f32>, !torch.vtensor<[64],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[?,64,112,112],f32> + %0 = torch.operator "onnx.Conv"(%arg0, %arg1, %arg2) {torch.onnx.dilations = [1 : si64, 1 : si64], torch.onnx.group = 1 : si64, torch.onnx.kernel_shape = [7 : si64, 7 : si64], torch.onnx.pads = [3 : si64, 3 : si64, 3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[?,?,224,224],f32>, !torch.vtensor<[64,3,7,7],f32>, !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,112,112],f32> + return %0 : !torch.vtensor<[?,64,112,112],f32> +} + // CHECK-LABEL: @test_convtranspose_dilations func.func @test_convtranspose_dilations(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,1,2,2],f32>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0 From ee75e8d1ae72b0b4868c6dae383709b83eb7e842 Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Tue, 26 Dec 2023 10:20:13 -0800 Subject: [PATCH 042/283] [MLIR][ONNX] Add OnnxToTorch support for Reshape Op (#2698) This commit adds the OnnxToTorch support for Reshape op. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 91 +++++++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 190 ++++++++++++++++++ 2 files changed, 281 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index a8ea1cc102c7..f943f288fc40 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -873,4 +873,95 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter.replaceOp(binder.op, operand); return success(); }); + + patterns.onOp( + "Reshape", 5, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + Value shape; + int64_t allowzero; + if (binder.tensorOperands(data, shape) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(allowzero, "allowzero", 0)) + return failure(); + Torch::BaseTensorType shapeType = + shape.getType().cast(); + SmallVector dimList; + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = shapeType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype()); + auto shapeSizes = + dyn_cast(shape.getType()).getSizes(); + auto dataSizes = + dyn_cast(data.getType()).getSizes(); + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + if (allowzero == 0) { + // convert shape (tensor) into torch int list while dealing with zero + // vals + for (int i = 0; i < shapeSizes[0]; i++) { + // Go through the shape list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, shape, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + // deal with zero axis values: replace with original dim value in + // input + Value isZero = + rewriter.create(binder.getLoc(), dim, zero); + isZero = + rewriter.create(binder.getLoc(), isZero); + Value adjustment; + int64_t inputDimsSize = dataSizes.size(); + if (i < inputDimsSize) { + adjustment = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + dataSizes[i])); + } + // Will never have a 0 in the shape tensor input at an index out of + // bounds of original input dims Therefore, no need to adjust + else { + adjustment = zero; + } + Value finalOffset = rewriter.create( + binder.getLoc(), isZero, adjustment); + Value finalDim = rewriter.create( + binder.getLoc(), dim, finalOffset); + dimList.push_back(finalDim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + dimList); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList); + return success(); + } + // convert axes (tensor) into torch int list + for (int i = 0; i < shapeSizes[0]; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, shape, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + dimList.push_back(dim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + rewriter.replaceOpWithNewOp(binder.op, resultType, + data, dimValueList); + return success(); + }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 0f4fcb08cdfb..5aca8688dac5 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -830,3 +830,193 @@ func.func @test_transpose_all_permutations_4(%arg0: !torch.vtensor<[2,3,4],f32>) // CHECK: return %[[TRANSPOSE1]] return %0 : !torch.vtensor<[4,2,3],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_reshape_negative_dim +func.func @test_reshape_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT2_1:.+]] = torch.constant.int 2 + // CHECK: torch.aten.select.int %arg1, %int0, %int2_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int + // CHECK: %[[INT4:.+]] = torch.constant.int 4 + // CHECK: torch.aten.mul.int %15, %int4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %18 : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,6,2],f32> + %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32> + return %0 : !torch.vtensor<[2,6,2],f32> +} + +// CHECK-LABEL: func.func @test_reshape_negative_extended_dims +func.func @test_reshape_negative_extended_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,2,3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT2_1:.+]] = torch.constant.int 2 + // CHECK: torch.aten.select.int %arg1, %int0, %int2_1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int + // CHECK: %[[INT4:.+]] = torch.constant.int 4 + // CHECK: torch.aten.mul.int %15, %int4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT3_2:.+]] = torch.constant.int 3 + // CHECK: torch.aten.select.int %arg1, %int0, %int3_2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %18 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.eq.int %19, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %20 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %21, %int0 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %19, %22 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5, %11, %17, %23 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %24 : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[1,2,3,4],f32> + %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,2,3,4],f32> + return %0 : !torch.vtensor<[1,2,3,4],f32> +} + +// CHECK-LABEL: func.func @test_reshape_one_dim +func.func @test_reshape_one_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[24],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %6 : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[24],f32> + %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[24],f32> + return %0 : !torch.vtensor<[24],f32> +} + +// CHECK-LABEL: func.func @test_reshape_reduced_dims +func.func @test_reshape_reduced_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,12],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5, %11 : (!torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %12 : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,12],f32> + %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,12],f32> + return %0 : !torch.vtensor<[2,12],f32> +} + +// CHECK-LABEL: func.func @test_reshape_reordered_all_dims +func.func @test_reshape_reordered_all_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[4,2,3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT2_1:.+]] = torch.constant.int 2 + // CHECK: torch.aten.select.int %arg1, %int0, %int2_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int + // CHECK: %[[INT4:.+]] = torch.constant.int 4 + // CHECK: torch.aten.mul.int %15, %int4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %18 : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[4,2,3],f32> + %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[4,2,3],f32> + return %0 : !torch.vtensor<[4,2,3],f32> +} + +// CHECK-LABEL: func.func @test_reshape_zero_and_negative_dim +func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[2,3,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT2_1:.+]] = torch.constant.int 2 + // CHECK: torch.aten.select.int %arg1, %int0, %int2_1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int + // CHECK: %[[INT4:.+]] = torch.constant.int 4 + // CHECK: torch.aten.mul.int %15, %int4 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT3_2:.+]] = torch.constant.int 3 + // CHECK: torch.aten.select.int %arg1, %int0, %int3_2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %18 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.eq.int %19, %int0 : !torch.int, !torch.int -> !torch.bool + // CHECK: torch.aten.Int.bool %20 : !torch.bool -> !torch.int + // CHECK: torch.aten.mul.int %21, %int0 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.aten.add.int %19, %22 : !torch.int, !torch.int -> !torch.int + // CHECK: torch.prim.ListConstruct %5, %11, %17, %23 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %24 : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,3,1,4],f32> + %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[2,3,1,4],f32> + return %0 : !torch.vtensor<[2,3,1,4],f32> +} From 4f252c88b486c73f1d7bf776168c2ace09c2a169 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 26 Dec 2023 23:55:31 +0530 Subject: [PATCH 043/283] [MLIR][ONNX] Add OnnxToTorch support for GlobalAveragePool op. (#2692) This commit adds the OnnxToTorch support for GlobalAveragePool op. Signed-Off By: vivekkhandelwal1424@gmail.com --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 71 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 36 ++++++++++ 2 files changed, 107 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index b9bb6a540a02..74cf7472bd68 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -242,6 +242,77 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, mm, c, constBeta); return success(); }); + patterns.onOp( + "GlobalAveragePool", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + + auto inputTensorType = operand.getType().cast(); + if (!inputTensorType || !inputTensorType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected input type having sizes"); + } + ArrayRef inputShape = inputTensorType.getSizes(); + unsigned inputRank = inputShape.size(); + if (!resultType || !resultType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected result type having sizes"); + } + ArrayRef resultShape = resultType.getSizes(); + + SmallVector cstKernel, cstPadding, cstStrides; + Value cstZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value cstOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + for (unsigned i = 2; i < inputRank; i++) { + int64_t kernelSize = inputShape[i] - resultShape[i] + 1; + cstKernel.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(kernelSize))); + cstPadding.push_back(cstZero); + cstStrides.push_back(cstOne); + } + Value kernelSizeList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstKernel); + Value paddingList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstPadding); + Value stridesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstStrides); + Value cstFalse = rewriter.create(binder.getLoc(), false); + Value cstCeilMode = cstFalse; + Value cstCountIncludePad = cstFalse; + Value cstNone = rewriter.create(binder.getLoc()); + + if (inputRank == 3) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludePad); + return success(); + } else if (inputRank == 4) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludePad, + /*divisor_override=*/cstNone); + return success(); + } else if (inputRank == 5) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, cstCeilMode, cstCountIncludePad, + /*divisor_override=*/cstNone); + return success(); + } + return failure(); + }); patterns.onOp("LeakyRelu", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 08bb69f23fc5..c75077492b2f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -265,3 +265,39 @@ func.func @test_hardsigmoid_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torc %0 = torch.operator "onnx.HardSigmoid"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } + +// ----- + +// CHECK-LABEL: @test_globalaveragepool +func.func @test_globalaveragepool(%arg0: !torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C5:.*]] = torch.constant.int 5 + // CHECK: %[[C5_0:.*]] = torch.constant.int 5 + // CHECK: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[C5]], %[[C5_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.avg_pool2d %arg0, %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,3,5,5],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,1,1],f32> + %0 = torch.operator "onnx.GlobalAveragePool"(%arg0) : (!torch.vtensor<[1,3,5,5],f32>) -> !torch.vtensor<[1,3,1,1],f32> + return %0 : !torch.vtensor<[1,3,1,1],f32> +} + +// ----- + +// CHECK-LABEL: @test_globalaveragepool_precomputed +func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,1,1],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[KERNELSIZE:.*]] = torch.prim.ListConstruct %[[C3]], %[[C3_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[PADDING:.*]] = torch.prim.ListConstruct %[[C0]], %[[C0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[STRIDE:.*]] = torch.prim.ListConstruct %[[C1]], %[[C1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: torch.aten.avg_pool2d %arg0, %[[KERNELSIZE]], %[[STRIDE]], %[[PADDING]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[1,1,3,3],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1,1],f32> + %0 = torch.operator "onnx.GlobalAveragePool"(%arg0) : (!torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1,1],f32> +} From abc6b0a25a8d6d42b440a32514c508302767469e Mon Sep 17 00:00:00 2001 From: aldesilv <35313749+aldesilv@users.noreply.github.com> Date: Wed, 27 Dec 2023 09:34:48 -0800 Subject: [PATCH 044/283] onnx to torch pow support (#2656) --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 15 ++++++++++++++- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 7 +++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 74cf7472bd68..710c9823f757 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -329,4 +329,17 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, operand, constAlpha); return success(); }); -} \ No newline at end of file + patterns.onOp("Pow", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); +} + diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index c75077492b2f..5dd0225b9c41 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -200,6 +200,13 @@ func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch. return %0 : !torch.vtensor<[3,4,5],i1> } +// CHECK-LABEL: func.func @test_pow + func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Pow"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> + } + // CHECK-LABEL: @test_hardsigmoid_example func.func @test_hardsigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 5.000000e-01 From 6847fc1fc69d71d1d9b9485881f80ea12218564f Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Wed, 27 Dec 2023 10:08:09 -0800 Subject: [PATCH 045/283] Fix since-opset too high (#2701) Addresses two of the ops from https://github.com/llvm/torch-mlir/issues/2689 https://github.com/llvm/torch-mlir/issues/2700 --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 4 +- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 4 +- .../unsupported_fb_opt_ops.mlir | 40 +++++++++++++++++++ 3 files changed, 44 insertions(+), 4 deletions(-) create mode 100644 test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 61bea1d866f1..3d2cf8aaee73 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -377,7 +377,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); }); patterns.onOp( - "Cast", 19, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "Cast", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; int64_t dtypeIntOnnx, dtypeIntTorch; @@ -848,7 +848,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, lhs, rhs); return success(); }); - patterns.onOp("Equal", 19, + patterns.onOp("Equal", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 710c9823f757..e7544d6c12ff 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -169,7 +169,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return success(); }); patterns.onOp( - "Gemm", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "Gemm", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value a, b, c; float alpha, beta; @@ -313,7 +313,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } return failure(); }); - patterns.onOp("LeakyRelu", 16, + patterns.onOp("LeakyRelu", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value operand; diff --git a/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir b/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir new file mode 100644 index 000000000000..8401c378b77c --- /dev/null +++ b/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir @@ -0,0 +1,40 @@ +// RUN: torch-mlir-opt <%s -split-input-file -verify-diagnostics -convert-torch-onnx-to-torch +// FB OPT OPS from https://github.com/llvm/torch-mlir/issues/2689 + +// ----- +// Fixed unecessarily high since-opset value +func.func @cast_operation(%arg0: !torch.vtensor<[?,?,?,?],si64>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %208 = torch.operator "onnx.Cast"(%arg0) { + torch.onnx.to = 1 : si64 + } : (!torch.vtensor<[?,?,?,?],si64>) -> !torch.vtensor<[?,?,?,?],f32> + return %208 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- +func.func @div_operation(%arg0: !torch.vtensor<[1,64,768],f32>, + %arg1: !torch.vtensor<[1,64,1],f32>) + -> !torch.vtensor<[1,64,768],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %209 = torch.operator "onnx.Div"(%arg0, %arg1) : (!torch.vtensor<[1,64,768],f32>, !torch.vtensor<[1,64,1],f32>) -> !torch.vtensor<[1,64,768],f32> + return %209 : !torch.vtensor<[1,64,768],f32> +} + +// ----- +// Fixed. +// this is the onnx opset 1 version of Equal, only int types. +// this used to fail to legalize because the "since" value is set unecessarily high (19) +func.func @equal_operation(%arg0: !torch.vtensor<[4],si64>, + %arg1: !torch.vtensor<[4],si64>) + -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %205 = torch.operator "onnx.Equal"(%arg0, %arg1) : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],i1> + return %205 : !torch.vtensor<[4],i1> +} + + +// ----- +func.func @reduce_mean_operation(%arg0: !torch.vtensor<[1,64,768],f32>) + -> !torch.vtensor<[1,64,1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // The ReduceMean operation as provided. + // expected-error @+1 {{failed to legalize operation 'torch.operator'}} + %211 = torch.operator "onnx.ReduceMean"(%arg0) {torch.onnx.axes = [-1 : si64]} : (!torch.vtensor<[1,64,768],f32>) -> !torch.vtensor<[1,64,1],f32> + return %211 : !torch.vtensor<[1,64,1],f32> +} \ No newline at end of file From 336cfb64b531a3d34cae74f007448749420d64ac Mon Sep 17 00:00:00 2001 From: aldesilv <35313749+aldesilv@users.noreply.github.com> Date: Wed, 27 Dec 2023 10:50:08 -0800 Subject: [PATCH 046/283] OnnxToTorch support for onnx.Mul op (#2699) --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 13 ++++++++++++- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 9 +++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index e7544d6c12ff..c24fd0c6549c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -99,6 +99,18 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); + patterns.onOp("Mul", 7, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); patterns.onOp("Greater", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -342,4 +354,3 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return success(); }); } - diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 5dd0225b9c41..085c6ea6a889 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -151,6 +151,15 @@ func.func @test_matmul_4d(%arg0: !torch.vtensor<[1,2,3,4],f32>, %arg1: !torch.vt // ----- +// CHECK-LABEL: func.func @test_mul + func.func @test_mul(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Mul"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> + } + +// ----- + // CHECK-LABEL: @test_gelu_default_1 func.func @test_gelu_default_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[STR1:.*]] = torch.constant.str "none" From 2d796b750250b204fd0ad0d29f844b81db95578c Mon Sep 17 00:00:00 2001 From: aldesilv <35313749+aldesilv@users.noreply.github.com> Date: Wed, 27 Dec 2023 11:07:35 -0800 Subject: [PATCH 047/283] lower onnx max op to torch aten maximum op (#2618) lower onnx min op to torch aten minimum op --- .../Conversion/TorchOnnxToTorch/Patterns.h | 7 ++ .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 105 ++++++++++++++++-- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 43 +++++++ 3 files changed, 145 insertions(+), 10 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 85d6f805f3f6..d842ea77bd3c 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -76,6 +76,13 @@ struct OpBinder { return failure(); return success(); } + + ParseResult tensorOperandsList( llvm::SmallVectorImpl &values) { + for (int i = 0; i < op->getNumOperands(); i++) { + values.push_back(op->getOperand(i)); + } + return success(); + } // Result type matchers of different arities. ParseResult tensorResultType(Torch::ValueTensorType &type0) { diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index c24fd0c6549c..d154edb1ab75 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -88,8 +88,45 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( operand, vApproximate); return success(); }); - patterns.onOp("MatMul", 13, + patterns.onOp("Less", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + + patterns.onOp("LessOrEqual", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp("Log", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp("MatMul", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || @@ -135,19 +172,67 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); - patterns.onOp("Less", 13, + patterns.onOp("Max", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if (binder.tensorOperandsList(operands) || + binder.tensorResultType(resultType) || + operands.size() == 0) { + return failure(); + } + Value result = operands[0]; + for (int i = 1; i < operands.size(); i++) { + result = rewriter.create( + binder.getLoc(), resultType, result, operands[i]); + } + rewriter.replaceOp( + binder.op, result.getDefiningOp()); + return success(); + }); + patterns.onOp("Min", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; - Value lhs, rhs; - if (binder.tensorOperands(lhs, rhs) || + llvm::SmallVector operands; + if (binder.tensorOperandsList(operands) || + binder.tensorResultType(resultType) || + operands.size() == 0) { + return failure(); + } + Value result = operands[0]; + for (int i = 1; i < operands.size(); i++) { + result = rewriter.create( + binder.getLoc(), resultType, result, operands[i]); + } + rewriter.replaceOp( + binder.op, result.getDefiningOp()); + return success(); + }); + patterns.onOp("Neg", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) { return failure(); } - rewriter.replaceOpWithNewOp( - binder.op, resultType, lhs, rhs); - return success(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); }); - patterns.onOp("LessOrEqual", 16, + patterns.onOp("Not", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp("Or", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; @@ -155,9 +240,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.tensorResultType(resultType)) { return failure(); } - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); - return success(); + return success(); }); patterns.onOp( "GatherElements", 13, diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 085c6ea6a889..e224ddfa2944 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -317,3 +317,46 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 %0 = torch.operator "onnx.GlobalAveragePool"(%arg0) : (!torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,1,1],f32> return %0 : !torch.vtensor<[1,1,1,1],f32> } + +// CHECK-LABEL: func.func @test_max_example + func.func @test_max_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Max"(%arg0, %arg1, %arg2) : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> + } + +// CHECK-LABEL: func.func @test_min_example + func.func @test_min_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.minimum %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Min"(%arg0, %arg1, %arg2) : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> + } + + +// CHECK-LABEL: func.func @test_log + func.func @test_log(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.log %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Log"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> + } + +// CHECK-LABEL: func.func @test_neg + func.func @test_neg(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.neg %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Neg"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> + } + +// CHECK-LABEL: func.func @test_not_2d +func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1> + %0 = torch.operator "onnx.Not"(%arg0) : (!torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> + return %0 : !torch.vtensor<[3,4],i1> + } + +// CHECK-LABEL: func.func @test_or2d + func.func @test_or2d(%arg0: !torch.vtensor<[3,4],i1>, %arg1: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.bitwise_or.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],i1>, !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1> + %0 = torch.operator "onnx.Or"(%arg0, %arg1) : (!torch.vtensor<[3,4],i1>, !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> + return %0 : !torch.vtensor<[3,4],i1> + } From 1b40b6384e8d7e716f42142f90c00c653c7d2635 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Wed, 27 Dec 2023 12:13:34 -0800 Subject: [PATCH 048/283] [onnx] Add torch-mlir-import-onnx native port as an optional tool/library. (#2694) As noted in the plan when this work started, we need to produce an ORT EP plugin for a downstream project, and this will necessitate a C-based ONNX importer (as opposed to the existing Python one). Because this comes with dependencies that we do not want to impart on various projects, this is optional in torch-mlir. It is also factored so that it can be used as standalone sources in downstreams that need it. Since it only depends on public C APIs on the MLIR side, this will make build coupling a lot better (since a C++ dep is not needed on the compiler and it is trivial to dynamically load). Our original plan was just to maintain this fork off to the side in our ORT plugin, but once work started, it seemed better to write it clean and contribute it upstream for anyone to use. We expect that for non-ORT use, the Python importer will have better ergonomics for most folks. I will follow-up with a test suite refactor so that we can drive the Python or C importer. This is a relatively mechanical port from Python to C, borrowing some scaffolding from the old JitIR importer. It does attempt to lay some groundwork for external data, which will need to be implemented on the Python side as well. --- CMakeLists.txt | 2 + projects/CMakeLists.txt | 4 + projects/onnx_c_importer/CMakeLists.txt | 44 + projects/onnx_c_importer/OnnxImporter.cpp | 1009 +++++++++++++++++ projects/onnx_c_importer/OnnxImporter.h | 240 ++++ projects/onnx_c_importer/README.md | 7 + projects/onnx_c_importer/import-onnx-main.cpp | 103 ++ 7 files changed, 1409 insertions(+) create mode 100644 projects/onnx_c_importer/CMakeLists.txt create mode 100644 projects/onnx_c_importer/OnnxImporter.cpp create mode 100644 projects/onnx_c_importer/OnnxImporter.h create mode 100644 projects/onnx_c_importer/README.md create mode 100644 projects/onnx_c_importer/import-onnx-main.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index f821d60034c8..376aea80eea3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -51,6 +51,8 @@ option(TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS "Enables PyTorch native extension fe cmake_dependent_option(TORCH_MLIR_ENABLE_JIT_IR_IMPORTER "Enables JIT IR Importer" ON TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS OFF) cmake_dependent_option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF TORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS OFF) +option(TORCH_MLIR_ENABLE_ONNX_C_IMPORTER "Enables the ONNX C importer" OFF) + #------------------------------------------------------------------------------- # Configure out-of-tree vs in-tree build #------------------------------------------------------------------------------- diff --git a/projects/CMakeLists.txt b/projects/CMakeLists.txt index d4fead890269..ea7e34593aba 100644 --- a/projects/CMakeLists.txt +++ b/projects/CMakeLists.txt @@ -1,5 +1,9 @@ include(AddMLIRPython) +if(TORCH_MLIR_ENABLE_ONNX_C_IMPORTER) + add_subdirectory(onnx_c_importer) +endif() + ################################################################################ # PyTorch # Configure PyTorch if we have any features enabled which require it. diff --git a/projects/onnx_c_importer/CMakeLists.txt b/projects/onnx_c_importer/CMakeLists.txt new file mode 100644 index 000000000000..b685c732f5dc --- /dev/null +++ b/projects/onnx_c_importer/CMakeLists.txt @@ -0,0 +1,44 @@ +message(STATUS "Enabling onnx_c_importer...") + +include(FetchContent) + +find_package(Protobuf) +if(NOT Protobuf_FOUND) + message(FATAL_ERROR + "In order to build C ONNX support, the Protobuf package must be installed " + "on the system. Without this ONNX will attempt to build it in the project " + "and the dependent ABSEIL build system is incompatible. " + "On Ubuntu, install with: " + "apt install libprotobuf-dev protobuf-compiler\n\n" + "(or this entire component can be disabled with " + "-DTORCH_MLIR_ENABLE_ONNX_C_IMPORTER=OFF)") +endif() + +option(ONNX_DISABLE_EXCEPTIONS "For compatibility with LLVM build" ON) + +FetchContent_Declare( + onnx + EXCLUDE_FROM_ALL + GIT_REPOSITORY https://github.com/onnx/onnx.git + GIT_TAG v1.15.0 + GIT_SHALLOW ON + GIT_PROGRESS ON +) +FetchContent_MakeAvailable(onnx) + +add_llvm_executable( + torch-mlir-import-onnx + PARTIAL_SOURCES_INTENDED + + import-onnx-main.cpp + OnnxImporter.h + OnnxImporter.cpp +) + +target_link_libraries( + torch-mlir-import-onnx + LLVMSupport + MLIRCAPIIR + TorchMLIRCAPI + onnx +) diff --git a/projects/onnx_c_importer/OnnxImporter.cpp b/projects/onnx_c_importer/OnnxImporter.cpp new file mode 100644 index 000000000000..4a61a2800ca5 --- /dev/null +++ b/projects/onnx_c_importer/OnnxImporter.cpp @@ -0,0 +1,1009 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "OnnxImporter.h" + +#include "mlir-c/BuiltinAttributes.h" +#include "mlir-c/BuiltinTypes.h" + +#include +#include + +using namespace torch_mlir_onnx; + +namespace { + +std::string SanitizeNameAsIdentifier(std::string_view in) { + std::string out; + if (!in.empty() && !std::isalnum(in.front())) { + out.append("_"); + } + out.append(in); + for (char &c : out) { + if (c == ':' || c == '/') + c = '_'; + } + return out; +} + +template +void AppendDelimittedStrings(std::string &into, T &container) { + bool first = true; + for (auto &item : container) { + if (first) { + first = false; + } else { + into.append(", "); + } + into.append(item); + } +} + +inline MlirStringRef toMlirStringRef(const std::string_view &s) { + return mlirStringRefCreate(s.data(), s.size()); +} + +inline MlirStringRef toMlirStringRef(const std::string &s) { + return mlirStringRefCreate(s.data(), s.size()); +} + +inline MlirStringRef toMlirStringRef(const char *s) { + return mlirStringRefCreate(s, std::strlen(s)); +} + +inline MlirNamedAttribute toMlirNamedAttribute(const char *s, + MlirAttribute attr) { + MlirContext context = mlirAttributeGetContext(attr); + MlirIdentifier ident = mlirIdentifierGet(context, toMlirStringRef(s)); + return mlirNamedAttributeGet(ident, attr); +} + +std::string getMlirAsm(MlirType t) { + std::string result; + mlirTypePrint( + t, + +[](MlirStringRef sr, void *userData) { + std::string *s = static_cast(userData); + s->append(sr.data, sr.length); + }, + static_cast(&result)); + return result; +} + +// C++ helpers to create operations. +void addToMlirOperationState(MlirOperationState &state, + MlirNamedAttribute namedAttr) { + mlirOperationStateAddAttributes(&state, 1, &namedAttr); +} + +void addToMlirOperationState( + MlirOperationState &state, + std::vector> &attrs) { + for (auto &p : attrs) { + addToMlirOperationState(state, + toMlirNamedAttribute(p.first.c_str(), p.second)); + } +} + +void addToMlirOperationState(MlirOperationState &state, MlirRegion region) { + mlirOperationStateAddOwnedRegions(&state, 1, ®ion); +} + +[[maybe_unused]] void addToMlirOperationState(MlirOperationState &state, + MlirValue value) { + mlirOperationStateAddOperands(&state, 1, &value); +} + +void addToMlirOperationState(MlirOperationState &state, + const std::vector &values) { + mlirOperationStateAddOperands(&state, values.size(), values.data()); +} + +void addToMlirOperationState(MlirOperationState &state, MlirType resultType) { + mlirOperationStateAddResults(&state, 1, &resultType); +} + +void addToMlirOperationState(MlirOperationState &state, + const std::vector &resultTypes) { + mlirOperationStateAddResults(&state, resultTypes.size(), resultTypes.data()); +} + +[[maybe_unused]] void addToMlirOperationState(MlirOperationState &state) {} + +template +void addToMlirOperationState(MlirOperationState &state, T &&t, U &&u, + Ts &&...ts) { + addToMlirOperationState(state, std::forward(t)); + addToMlirOperationState(state, std::forward(u), std::forward(ts)...); +} + +template +MlirOperation createMlirOperation(std::string name, MlirLocation loc, + Ts &&...ts) { + MlirOperationState state = mlirOperationStateGet(toMlirStringRef(name), loc); + addToMlirOperationState(state, std::forward(ts)...); + return mlirOperationCreate(&state); +} + +template +MlirOperation createMlirOperationAtEnd(MlirBlock block, std::string name, + MlirLocation loc, Ts &&...ts) { + MlirOperation operation = + createMlirOperation(name, loc, std::forward(ts)...); + mlirBlockInsertOwnedOperationBefore(block, mlirBlockGetTerminator(block), + operation); + return operation; +} + +} // namespace + +// ---------------------------------------------------------------------------// +// ModelInfo +// ---------------------------------------------------------------------------// + +ModelInfo::ModelInfo() = default; + +void ModelInfo::DebugDumpProto() { + std::string debug_string = model_proto_.DebugString(); + fprintf(stderr, "%s\n", debug_string.c_str()); +} + +Status ModelInfo::Initialize() { + if (!model_proto_.has_graph()) { + return SetError("ONNX ModelProto has no main graph"); + } + main_graph_ = std::make_unique(*this, model_proto_.graph()); + if (failed(main_graph_->Initialize())) { + return failure(); + } + + return success(); +} + +// ---------------------------------------------------------------------------// +// GraphInfo +// ---------------------------------------------------------------------------// + +Status GraphInfo::Initialize() { + // Initialize look up tables. + for (const onnx::TensorProto &t : graph_proto_.initializer()) { + initializer_map_.emplace(t.name(), t); + } + for (const onnx::ValueInfoProto &v : graph_proto_.value_info()) { + value_info_map_.emplace(v.name(), v); + } + for (const onnx::ValueInfoProto &v : graph_proto_.input()) { + declared_inputs_.emplace_back(&v); + } + for (const onnx::ValueInfoProto &v : graph_proto_.output()) { + outputs_.emplace_back(&v); + } + + // Generate the effective input map, which for old models can be a subset of + // the input map. + if (model_info_.config().elide_initialized_inputs) { + // Default. Add declared inputs to the input map unless if they appear + // as an initializer. + for (const onnx::ValueInfoProto *it : declared_inputs_) { + std::string_view key = it->name(); + if (initializer_map_.find(key) != initializer_map_.end()) { + // In initializers. Skip. + continue; + } + inputs_.emplace_back(it); + } + } else { + // Fallback for some legacy compatibility. + inputs_ = declared_inputs_; + std::vector illegal_keys; + for (const onnx::ValueInfoProto *it : inputs_) { + std::string_view key = it->name(); + if (initializer_map_.find(key) != initializer_map_.end()) { + illegal_keys.push_back(key); + } + } + if (!illegal_keys.empty()) { + std::string error = "When not in elide_initialized_inputs=true mode, we " + "expect inputs to not have an initial value (got "; + AppendDelimittedStrings(error, illegal_keys); + error.append(")"); + return model_info_.SetError(std::move(error)); + } + } + + // Index the inputs and outputs. + for (auto *input : inputs_) { + input_map_.emplace(input->name(), *input); + } + for (auto *output : outputs_) { + output_map_.emplace(output->name(), *output); + } + return success(); +} + +const onnx::TypeProto *GraphInfo::FindTypeProtoForName(std::string_view name) { + // Node outputs don't typically have type information, but shape inference + // will associate them in the value_info. If not there, it may be a + // graph output, which must have type information. + { + auto it = value_info_map_.find(name); + if (it != value_info_map_.end()) { + return &it->second.type(); + } + } + { + auto it = output_map_.find(name); + if (it != output_map_.end()) { + return &it->second.type(); + } + } + + std::string msg = "No type information associated with '"; + msg.append(name); + msg.append("'. Run shape inference?"); + model_info_.SetError(std::move(msg)); + return nullptr; +} + +// ---------------------------------------------------------------------------// +// ContextCache +// ---------------------------------------------------------------------------// + +MlirType ContextCache::ConvertTypeProto(const onnx::TypeProto &tp) { + if (tp.has_tensor_type()) { + // Convert Tensor TypeProto. + const onnx::TypeProto_Tensor &tt = tp.tensor_type(); + if (!tt.has_shape()) { + std::string msg = + "Unsupported Tensor type without shape (run shape inference?): "; + msg.append(tt.DebugString()); + model_info_.SetError(std::move(msg)); + return {nullptr}; + } + + MlirType element_type = ConvertTensorElementType(tt.elem_type()); + if (mlirTypeIsNull(element_type)) { + return {nullptr}; + } + shared_dims_.clear(); + shared_dims_.reserve(6); + for (const onnx::TensorShapeProto::Dimension &dim : tt.shape().dim()) { + if (dim.has_dim_value()) { + // Static. + shared_dims_.push_back(dim.dim_value()); + } else { + // Dynamic. + shared_dims_.push_back(-1); + } + } + + return GetVtensorType(shared_dims_, element_type); + } else { + std::string msg = "Unsupported ONNX TypeProto: "; + msg.append(tp.DebugString()); + model_info_.SetError(std::move(msg)); + return {nullptr}; + } +} + +MlirType ContextCache::ConvertTensorElementType(int elem_type) { + auto it = elem_type_map_.find(elem_type); + if (it != elem_type_map_.end()) { + return it->second; + } + + MlirType t = {nullptr}; + switch (elem_type) { + case onnx::TensorProto::FLOAT: + t = mlirF32TypeGet(context_); + break; + case onnx::TensorProto::UINT8: + t = mlirIntegerTypeUnsignedGet(context_, 8); + break; + case onnx::TensorProto::INT8: + t = mlirIntegerTypeSignedGet(context_, 8); + break; + case onnx::TensorProto::UINT16: + t = mlirIntegerTypeUnsignedGet(context_, 16); + break; + case onnx::TensorProto::INT16: + t = mlirIntegerTypeSignedGet(context_, 16); + break; + case onnx::TensorProto::INT32: + t = mlirIntegerTypeSignedGet(context_, 32); + break; + case onnx::TensorProto::UINT32: + t = mlirIntegerTypeUnsignedGet(context_, 32); + break; + case onnx::TensorProto::INT64: + t = mlirIntegerTypeSignedGet(context_, 64); + break; + case onnx::TensorProto::UINT64: + t = mlirIntegerTypeUnsignedGet(context_, 64); + break; + case onnx::TensorProto::BOOL: + t = mlirIntegerTypeGet(context_, 1); + break; + case onnx::TensorProto::FLOAT16: + t = mlirF16TypeGet(context_); + break; + case onnx::TensorProto::DOUBLE: + t = mlirF64TypeGet(context_); + break; + case onnx::TensorProto::COMPLEX64: + t = mlirComplexTypeGet(mlirF32TypeGet(context_)); + break; + case onnx::TensorProto::COMPLEX128: + t = mlirComplexTypeGet(mlirF64TypeGet(context_)); + break; + case onnx::TensorProto::BFLOAT16: + t = mlirBF16TypeGet(context_); + break; + case onnx::TensorProto::FLOAT8E4M3FN: + t = mlirFloat8E4M3FNTypeGet(context_); + break; + case onnx::TensorProto::FLOAT8E4M3FNUZ: + t = mlirFloat8E4M3FNUZTypeGet(context_); + break; + case onnx::TensorProto::FLOAT8E5M2: + t = mlirFloat8E5M2TypeGet(context_); + break; + case onnx::TensorProto::FLOAT8E5M2FNUZ: + t = mlirFloat8E5M2FNUZTypeGet(context_); + break; + default: { + std::string msg = "Unknown ONNX tensor element type: "; + msg.append(std::to_string(elem_type)); + model_info_.SetError(std::move(msg)); + return {nullptr}; + } + } + + assert(t.ptr && "did not convert type"); + elem_type_map_[elem_type] = t; + return t; +} + +MlirAttribute +ContextCache::ConvertTensorProtoToAttr(const onnx::TensorProto &tp) { + MlirType tensor_type = ConvertTensorProtoToBuiltinType(tp); + if (tp.has_raw_data()) { + std::string sanitized_name = SanitizeNameAsIdentifier(tp.name()); + // Conveniently, DenseResourceElementsAttr shares the raw data + // format. We just give it maximum numeric alignment. + return mlirUnmanagedDenseResourceElementsAttrGet( + tensor_type, toMlirStringRef(sanitized_name), + const_cast(static_cast(tp.raw_data().data())), + tp.raw_data().size(), /*dataAlignment=*/8, /*dataIsMutable=*/false, + /*deleter=*/nullptr, /*userData=*/nullptr); + } else { + switch (tp.data_type()) { + case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT: + return mlirDenseElementsAttrFloatGet(tensor_type, tp.float_data_size(), + tp.float_data().data()); + case onnx::TensorProto::DataType::TensorProto_DataType_INT32: + return mlirDenseElementsAttrInt32Get(tensor_type, tp.int32_data_size(), + tp.int32_data().data()); + case onnx::TensorProto::DataType::TensorProto_DataType_INT64: + return mlirDenseElementsAttrInt64Get(tensor_type, tp.int64_data_size(), + tp.int64_data().data()); + case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE: + return mlirDenseElementsAttrDoubleGet(tensor_type, tp.double_data_size(), + tp.double_data().data()); + case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: { + // Special case. See proto. Someone apparently got lazy. + std::vector stupid_conversion; + stupid_conversion.reserve(tp.uint64_data_size()); + for (uint64_t v : tp.uint64_data()) + stupid_conversion.push_back(v); + return mlirDenseElementsAttrUInt32Get( + tensor_type, stupid_conversion.size(), stupid_conversion.data()); + } + case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: + return mlirDenseElementsAttrUInt64Get(tensor_type, tp.uint64_data_size(), + tp.uint64_data().data()); + } + } + + std::string message = + "Unable to convert ONNX TensorProto to MLIR attribute: "; + message.append(tp.DebugString()); + model_info_.SetError(std::move(message)); + return {nullptr}; +} + +MlirType +ContextCache::ConvertTensorProtoToBuiltinType(const onnx::TensorProto &tp) { + MlirType element_type = ConvertTensorElementType(tp.data_type()); + if (mlirTypeIsNull(element_type)) + return {nullptr}; + + shared_dims_.clear(); + for (auto dim : tp.dims()) { + shared_dims_.push_back(dim); + } + return mlirRankedTensorTypeGet(shared_dims_.size(), shared_dims_.data(), + element_type, + /*encoding=*/{nullptr}); +} + +MlirType +ContextCache::ConvertTensorProtoToVtensorType(const onnx::TensorProto &tp) { + MlirType element_type = ConvertTensorElementType(tp.data_type()); + if (mlirTypeIsNull(element_type)) + return {nullptr}; + + shared_dims_.clear(); + for (auto dim : tp.dims()) { + shared_dims_.push_back(dim); + } + + return GetVtensorType(shared_dims_, element_type); +} + +MlirType ContextCache::GetVtensorType(const std::vector &dims, + MlirType element_type) { + std::string type_asm = "!torch.vtensor<["; + // Add dimension list. + bool first_dim = true; + for (int dim : dims) { + if (first_dim) + first_dim = false; + else + type_asm.push_back(','); + if (dim < 0) + type_asm.push_back('?'); + else + type_asm.append(std::to_string(dim)); + } + type_asm.append("],"); + + // Add element type. + type_asm.append(getMlirAsm(element_type)); + type_asm.push_back('>'); + + // Look in cache. + auto found_it = asm_type_map_.find(type_asm); + if (found_it != asm_type_map_.end()) { + return found_it->second; + } + + // Parse. + MlirType t = mlirTypeParseGet(context_, toMlirStringRef(type_asm)); + if (mlirTypeIsNull(t)) { + std::string message = + "internal error: could not parse !torch.vtensor type: "; + message.append(type_asm); + model_info_.SetError(std::move(message)); + return t; + } + + asm_type_map_[std::move(type_asm)] = t; + return t; +} + +// ---------------------------------------------------------------------------// +// NodeImporter +// ---------------------------------------------------------------------------// + +NodeImporter::NodeImporter(GraphInfo &graph_info, ContextCache &cc, + MlirOperation module_op) + : graph_info_(graph_info), cc_(cc), + context_(mlirOperationGetContext(module_op)), module_op_(module_op), + func_op_({nullptr}), body_block_({nullptr}) { + std::string locName = "graph:"; + locName.append(graph_info.graph_proto().name()); + default_loc_ = mlirLocationNameGet(context_, toMlirStringRef(locName), + /*childLoc=*/{nullptr}); +} + +Status NodeImporter::DefineFunction(std::optional name) { + const onnx::GraphProto &p = graph_info_.graph_proto(); + MlirRegion moduleBodyRegion = mlirOperationGetRegion(module_op_, 0); + MlirBlock moduleBody = mlirRegionGetFirstBlock(moduleBodyRegion); + MlirAttribute nameAttr; + if (name) { + // Explicitly named. + nameAttr = mlirStringAttrGet(context_, toMlirStringRef(*name)); + } else { + // Name it according to the graph. + nameAttr = mlirStringAttrGet(context_, toMlirStringRef(p.name())); + } + + // Derive the FunctionType. + std::vector input_types; + std::vector input_locs; + std::vector output_types; + for (auto *input : graph_info_.inputs()) { + MlirType t = cc_.ConvertTypeProto(input->type()); + if (mlirTypeIsNull(t)) { + return failure(); + } + input_types.push_back(t); + input_locs.push_back(default_loc_); + } + for (auto *output : graph_info_.outputs()) { + MlirType t = cc_.ConvertTypeProto(output->type()); + if (mlirTypeIsNull(t)) { + return failure(); + } + output_types.push_back(t); + } + MlirType ftype = + mlirFunctionTypeGet(context_, input_types.size(), input_types.data(), + output_types.size(), output_types.data()); + + // Create func.func. + func_op_ = createMlirOperationAtEnd( + moduleBody, "func.func", default_loc_, mlirRegionCreate(), + toMlirNamedAttribute("function_type", mlirTypeAttrGet(ftype)), + toMlirNamedAttribute("sym_name", nameAttr)); + + // Add entry block. + body_block_ = mlirBlockCreate(input_types.size(), input_types.data(), + input_locs.data()); + MlirRegion bodyRegion = mlirOperationGetRegion(func_op_, 0); + mlirRegionAppendOwnedBlock(bodyRegion, body_block_); + + // Map the block args to names and store for evaluation. + for (int i = 0, e = graph_info_.inputs().size(); i < e; ++i) { + std::string_view name = graph_info_.inputs()[i]->name(); + MlirValue value = mlirBlockGetArgument(body_block_, i); + nv_map_[name] = value; + } + + PopulateGraphAttrs(func_op_); + return success(); +} + +void NodeImporter::PopulateGraphAttrs(MlirOperation container_op) { + const onnx::ModelProto &m = graph_info_.model_info().model_proto(); + MlirType i64_type = mlirIntegerTypeSignedGet(context_, 64); + int default_opset_version = 0; + std::unordered_map opset_versions; + // Determine model level opset versions. + for (const onnx::OperatorSetIdProto &opset_import : m.opset_import()) { + if (opset_import.has_domain()) { + opset_versions[opset_import.domain()] = + mlirIntegerAttrGet(i64_type, opset_import.version()); + } else { + default_opset_version = opset_import.version(); + } + } + + // Set the default domain version. + if (default_opset_version != 0) { + mlirOperationSetDiscardableAttributeByName( + container_op, toMlirStringRef("torch.onnx_meta.opset_version"), + mlirIntegerAttrGet(i64_type, default_opset_version)); + } + + // Set versions for other domains. + if (!opset_versions.empty()) { + std::vector version_attrs; + for (auto it : opset_versions) { + version_attrs.push_back(mlirNamedAttributeGet( + mlirIdentifierGet(context_, toMlirStringRef(it.first)), it.second)); + } + MlirAttribute dict_attr = mlirDictionaryAttrGet( + context_, version_attrs.size(), version_attrs.data()); + mlirOperationSetDiscardableAttributeByName( + container_op, toMlirStringRef("torch.onnx_meta.opset_versions"), + dict_attr); + } + + // IR version and producer. + mlirOperationSetDiscardableAttributeByName( + container_op, toMlirStringRef("torch.onnx_meta.ir_version"), + mlirIntegerAttrGet(i64_type, m.ir_version())); + mlirOperationSetDiscardableAttributeByName( + container_op, toMlirStringRef("torch.onnx_meta.producer_name"), + mlirStringAttrGet(context_, toMlirStringRef(m.producer_name()))); + mlirOperationSetDiscardableAttributeByName( + container_op, toMlirStringRef("torch.onnx_meta.producer_version"), + mlirStringAttrGet(context_, toMlirStringRef(m.producer_version()))); +} + +Status NodeImporter::ImportAll() { + // TODO: Consider pulling in initializers on demand since there can be so + // much unused crap. + for (auto it : graph_info_.initializer_map()) { + if (failed(ImportInitializer(it.second))) + return failure(); + } + for (auto it : graph_info_.graph_proto().node()) { + if (failed(ImportNode(it))) + return failure(); + } + + // Lookup the outputs, which should all be in the nv_map if the graph was + // properly formed. + std::vector output_values; + for (const auto *output : graph_info_.outputs()) { + std::string_view name = output->name(); + auto found_it = nv_map_.find(name); + if (found_it == nv_map_.end()) { + std::string msg = "Non topologically produced ONNX graph output '"; + msg.append(name); + msg.append("'"); + return SetError(std::move(msg)); + } + output_values.push_back(found_it->second); + } + + createMlirOperationAtEnd(body_block_, "func.return", default_loc_, + output_values); + return success(); +} + +Status NodeImporter::ImportInitializer(const onnx::TensorProto &initializer) { + std::string_view name = initializer.name(); + MlirLocation loc = mlirLocationNameGet(context_, toMlirStringRef(name), + /*childLoc=*/{nullptr}); + + MlirAttribute value_attr = cc_.ConvertTensorProtoToAttr(initializer); + MlirType vtensor_type = cc_.ConvertTensorProtoToVtensorType(initializer); + if (mlirAttributeIsNull(value_attr) || mlirTypeIsNull(vtensor_type)) + return failure(); + + MlirOperation op = createMlirOperationAtEnd( + body_block_, "torch.vtensor.literal", loc, vtensor_type, + toMlirNamedAttribute("value", value_attr)); + MlirValue result = mlirOperationGetResult(op, 0); + + auto inserted = nv_map_.insert(std::make_pair(name, result)); + if (!inserted.second) { + std::string msg = "Multiple nodes produced a value for '"; + msg.append(name); + msg.append("', most recent from "); + msg.append(initializer.DebugString()); + return SetError(std::move(msg)); + } + + return success(); +} + +Status NodeImporter::ImportNode(const onnx::NodeProto &node) { + std::string_view op_type = node.op_type(); + // Handle special-form op types that do not go down the generic path. + if (op_type == "ConstantOfShape") { + return ImportConstantOfShapeNode(node); + } + + return ImportGeneralNode(node); +} + +Status NodeImporter::ImportGeneralNode(const onnx::NodeProto &node) { + MlirLocation loc = mlirLocationNameGet(context_, toMlirStringRef(node.name()), + /*childLoc=*/{nullptr}); + + // Map inputs to values. + std::vector input_values; + for (auto &input_name : node.input()) { + auto found_it = nv_map_.find(input_name); + if (found_it == nv_map_.end()) { + std::string msg = "Non topologically produced ONNX node input '"; + msg.append(input_name); + msg.append("'"); + return SetError(std::move(msg)); + } + input_values.push_back(found_it->second); + } + + // Map outputs to types. + std::vector output_types; + for (auto &output_name : node.output()) { + const onnx::TypeProto *type_proto = + graph_info_.FindTypeProtoForName(output_name); + if (!type_proto) + return failure(); + + MlirType t = cc_.ConvertTypeProto(*type_proto); + if (mlirTypeIsNull(t)) + return failure(); + output_types.push_back(t); + } + + // Derive the op name. + std::string op_name = "onnx."; + op_name.append(node.op_type()); + MlirAttribute op_name_attr = + mlirStringAttrGet(context_, toMlirStringRef(op_name)); + + // General attributes. + std::vector> general_attributes; + for (auto &onnx_attr : node.attribute()) { + MlirAttribute attr = ImportGeneralAttribute(onnx_attr); + if (mlirAttributeIsNull(attr)) + return failure(); + std::string full_name = "torch.onnx."; + full_name.append(onnx_attr.name()); + general_attributes.push_back(std::make_pair(full_name, attr)); + } + + // Create op. + MlirOperation op = createMlirOperationAtEnd( + body_block_, "torch.operator", loc, output_types, input_values, + toMlirNamedAttribute("name", op_name_attr), general_attributes); + + // Record the result values. + for (int i = 0, e = output_types.size(); i < e; ++i) { + MlirValue result = mlirOperationGetResult(op, i); + std::string_view name = node.output(i); + auto inserted = nv_map_.insert(std::make_pair(name, result)); + if (!inserted.second) { + std::string msg = "Multiple nodes produced a value for '"; + msg.append(name); + msg.append("', most recent from "); + msg.append(node.DebugString()); + return SetError(std::move(msg)); + } + } + + return success(); +} + +MlirAttribute +NodeImporter::ImportGeneralAttribute(const onnx::AttributeProto &onnx_attr) { + switch (onnx_attr.type()) { + case onnx::AttributeProto::UNDEFINED: + SetError("'UNDEFINED' attribute type not supported"); + return {nullptr}; + case onnx::AttributeProto::FLOAT: + return mlirFloatAttrDoubleGet(context_, mlirF32TypeGet(context_), + onnx_attr.f()); + case onnx::AttributeProto::INT: + return mlirIntegerAttrGet(mlirIntegerTypeSignedGet(context_, 64), + onnx_attr.i()); + case onnx::AttributeProto::STRING: + return mlirStringAttrGet(context_, toMlirStringRef(onnx_attr.s())); + case onnx::AttributeProto::TENSOR: + return cc_.ConvertTensorProtoToAttr(onnx_attr.t()); + case onnx::AttributeProto::GRAPH: + SetError("'GRAPH' attribute type not supported on this node"); + return {nullptr}; + case onnx::AttributeProto::SPARSE_TENSOR: + SetError("'SPARSE_TENSOR' attribute type not supported on this node"); + return {nullptr}; + case onnx::AttributeProto::TYPE_PROTO: + SetError("'TYPE_PROTO' attribute type not supported on this node"); + return {nullptr}; + case onnx::AttributeProto::FLOATS: { + std::vector attrs; + for (auto f : onnx_attr.floats()) + attrs.push_back( + mlirFloatAttrDoubleGet(context_, mlirF32TypeGet(context_), f)); + return mlirArrayAttrGet(context_, attrs.size(), attrs.data()); + } + case onnx::AttributeProto::INTS: { + std::vector attrs; + for (auto i : onnx_attr.ints()) + attrs.push_back( + mlirIntegerAttrGet(mlirIntegerTypeSignedGet(context_, 64), i)); + return mlirArrayAttrGet(context_, attrs.size(), attrs.data()); + } + case onnx::AttributeProto::STRINGS: { + std::vector attrs; + for (auto s : onnx_attr.strings()) + attrs.push_back(mlirStringAttrGet(context_, toMlirStringRef(s))); + return mlirArrayAttrGet(context_, attrs.size(), attrs.data()); + } + case onnx::AttributeProto::TENSORS: { + std::vector attrs; + for (auto &t : onnx_attr.tensors()) { + MlirAttribute attr = cc_.ConvertTensorProtoToAttr(t); + if (mlirAttributeIsNull(attr)) + return {nullptr}; + attrs.push_back(attr); + } + return mlirArrayAttrGet(context_, attrs.size(), attrs.data()); + } + case onnx::AttributeProto::GRAPHS: + SetError("'GRAPHS' attribute type not supported on this node"); + return {nullptr}; + case onnx::AttributeProto::SPARSE_TENSORS: + SetError("'SPARSE_TENSORS' attribute type not supported on this node"); + return {nullptr}; + case onnx::AttributeProto::TYPE_PROTOS: + SetError("'TYPE_PROTOS' attribute type not supported on this node"); + return {nullptr}; + } + + std::string msg = "Unhandled ONNX attribute type code "; + msg.append(std::to_string(onnx_attr.type())); + msg.append(": "); + msg.append(onnx_attr.DebugString()); + SetError(std::move(msg)); + return {nullptr}; +} + +Status NodeImporter::ImportConstantOfShapeNode(const onnx::NodeProto &node) { + std::string_view name = node.name(); + MlirLocation loc = mlirLocationNameGet(context_, toMlirStringRef(name), + /*childLoc=*/{nullptr}); + + // This op is special: It has an input of the shape, and in full generality + // could involve eager production of constants of variable size. In + // practice, the DNN profile for ONNX makes this very difficult to do + // and we hard-assert that the input can be resolved to an immediate + // value. + if (node.input_size() != 1 || node.output_size() != 1) { + return SetError("ConstantOfShape node must have one input and output"); + } + + // Shape. + std::vector shape; + if (failed(GetImmediateShapeTensor(node.input(0), shape))) + return failure(); + + // Value. + const onnx::AttributeProto *value_proto = nullptr; + for (auto &attr : node.attribute()) { + if (attr.name() == "value") { + value_proto = &attr; + break; + } + } + if (!value_proto) { + return SetError("ConstantOfShape node must have a 'value' attribute"); + } + if (value_proto->type() != onnx::AttributeProto_AttributeType_TENSOR) { + return SetError("ConstantOfShape node must have a tensor value attribute"); + } + + // Create the splat. + const onnx::TensorProto &tensor_proto = value_proto->t(); + if (tensor_proto.dims_size() != 1 || tensor_proto.dims(0) != 1) { + return SetError("ConstantOfShape node expected a scalar tensor value"); + } + auto tensorTypeFor = [&](MlirType element_type) { + return mlirRankedTensorTypeGet(shape.size(), shape.data(), element_type, + /*encoding*/ {nullptr}); + }; + MlirAttribute splat_attr = {nullptr}; + switch (tensor_proto.data_type()) { + case onnx::TensorProto::DataType::TensorProto_DataType_FLOAT: + splat_attr = mlirDenseElementsAttrFloatSplatGet( + tensorTypeFor(mlirF32TypeGet(context_)), tensor_proto.float_data(0)); + break; + case onnx::TensorProto::DataType::TensorProto_DataType_INT32: + splat_attr = mlirDenseElementsAttrFloatSplatGet( + tensorTypeFor(mlirIntegerTypeSignedGet(context_, 32)), + tensor_proto.int32_data(0)); + break; + case onnx::TensorProto::DataType::TensorProto_DataType_INT64: + splat_attr = mlirDenseElementsAttrFloatSplatGet( + tensorTypeFor(mlirIntegerTypeSignedGet(context_, 64)), + tensor_proto.int64_data(0)); + break; + case onnx::TensorProto::DataType::TensorProto_DataType_DOUBLE: + splat_attr = mlirDenseElementsAttrFloatSplatGet( + tensorTypeFor(mlirF64TypeGet(context_)), tensor_proto.double_data(0)); + break; + case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: + splat_attr = mlirDenseElementsAttrFloatSplatGet( + tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 64)), + tensor_proto.uint64_data(0)); + break; + case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: + // Special case: inline data is stored in uint64. + splat_attr = mlirDenseElementsAttrFloatSplatGet( + tensorTypeFor(mlirIntegerTypeUnsignedGet(context_, 32)), + tensor_proto.uint64_data(0)); + break; + } + + if (mlirAttributeIsNull(splat_attr)) { + std::string message = + "ConstantOfShape node has an unsupported splat data type: "; + message.append(tensor_proto.DebugString()); + return SetError(std::move(message)); + } + + // Create the vtensor type for the result. + MlirType splat_type = mlirAttributeGetType(splat_attr); + MlirType element_type = mlirShapedTypeGetElementType(splat_type); + MlirType vtensor_type = cc_.GetVtensorType(shape, element_type); + if (mlirTypeIsNull(vtensor_type)) + return failure(); + + MlirOperation op = createMlirOperationAtEnd( + body_block_, "torch.vtensor.literal", loc, vtensor_type, + toMlirNamedAttribute("value", splat_attr)); + MlirValue result = mlirOperationGetResult(op, 0); + + // Export to the nv_map. + auto inserted = nv_map_.insert(std::make_pair(name, result)); + if (!inserted.second) { + std::string msg = "Multiple nodes produced a value for '"; + msg.append(name); + msg.append("', most recent from "); + msg.append(node.DebugString()); + return SetError(std::move(msg)); + } + + return success(); +} + +Status NodeImporter::GetImmediateShapeTensor(const std::string &name, + std::vector &shape) { + auto found_it = graph_info_.initializer_map().find(name); + if (found_it == graph_info_.initializer_map().end()) { + std::string message = "An immediate shape value for '"; + message.append(name); + message.append("' was required but it is dynamically produced"); + return SetError(std::move(message)); + } + + const onnx::TensorProto &tp = found_it->second; + shape.clear(); + + // Since this is being interpreted as a shape, we only support some limited + // types. + size_t raw_data_size; + switch (tp.data_type()) { + case onnx::TensorProto::DataType::TensorProto_DataType_INT32: { + auto *raw_data = graph_info_.GetOptionalRawData(tp, raw_data_size); + if (raw_data) { + std::copy(raw_data, raw_data + raw_data_size, std::back_inserter(shape)); + } else { + for (auto v : tp.int32_data()) + shape.push_back(v); + } + return success(); + } + case onnx::TensorProto::DataType::TensorProto_DataType_INT64: { + auto *raw_data = graph_info_.GetOptionalRawData(tp, raw_data_size); + if (raw_data) { + std::copy(raw_data, raw_data + raw_data_size, std::back_inserter(shape)); + } else { + for (auto v : tp.int64_data()) + shape.push_back(v); + } + return success(); + } + case onnx::TensorProto::DataType::TensorProto_DataType_UINT32: { + auto *raw_data = + graph_info_.GetOptionalRawData(tp, raw_data_size); + if (raw_data) { + std::copy(raw_data, raw_data + raw_data_size, std::back_inserter(shape)); + } else { + // Stupid special case: stored in uint64. + for (auto v : tp.uint64_data()) + shape.push_back(v); + } + return success(); + } + case onnx::TensorProto::DataType::TensorProto_DataType_UINT64: { + auto *raw_data = + graph_info_.GetOptionalRawData(tp, raw_data_size); + if (raw_data) { + std::copy(raw_data, raw_data + raw_data_size, std::back_inserter(shape)); + } else { + for (auto v : tp.uint64_data()) + shape.push_back(v); + } + return success(); + } + } + + { + std::string message = + "An immediate shape value could not be converted from TensorProto: "; + message.append(tp.DebugString()); + return SetError(std::move(message)); + } +} + +void NodeImporter::DebugDumpModule() { + auto callback = +[](MlirStringRef sr, void *) { + fwrite(sr.data, sizeof(char), sr.length, stderr); + }; + mlirOperationPrint(module_op_, callback, nullptr); +} diff --git a/projects/onnx_c_importer/OnnxImporter.h b/projects/onnx_c_importer/OnnxImporter.h new file mode 100644 index 000000000000..57070e0e5f2a --- /dev/null +++ b/projects/onnx_c_importer/OnnxImporter.h @@ -0,0 +1,240 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +// Stand-alone ONNX -> MLIR importer. +// This library only depends on ONNX (and transitively protobuf, of course) +// and the MLIR C API. It does this to minimize its dependency surface area +// and make it possible to integrate as source code into other systems while +// retaining this implementation as the source of truth. +// +// It uses a hybrid of LLVM and Google C++ coding style, preferring the latter +// for class members/accessors because canonical protobuf coding presumes +// this kind of style. + +#include "mlir-c/IR.h" +#include "onnx/onnx_pb.h" + +#include +#include +#include + +namespace torch_mlir_onnx { + +struct Config; +class GraphInfo; +class ModelInfo; + +struct Config { + // Ancient ONNX exporters would often add a model input for anything that + // might be mutable, providing an initializer for it as well. More modern + // tools tools realized this is a really bad idea for a lot of reasons. + // We choose to assume more recent norms, even if encountering older + // models. Setting this to False probably won't do what you want but + // should produce interesting errors to waste your time deciphering. + // We mainly use it as a way to document in the code that we are + // making an assumption. + bool elide_initialized_inputs = true; +}; + +/// A light-weight status. It only encapsulates success/failure. +/// Full error information will be set on the ModelInfo. +class Status { +public: + static Status success(bool isSuccess = true) { return Status(isSuccess); } + static Status failure(bool isFailure = true) { return Status(!isFailure); } + + bool is_success() { return is_success_; } + +private: + Status(bool is_success) : is_success_(is_success) {} + bool is_success_; +}; + +static inline Status success() { return Status::success(); } +static inline Status failure() { return Status::failure(); } +static inline bool succeeded(Status status) { return status.is_success(); } +static inline bool failed(Status status) { return !status.is_success(); } + +// Accounting for a GraphProto. +class GraphInfo { +public: + GraphInfo(ModelInfo &model_info, const onnx::GraphProto &graph_proto) + : model_info_(model_info), graph_proto_(graph_proto) {} + ModelInfo &model_info() { return model_info_; } + const onnx::GraphProto &graph_proto() { return graph_proto_; } + + /// Post-construction, failable initialization. + Status Initialize(); + + /// Finds a TypeProto for the given value name. If returning nullptr, then + /// an error will have been set. + const onnx::TypeProto *FindTypeProtoForName(std::string_view name); + + /// Attempts to access the raw or external data of the TensorProto. If the + /// the data is located in those positions, returns a types pointer to it + /// and stores the number of elements to `out_size`. Otherwise, nullptr is + /// returned (and no error is set). + template + const ElementType *GetOptionalRawData(const onnx::TensorProto &tp, + size_t &out_size) { + if (tp.has_raw_data()) { + out_size = tp.raw_data().size() / sizeof(ElementType); + return reinterpret_cast(tp.raw_data().data()); + } + return nullptr; + } + + std::vector &inputs() { return inputs_; } + std::unordered_map & + input_map() { + return input_map_; + } + std::vector &outputs() { return outputs_; } + std::unordered_map & + output_map() { + return output_map_; + } + + std::unordered_map & + initializer_map() { + return initializer_map_; + } + +private: + ModelInfo &model_info_; + const onnx::GraphProto &graph_proto_; + + std::unordered_map + initializer_map_; + std::unordered_map + value_info_map_; + + std::vector declared_inputs_; + std::vector inputs_; + std::vector outputs_; + std::unordered_map input_map_; + std::unordered_map + output_map_; +}; + +/// Top-level accounting and accessors for an ONNX model. +class ModelInfo { +public: + ModelInfo(); + Config &config() { return config_; } + onnx::ModelProto &model_proto() { return model_proto_; } + + /// Post-construction, failable initialization. + Status Initialize(); + + GraphInfo &main_graph() { return *main_graph_; } + const std::string &error_message() { return error_message_; } + + Status SetError(std::string msg) { + error_message_ = std::move(msg); + return failure(); + } + + void DebugDumpProto(); + +private: + Config config_; + onnx::ModelProto model_proto_; + std::unique_ptr main_graph_; + + std::string error_message_; +}; + +class ContextCache { +public: + ContextCache(ModelInfo &model_info, MlirContext context) + : model_info_(model_info), context_(context) {} + + MlirContext context() { return context_; } + + /// Converts the TypeProto to an MlirType, returning a null type and + /// setting an error if not possible. + MlirType ConvertTypeProto(const onnx::TypeProto &tp); + + /// Converts the ONNX element type code to an MlirType, returning a null type + /// and setting an error if not possible. + MlirType ConvertTensorElementType(int element_type_code); + + /// Converts an ONNX TensorProto to an MlirAttribute, returning a null + /// attribute and setting an error if not possible. + MlirAttribute ConvertTensorProtoToAttr(const onnx::TensorProto &tp); + + /// Converts the ONNX TensorProto to an Mlir RankedTensor type. + MlirType ConvertTensorProtoToBuiltinType(const onnx::TensorProto &tp); + + /// Converts the ONNX TensorProto to a !torch.vtensor type. + MlirType ConvertTensorProtoToVtensorType(const onnx::TensorProto &tp); + + /// Gets a !torch.vtensor type for the given dims and element type. + /// Dynamic dims are represented as -1. + /// If it was not possible to create the type, sets an error and returns + /// the null type. + MlirType GetVtensorType(const std::vector &dims, + MlirType element_type); + +private: + ModelInfo &model_info_; + MlirContext context_; + + std::unordered_map elem_type_map_; + std::unordered_map asm_type_map_; + std::vector shared_dims_; +}; + +/// Imports graph nodes into a function. +class NodeImporter { +public: + NodeImporter(GraphInfo &graph_info, ContextCache &cc, + MlirOperation module_op); + + /// Called after construction to define the function in the module. Must be + /// called prior to importing nodes. + Status DefineFunction(std::optional name = {}); + + /// Imports all nodes topologically. + Status ImportAll(); + + void DebugDumpModule(); + +private: + void PopulateGraphAttrs(MlirOperation container_op); + Status ImportInitializer(const onnx::TensorProto &initializer); + Status ImportNode(const onnx::NodeProto &node); + MlirAttribute ImportGeneralAttribute(const onnx::AttributeProto &onnx_attr); + + // Special-form nodes. + Status ImportGeneralNode(const onnx::NodeProto &node); + Status ImportConstantOfShapeNode(const onnx::NodeProto &node); + + /// Looks for an initializer for `name` and attempts to treat it as a 1D + /// shape, filling `shape` if successful. Returns failure and sets an error + /// if not. + Status GetImmediateShapeTensor(const std::string &name, + std::vector &shape); + + Status SetError(std::string msg) { + return graph_info_.model_info().SetError(std::move(msg)); + } + + GraphInfo &graph_info_; + ContextCache &cc_; + MlirContext context_; + MlirOperation module_op_; + MlirOperation func_op_; + MlirBlock body_block_; + MlirLocation default_loc_; + std::unordered_map nv_map_; +}; + +} // namespace torch_mlir_onnx diff --git a/projects/onnx_c_importer/README.md b/projects/onnx_c_importer/README.md new file mode 100644 index 000000000000..571c6fd41cd8 --- /dev/null +++ b/projects/onnx_c_importer/README.md @@ -0,0 +1,7 @@ +# ONNX C Importer + +This project provides a C implementation of the `onnx_importer.py`, which is +the canonical source. It is provided as sample code for anyone who wishes to +integrate it into their system. By design, it only depends on the ONNX API +and the MLIR C API via the `mlir-c` headers. As such, it should be easy to +build into any system that already has those things by adding the sources. diff --git a/projects/onnx_c_importer/import-onnx-main.cpp b/projects/onnx_c_importer/import-onnx-main.cpp new file mode 100644 index 000000000000..58ebd98b6a70 --- /dev/null +++ b/projects/onnx_c_importer/import-onnx-main.cpp @@ -0,0 +1,103 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +// This main driver uses LLVM tool-making facilities and the support lib. +// The actual importer libraries, however, only depend on the C API so that +// they can be included in foreign projects more easily. + +#include "torch-mlir-c/Registration.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/InitLLVM.h" +#include "llvm/Support/raw_ostream.h" + +#include "OnnxImporter.h" + +#include "onnx/onnx_pb.h" + +#include +#include + +using namespace llvm; +using namespace torch_mlir_onnx; + +struct MlirState { + MlirState() { + context = mlirContextCreateWithThreading(false); + torchMlirRegisterAllDialects(context); + module = mlirModuleCreateEmpty(mlirLocationUnknownGet(context)); + } + ~MlirState() { + mlirModuleDestroy(module); + mlirContextDestroy(context); + } + + MlirContext context; + MlirModule module; +}; + +int main(int argc, char **argv) { + static cl::opt inputFilename( + cl::Positional, cl::desc(""), cl::init("-")); + + static cl::opt outputFilename("o", cl::desc("Output filename"), + cl::value_desc("filename"), + cl::init("-")); + + InitLLVM y(argc, argv); + cl::ParseCommandLineOptions(argc, argv, "torch-mlir-onnx-import-c"); + + // Open the input as an istream because that is what protobuf likes. + std::unique_ptr alloced_input_stream; + std::istream *input_stream = nullptr; + if (inputFilename == "-") { + errs() << "(parsing from stdin)\n"; + input_stream = &std::cin; + } else { + alloced_input_stream = std::make_unique( + inputFilename, std::ios::in | std::ios::binary); + if (!*alloced_input_stream) { + errs() << "error: could not open input file " << inputFilename << "\n"; + return 1; + } + input_stream = alloced_input_stream.get(); + } + + // Parse the model proto. + ModelInfo model_info; + if (!model_info.model_proto().ParseFromIstream(input_stream)) { + errs() << "Failed to parse ONNX ModelProto from " << inputFilename << "\n"; + return 2; + } + + if (failed(model_info.Initialize())) { + errs() << "error: Import failure: " << model_info.error_message() << "\n"; + model_info.DebugDumpProto(); + return 3; + } + model_info.DebugDumpProto(); + + // Import. + MlirState owned_state; + ContextCache cc(model_info, owned_state.context); + NodeImporter importer(model_info.main_graph(), cc, + mlirModuleGetOperation(owned_state.module)); + if (failed(importer.DefineFunction())) { + errs() << "error: Could not define MLIR function for graph: " + << model_info.error_message() << "\n"; + return 4; + } + if (failed(importer.ImportAll())) { + errs() << "error: Could not import one or more graph nodes: " + << model_info.error_message() << "\n"; + return 5; + } + importer.DebugDumpModule(); + + return 0; +} From d560698e3d610ecdc56667c713e2338c47bf4f44 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Wed, 27 Dec 2023 17:53:07 -0800 Subject: [PATCH 049/283] Lower `onnx.split` to `torch.aten` (#2686) --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 142 ++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 30 ++++ 2 files changed, 172 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index f943f288fc40..c0e9af22ceb9 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -794,6 +794,148 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); }); + // split with fixed-size parts + // Arguments: + // - input: the tensor to split + // Attributes: + // - axis: the axis along which to split the input + // - num_outputs: the number of outputs to produce + // Outputs: + // - outputs: the produced outputs. Variadic with num_outputs elements. + // Note: torch.aten gives a list of tensors, but ONNX gives a variadic list of + // tensors + // so we need to unpack the list + patterns.onOp( + "Split", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Value self; + int64_t axis; + int64_t num_outputs; + if (binder.tensorOperand(self)) + return rewriter.notifyMatchFailure( + binder.op, "Not converting to AtenSplitTensorOp due to input " + "tensor mismatch"); + if (binder.s64IntegerAttr(axis, "axis", 0)) + return rewriter.notifyMatchFailure(binder.op, + "Failed to get axis attribute"); + if (binder.s64IntegerAttr(num_outputs, "num_outputs", 0)) + return rewriter.notifyMatchFailure( + binder.op, "Failed to get num_outputs attribute"); + + auto result0Ty = + binder.op->getResult(0).getType().cast(); + auto selfTy = self.getType().cast(); + + int64_t dim = axis; + if (dim < 0) + dim += selfTy.getSizes().size(); + + // set intermediate shape to the shape of the first result + // if the results are of different shapes + // set the splitted axis to variable shape + llvm::SmallVector intermediateShape(result0Ty.getSizes()); + for (auto result : binder.op->getResultTypes()) { + int64_t d = result.cast().getSizes()[dim]; + intermediateShape[dim] = d == intermediateShape[dim] ? d : -1; + } + + Value dimValue = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), dim)); + + Value splitSize = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), num_outputs)); + + // TODO: Attempting to use the shape expected by the ONNX mlir as ground + // truth. For now just use dynamic shapes. + auto resultOuterType = + Torch::ListType::get(rewriter.getType( + /*std::optional>=*/intermediateShape, + result0Ty.getOptionalDtype())); + Torch::AtenSplitTensorOp new_op = + rewriter.create( + binder.getLoc(), resultOuterType, self, splitSize, dimValue); + + // the onnx op is variadic with multiple results, but AtenSplitWithSizes + // outputs a list so we need to unpack the list + rewriter.replaceOpWithNewOp( + binder.op, binder.op->getResults().getType(), new_op.getResult()); + + return success(); + }); + + // split with variable parts + // Arguments: + // - input: the tensor to split + // - split: the sizes of the splits to be produced + // Attributes: + // - axis: the axis along which to split the input + // - num_outputs: the number of outputs to produce + // Outputs: + // - outputs: the produced outputs. Variadic with num_outputs elements. + // Note: torch.aten gives a list of tensors, but ONNX gives a variadic list of + // tensors + // so we need to unpack the list + patterns.onOp( + "Split", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Value self; + Value split; + int64_t axis; + int64_t num_outputs; + if (binder.tensorOperandAtIndex(self, 0) || + binder.tensorOperandAtIndex(split, 1)) + return rewriter.notifyMatchFailure( + binder.op, "Not converting to AtenSplitWithSizesOp due to input " + "tensor mismatch"); + if (binder.s64IntegerAttr(axis, "axis", 0)) + return rewriter.notifyMatchFailure(binder.op, + "Failed to get axis attribute"); + if (binder.s64IntegerAttr(num_outputs, "num_outputs", 0)) + return rewriter.notifyMatchFailure( + binder.op, "Failed to get num_outputs attribute"); + + auto result0Ty = + binder.op->getResult(0).getType().cast(); + auto selfTy = + cast(binder.op->getOperand(0).getType()); + + int64_t dim = axis; + if (dim < 0) + dim += selfTy.getSizes().size(); + + llvm::SmallVector intermediateShape(result0Ty.getSizes()); + for (auto result : binder.op->getResultTypes()) { + int64_t d = result.cast().getSizes()[dim]; + intermediateShape[dim] = d == intermediateShape[dim] ? d : -1; + } + + Torch::PrimTolistOp splitToList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(rewriter.getType()), split); + + Value dimValue = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), dim)); + + // TODO: Attempting to use the shape expected by the ONNX mlir as ground + // truth. For now just use dynamic shapes. + auto resultOuterType = + Torch::ListType::get(rewriter.getType( + /*std::optional>=*/intermediateShape, + result0Ty.getOptionalDtype())); + Torch::AtenSplitWithSizesOp new_op = + rewriter.create( + binder.getLoc(), resultOuterType, self, + splitToList.getResult(0), dimValue); + + // the onnx op is variadic with multiple results, but AtenSplitWithSizes + // outputs a list so we need to unpack the list + rewriter.replaceOpWithNewOp( + binder.op, binder.op->getResults().getType(), new_op.getResult()); + + return success(); + }); + patterns.onOp("Tan", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 5aca8688dac5..b2a19334ab09 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -795,6 +795,36 @@ func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[ // ----- +// CHECK-LABEL: func.func @test_split_variable_parts_2d_opset18( +// CHECK-SAME: %[[VAL_INPUT:.*]]: !torch.vtensor<[2,6],f32>, +// CHECK-SAME: %[[VAL_SPLIT:.*]]: !torch.vtensor<[2],si64> +// CHECK: %[[VAL_SPLIT_LIST:.*]] = torch.prim.tolist(%[[VAL_SPLIT]]) : !torch.vtensor<[2],si64> -> !torch.list +// CHECK: %[[VAL_AXIS:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_RESULT_LIST:.*]] = torch.aten.split_with_sizes %[[VAL_INPUT]], %[[VAL_SPLIT_LIST]], %[[VAL_AXIS]] : !torch.vtensor<[2,6],f32>, !torch.list, !torch.int -> !torch.list> +// CHECK: %[[VAL_VARIADIC_RETURN_VALUE:.*]]:2 = torch.prim.ListUnpack %[[VAL_RESULT_LIST]] : !torch.list> -> !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,4],f32> +// CHECK: return %[[VAL_VARIADIC_RETURN_VALUE]]#0, %[[VAL_VARIADIC_RETURN_VALUE]]#1 : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,4],f32> +func.func @test_split_variable_parts_2d_opset18(%arg0: !torch.vtensor<[2,6],f32>, %arg1: !torch.vtensor<[2],si64>) -> (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,4],f32>) attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0:2 = torch.operator "onnx.Split"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,6],f32>, !torch.vtensor<[2],si64>) -> (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,4],f32>) + return %0#0, %0#1 : !torch.vtensor<[2,2],f32>, !torch.vtensor<[2,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_split_2d_uneven_split_opset18( +// CHECK-SAME: %[[INPUT_TENSOR:.*]]: !torch.vtensor<[2,8],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>) attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK: %[[AXIS:.*]] = torch.constant.int 1 +// CHECK: %[[SPLIT_SIZE:.*]] = torch.constant.int 3 +// CHECK: %[[SPLIT_RESULT:.*]] = torch.aten.split.Tensor %[[INPUT_TENSOR]], %[[SPLIT_SIZE]], %[[AXIS]] : !torch.vtensor<[2,8],f32>, !torch.int, !torch.int -> !torch.list> +// CHECK: %[[UNPACKED_TENSORS:.*]]:3 = torch.prim.ListUnpack %[[SPLIT_RESULT]] : !torch.list> -> !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32> +// CHECK: return %[[UNPACKED_TENSORS]]#0, %[[UNPACKED_TENSORS]]#1, %[[UNPACKED_TENSORS]]#2 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32> +// CHECK: } +func.func @test_split_2d_uneven_split_opset18(%arg0: !torch.vtensor<[2,8],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>) attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0:3 = torch.operator "onnx.Split"(%arg0) {torch.onnx.axis = 1 : si64, torch.onnx.num_outputs = 3 : si64} : (!torch.vtensor<[2,8],f32>) -> (!torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32>) + return %0#0, %0#1, %0#2 : !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,3],f32>, !torch.vtensor<[2,2],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_tan func.func @test_tan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[TAN:.+]] = torch.aten.tan %arg0 From 8e389ff2ffac781648721696c716f141048c38c9 Mon Sep 17 00:00:00 2001 From: Sungsoon Cho <1025787+godot73@users.noreply.github.com> Date: Wed, 27 Dec 2023 20:33:18 -0800 Subject: [PATCH 050/283] Implement lowering of torch.aten.exponential (#2680) https://github.com/llvm/torch-mlir/issues/2646 Decompose aten.exponential() into: -exp(1-x)/lambda --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 ++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 7 +++ .../Torch/Transforms/DecomposeComplexOps.cpp | 46 +++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 1 + .../build_tools/abstract_interp_lib_gen.py | 7 +++ .../build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/rng.py | 23 ++++++++++ 8 files changed, 111 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 6013f6da3cfc..16eb5565bedd 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -4739,6 +4739,31 @@ def Torch_AtenBernoulliPOp : Torch_Op<"aten.bernoulli.p", [ }]; } +def Torch_AtenExponentialOp : Torch_Op<"aten.exponential", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::exponential : (Tensor, float, Generator?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_FloatType:$lambd, + AnyTorchOptionalGeneratorType:$generator + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenExponentialOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenExponentialOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenMultinomialOp : Torch_Op<"aten.multinomial", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 1031f4aa7e53..25e83899bc1d 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7580,6 +7580,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.uniform\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.exponential\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.any) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.rand\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -9382,6 +9385,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.exponential\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.rand\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %none = torch.constant.none\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index d8b8639e0a75..63fa66ccc31e 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3562,6 +3562,51 @@ class DecomposeAtenBernoulliTensorOp }; } // namespace +namespace { +// Decompose exponential() to do inverse transform sampling. +// - https://en.wikipedia.org/wiki/Inverse_transform_sampling +// With the exponential distribution, F(x) = 1 - exp(-lambda * x). Thus, +// exponential() = - ln(1 - uniform(0, 1)) / lambda. +class DecomposeAtenExponentialOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenExponentialOp op, + PatternRewriter &rewriter) const override { + if (!op.getGenerator().getType().isa()) + return rewriter.notifyMatchFailure( + op, "The generator has to be None because only global default " + "generator is supported"); + + Location loc = op.getLoc(); + Type resultType = op.getType(); + + // Create a uniform random op with low and high set to 0.0 and 1.0, + // respectively. + Value none = rewriter.create(loc); + Value zero = + rewriter.create(loc, rewriter.getF64FloatAttr(0.0)); + Value one = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value emptyTensor = rewriter.create( + loc, resultType, op.getSelf(), zero, /*dtype=*/none, /*layout=*/none, + /*device=*/none, /*pin_memoty=*/none, /*memory_format=*/none); + Value x = rewriter.create(loc, resultType, emptyTensor, + /*from=*/zero, /*to=*/one, + /*generator=*/none); + + Value negX = rewriter.create(loc, resultType, x); + Value oneMinusX = + rewriter.create(loc, resultType, negX, one, + /*alpha=*/one); + Value lnOneMinusX = rewriter.create(loc, resultType, oneMinusX); + Value negLambda = rewriter.create(loc, op.getLambd()); + rewriter.replaceOpWithNewOp(op, resultType, lnOneMinusX, + negLambda); + return success(); + } +}; +} // namespace + namespace { template class DecomposeAtenAddCLikeOp : public OpRewritePattern { @@ -6410,6 +6455,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal< DecomposeAtenBernoulliLikeOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 79f64ef32fbf..933140d3013d 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -427,6 +427,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index c70b01b47819..6f683a43c34f 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1397,6 +1397,7 @@ "CeilFloatModule_basic", "DivFloatModule_basic", "EqIntModule_basic", + "ExponentialModule_basic", "GeFloatIntModule_basic", "GeFloatModule_basic", "GeIntModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 338f5e97e100..2e6094a6fa20 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -831,6 +831,9 @@ def aten〇copy〡shape(self: List[int], src: List[int], non_blocking: bool = Fa def aten〇uniform〡shape(self: List[int], from_: float = 0., to: float = 1., generator: Any = None) -> List[int]: return self +def aten〇exponential〡shape(self: List[int], lambd: float = 1., generator: Any = None) -> List[int]: + return self + def aten〇rand〡shape(size: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return size @@ -2267,6 +2270,10 @@ def aten〇uniform〡dtype(self_rank_dtype: Tuple[int, int], from_: float = 0., self_rank, self_dtype = self_rank_dtype return self_dtype +def aten〇exponential〡dtype(self_rank_dtype: Tuple[int, int], lambd: float = 1., generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function([Invocation([1]), Invocation([1], dtype=torch.float16), Invocation([1], dtype=torch.complex64)]) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index efee6c852eb4..fb458f6a5d91 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -378,6 +378,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::bernoulli : (Tensor, Generator?) -> (Tensor)") emit("aten::bernoulli_.float : (Tensor, float, Generator?) -> (Tensor)") emit("aten::bernoulli.p : (Tensor, float, Generator?) -> (Tensor)") + emit("aten::exponential : (Tensor, float, Generator?) -> (Tensor)") emit("aten::multinomial : (Tensor, int, bool, Generator?) -> (Tensor)") emit("aten::randint.low : (int, int, int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::randint : (int, int[], int?, int?, Device?, bool?) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py index 1baa462462f1..dedd2b398bd4 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py @@ -157,6 +157,29 @@ def UniformNoCorrelationModule_basic(module, tu: TestUtils): # ============================================================================== +class ExponentialModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float64, True), + ]) + def forward(self, x): + a = torch.ops.aten.exponential(x, 3.0) + mean = torch.mean(a) + std = torch.std(a) + return mean, std + + +@register_test_case(module_factory=lambda: ExponentialModule()) +def ExponentialModule_basic(module, tu: TestUtils): + module.forward( + tu.rand(512, 512, 16).double()) + +# ============================================================================== + class BernoulliModule(torch.nn.Module): def __init__(self): super().__init__() From 9fc212ea9afe0f1e31b4e2ee03bc6db296e84190 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Thu, 28 Dec 2023 09:31:41 -0800 Subject: [PATCH 051/283] support Onnx opset 1-13 ReduceMean where axes is supplied as an attr (#2703) (instead of an input) Addresses part of #2689. fixes #2702 --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 79 ++++++++++++++++++- .../unsupported_fb_opt_ops.mlir | 1 - 2 files changed, 78 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index c0e9af22ceb9..4756315c800e 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -545,8 +545,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*dtype=*/noneVal); return success(); }); + // onnx.ReduceMean with axes provided as argument introduced in opset 18 patterns.onOp( - "ReduceMean", 13, + "ReduceMean", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data; @@ -632,6 +633,82 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*dtype=*/noneVal); return success(); }); + + // onnx.ReduceMean with axes provided as attribute + patterns.onOp( + "ReduceMean", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data; + llvm::SmallVector axes; + int64_t keepDims; + int64_t noop_with_empty_axes; + if (binder.tensorOperand(data) || + binder.tensorResultType(resultType) || + binder.s64IntegerArrayAttr(axes, "axes", 0) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + SmallVector dimList; + SmallVector selectSizes; + selectSizes.push_back(1); + Value noneVal = rewriter.create(binder.getLoc()); + // deal with case when axes is empty + if (axes.size() == 0) { + if (noop_with_empty_axes == 0) { + Value keepDimsConstInt = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); + Value keepDimsBool = rewriter.create( + binder.getLoc(), keepDimsConstInt); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, /*dim=*/noneVal, keepDimsBool, + /*dtype=*/noneVal); + } else { + rewriter.replaceOp(binder.op, data); + } + return success(); + } + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + int64_t adjustmentInt = + cast(data.getType()).getSizes().size(); + Value adjustment = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + adjustmentInt)); + // convert axes (tensor) into torch int list while dealing with neg axis + for (int i = 0; i < axes.size(); i++) { + // Go through the axes list and get each dim in the list + int64_t dim = axes[i]; + if (dim < 0) { + dim += adjustmentInt; + } + // deal with neg axis: if (axis < 0) axis += rank + Value finalDim = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), dim)); + dimList.push_back(finalDim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + Value keepDimBool; + if (keepDims == 1) { + keepDimBool = + rewriter.create(binder.getLoc(), true); + } else { + keepDimBool = + rewriter.create(binder.getLoc(), false); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList, keepDimBool, + /*dtype=*/noneVal); + return success(); + }); patterns.onOp( "ReduceMin", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir b/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir index 8401c378b77c..3ed9f1c6ebe6 100644 --- a/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir +++ b/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir @@ -34,7 +34,6 @@ func.func @equal_operation(%arg0: !torch.vtensor<[4],si64>, func.func @reduce_mean_operation(%arg0: !torch.vtensor<[1,64,768],f32>) -> !torch.vtensor<[1,64,1],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // The ReduceMean operation as provided. - // expected-error @+1 {{failed to legalize operation 'torch.operator'}} %211 = torch.operator "onnx.ReduceMean"(%arg0) {torch.onnx.axes = [-1 : si64]} : (!torch.vtensor<[1,64,768],f32>) -> !torch.vtensor<[1,64,1],f32> return %211 : !torch.vtensor<[1,64,1],f32> } \ No newline at end of file From 6660a26594dc82cd3dd6fc33c9269ff09ecd263a Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Thu, 28 Dec 2023 17:20:32 -0800 Subject: [PATCH 052/283] lower torch.aten.isinf to linalg (#2638) Co-authored-by: Rob Suderman --- .../TorchToLinalg/Uncategorized.cpp | 11 ++++++-- .../base_lazy_backend/shape_inference.cpp | 4 +++ projects/pt1/e2e_testing/xfail_sets.py | 4 ++- .../test_suite/elementwise.py | 25 +++++++++++++++++++ 4 files changed, 41 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index e947ae73ace0..0943534dbd9c 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -426,6 +426,12 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (isa(op)) return b.create(loc, payloadArgs[0]); + if (isa(op)){ + Value abs = b.create(loc, payloadArgs[0]); + Value infinity = b.create( + loc, b.getFloatAttr(abs.getType(), std::numeric_limits::infinity())); + return createEqual(b, loc, abs.getType(), abs, infinity); + } if (isa(op)) { auto negate = createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); @@ -1343,7 +1349,7 @@ class ConvertElementwiseOp : public ConversionPattern { AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, - AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, + AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenRealOp, AtenImagOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); @@ -1992,7 +1998,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, - AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenTriuOp, AtenTrilOp, + AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, + AtenTrilOp, AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRealOp, AtenImagOp>(); patterns.add(typeConverter, context); diff --git a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp index d5458f9c4ea6..244ee7b88cc0 100644 --- a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp +++ b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp @@ -39,6 +39,10 @@ std::vector compute_shape_div(const at::Tensor& self, return {Shape(self.scalar_type(), self.sizes().vec())}; } +std::vector compute_shape_isinf(const at::Tensor& self) { + return {Shape(at::kBool, self.sizes().vec())}; +} + std::vector compute_shape_max_pool3d_with_indices( const at::Tensor& self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 6f683a43c34f..d6cb60e57367 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1033,6 +1033,7 @@ "ElementwiseAddScalarIntModule_basic", "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", "ElementwiseAtenDivIntScalarModule_basic", + "ElementwiseAtenIsinfOpModule_basic", "ElementwiseAtenWhereSelfModule_basic", "ElementwiseBinaryModule_basic", "ElementwiseBinaryStaticShapeModule_basic", @@ -1328,6 +1329,8 @@ "SliceWholeTensorModule_basic", "TensorFloatModule_basic", "TensorIntModule_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", }) - { ### Test failing in make_fx_tosa but not in tosa @@ -1489,5 +1492,4 @@ "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", - "ElementwiseIsinfModule_basic", } diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 33c420a1c517..15e45b52eb44 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -3385,6 +3385,31 @@ def ElementwiseAtenLogicalNotOpModule_basic(module, tu: TestUtils): module.forward(tu.randint(4, 5, high=2).bool()) +# ============================================================================== + +class ElementwiseAtenIsinfOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.isinf(x) + +@register_test_case(module_factory=lambda: ElementwiseAtenIsinfOpModule()) +def ElementwiseAtenIsinfOpModule_basic(module, tu: TestUtils): + test_input = torch.tensor( + [ + [1, float('inf'), 2, float('-inf'), float('nan')], + [1, float('inf'), float('-inf'), float('nan'), 3], + ] + ) + module.forward(test_input) + + # ============================================================================== From 9adad9bc407d92860b99f74b02da3a07b315d6b0 Mon Sep 17 00:00:00 2001 From: kumardeepakamd <123522031+kumardeepakamd@users.noreply.github.com> Date: Tue, 2 Jan 2024 11:05:11 -0800 Subject: [PATCH 053/283] Add support for reflection_pad1d (#2706) Adds a lowering to Linalg for reflection_pad1d. Based on ideas/code from draft PR https://github.com/llvm/torch-mlir/pull/2693. --------- Co-authored-by: Kumar Deepak --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 24 +++ lib/Conversion/TorchToLinalg/DataMovement.cpp | 139 ++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 50 +++++++ .../build_tools/abstract_interp_lib_gen.py | 27 ++++ .../build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/basic.py | 72 +++++++++ 6 files changed, 313 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 16eb5565bedd..23e65d75d77f 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7869,6 +7869,30 @@ def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ }]; } +def Torch_AtenReflectionPad1dOp : Torch_Op<"aten.reflection_pad1d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$padding + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenReflectionPad1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenReflectionPad1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenPadOp : Torch_Op<"aten.pad", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index dae387422b52..6534e859881e 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -107,6 +107,143 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, return success(); } +// Example: +// input = tensor([[[0., 1., 2., 3.], +// [4., 5., 6., 7.]]]) +// torch.ops.aten.reflection_pad1d(input, (3,1)) ; padding_left = 3, padding_right = 1 +// tensor([[[3., 2., 1., 0., 1., 2., 3., 2.], +// [7., 6., 5., 4., 5., 6., 7., 6.]]]) +// Checks: 1) Each of padding_left and padding_right must be non-negative less than size of last dimension +// Implementation: a) Construct a result tensor of shape of input tensor except for the last dimension. +// The last dimension of the result tensor should be last dimension of input tensor + +// left padding size + right padding size. INitialize result tensor to all zeros +// b) Setup affine map to take slice from input tensor of size left padding starting from +// second column onwards as first column is reflection boundary +// c) Reflect the affine map to have resultant slice reflected +// d) Take the slice and write from begining in result tensor +// e) write the original tensor next into result tensor +// f) Setup affine map to take slice from input tensor of right padding size ending +// at second last column as last column is reflection boundary for right padding +// g) Reflect the affine map to have resultant slice reflected +// h) Take the slice and write from left padding size + orignal tensor last dim size +// into result tensor +// Uses the ideas/code used for AtenReflectionPad2dOp +namespace { +class ConvertAtenReflectionPad1dOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenReflectionPad1dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + SmallVector padInts; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) + return rewriter.notifyMatchFailure( + op, "only constant int padding range is supported"); + + MLIRContext *context = rewriter.getContext(); + Location loc = op.getLoc(); + + // Lambda Unitility Functions + // Create an Integer expression of x + y + auto createIAdd = [&](Value x, Value y) { + return rewriter.create(loc, x, y); + }; + + // Create an integer expression of x - y + auto createISub = [&](Value x, Value y) { + return rewriter.create(loc, x, y); + }; + + enum PadLocation {PAD_LEFT = 0, PAD_RIGHT = 1, PAD_CENTER=2}; + + Value input = adaptor.getSelf(); + Type indexType = rewriter.getIndexType(); + Value zero = getConstant(rewriter, loc, 0, indexType); + Value one = getConstant(rewriter, loc, 1, indexType); + auto inputType = llvm::cast(input.getType()); + auto outputType = llvm::cast(getTypeConverter()->convertType(op->getResult(0).getType())); + unsigned numDims = inputType.getRank(); + assert(numDims >= 2 && "Not enough input dimensions"); + int64_t lastDim = numDims - 1; + SmallVector inputShape = getTensorSizes(rewriter, loc, input); + Value lastDimSize = inputShape[lastDim]; // input [1,2,4], then lastDim = 2, inputShape[2] will give 4 + + Value tileWidth[3], extractOffset[3], insertOffset[3]; + + tileWidth[PAD_LEFT] = getConstant(rewriter, loc, padInts[PAD_LEFT], indexType); + tileWidth[PAD_RIGHT] = getConstant(rewriter, loc, padInts[PAD_RIGHT], indexType); + tileWidth[PAD_CENTER] = lastDimSize; + + extractOffset[PAD_LEFT] = one; + // for (1,2,4) input, padding (3,1) lastDimSize=4, 4 - 1 - 1 = 2 [3,5, 6,7], so start offset to 6, which is right + // lasDimSize - (tileWidth[PAD_RIGHT] + one) + extractOffset[PAD_RIGHT] = createISub(lastDimSize, createIAdd(tileWidth[PAD_RIGHT], one)); + extractOffset[PAD_CENTER] = zero; + + insertOffset[PAD_LEFT] = zero; + insertOffset[PAD_RIGHT] = createIAdd(lastDimSize, tileWidth[PAD_LEFT]); + insertOffset[PAD_CENTER] = tileWidth[PAD_LEFT]; + + + SmallVector resultShape{inputShape}; + // Result's last dimension will have shape lastDimSize + left padding size + right padding size + resultShape[lastDim] = createIAdd(resultShape[lastDim], createIAdd(tileWidth[PAD_LEFT], tileWidth[PAD_RIGHT])); + Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape, inputType.getElementType()); + + // Helper to reflect/reverse the i-th dimension of an affine map without symbols. This only works if applied on a tensor + // for which the corresponding dimension has a statically known size + auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i, int64_t size) { + AffineExpr d = map.getResult(i); + return map.replace(d, size - d - 1, numDims, 0); // left reflect for (3,1) on input shape (1,2,4). size = 3, lastDim=2, numDims=3 + }; + + SmallVector iteratorTypes{numDims, utils::IteratorType::parallel}; + auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context); + SmallVector allOneStrides(numDims, one); + + auto addTileToResult = [&](PadLocation padPosition) { + // Create the tile by extracting a slice from the input tensor. + SmallVector extractShape{inputShape}; + extractShape[lastDim] = tileWidth[padPosition]; + SmallVector extractOffsets(numDims, zero); + extractOffsets[lastDim] = extractOffset[padPosition]; + Value tile = rewriter.create( + loc, input, extractOffsets, extractShape, allOneStrides); + + + auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context); + // Setup the affine map function to resverse the tile along the horizontal for left and right slices + if(padPosition < PAD_CENTER) { + inputMap = reflectDim(inputMap, numDims, lastDim, padInts[padPosition]); + // Take reflected slice as per inputMap + tile = rewriter.create(loc, llvm::cast(tile.getType()), tile, + tile, ArrayRef({inputMap, idMap}), iteratorTypes, + [](OpBuilder &b, Location nestedLoc, ValueRange args) { + b.create(nestedLoc, args[0]); + }).getResult(0); + } + // Insert the tile in the resultTensor + SmallVector insertOffsets(numDims, zero); + insertOffsets[lastDim] = insertOffset[padPosition]; + resultTensor = rewriter.create(loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides); + }; + + if(padInts[PAD_LEFT] > 0) + addTileToResult(PAD_LEFT); + if(padInts[PAD_RIGHT] > 0) + addTileToResult(PAD_RIGHT); + addTileToResult(PAD_CENTER); + + rewriter.replaceOpWithNewOp(op, outputType, resultTensor); + return success(); + } +}; +} + namespace { class ConvertAtenFlattenUsingIntsOp : public OpConversionPattern { @@ -1413,6 +1550,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 25e83899bc1d..4adf55556a2e 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8331,6 +8331,41 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.reflection_pad1d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %int-1 = torch.constant.int -1\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %5 = torch.aten.lt.int %3, %2 : !torch.int, !torch.int -> !torch.bool\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" %8 = torch.aten.lt.int %4, %2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %8 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %7 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.index.Tensor\"(%arg0: !torch.list, %arg1: !torch.list>>) -> !torch.list {\n" " %0 = call @__torch__.index_tensor_like(%arg0, %arg1) : (!torch.list, !torch.list>>) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8952,6 +8987,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad1d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: padding size expected to be 2\"\n" +" %int2 = torch.constant.int 2\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %2 = torch.aten.eq.int %1, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.contiguous\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 2e6094a6fa20..48949c318e22 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1271,6 +1271,21 @@ def aten〇constant_pad_nd〡shape(self: List[int], pad: List[int], value: float def aten〇pad〡shape(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]: return pad_shape_fn(self, pad) +#Padding size must be smaller than the size of the last dimension +@check_shape_function([ErrorInvocation(TensorOfShape(1, 2, 4), padding=[4,1]), + Invocation(TensorOfShape(1, 2, 4), padding=[3,3]), + ErrorInvocation(TensorOfShape(1, 2, 4), padding=[1,4]), + ErrorInvocation(TensorOfShape(1, 4), padding=[4,1]), + Invocation(TensorOfShape(1, 4), padding=[3,3]), + ErrorInvocation(TensorOfShape(1, 4), padding=[1,4])]) +def aten〇reflection_pad1d〡shape(self: List[int], padding: List[int]) -> List[int]: + assert len(self) >= 2 + hdim = self[-1] + padding_left = padding[0] + padding_right = padding[1] + assert padding_left < hdim and padding_right < hdim + return pad_shape_fn(self, padding) + # TODO: upstream this def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> List[int]: assert len(indices) <= len(self), "More indices than dimensions to index" @@ -1804,6 +1819,18 @@ def aten〇constant_pad_nd〡dtype(self_rank_dtype: Tuple[int, int], pad: List[i self_rank, self_dtype = self_rank_dtype return self_dtype + +@check_dtype_function([ErrorInvocation(TensorOfShape(2, 3, 4), padding=1), + ErrorInvocation(TensorOfShape(2, 3, 4), padding=[]), + ErrorInvocation(TensorOfShape(2, 3, 4), padding=[2]), + Invocation(TensorOfShape(2, 3, 4), padding=[2,1]), + Invocation(TensorOfShape(5, 5, 4), padding=[1,2]), + ErrorInvocation(TensorOfShape(2, 3, 4), padding=[3,2,1])]) +def aten〇reflection_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert len(padding) == 2, 'padding size expected to be 2' + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇contiguous〡dtype(self_rank_dtype: Tuple[int, int], memory_format: int = 0) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index fb458f6a5d91..4d5b65c1dcd6 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -541,6 +541,7 @@ def emit_with_mutating_variants(key, **kwargs): # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") + emit("aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)") emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)") emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True) emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 971aa1efca77..20bc293e796b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -552,8 +552,80 @@ def ConstantPadNdPartialStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class ReflectionPad1dModule3dInput(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 2, 4], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad1d(x, (3,1)) + + +@register_test_case(module_factory=lambda: ReflectionPad1dModule3dInput()) +def ReflectionPad1dModule3dInput_basic(module, tu: TestUtils): + module.forward(tu.rand(1,2,4)) + + +class ReflectionPad1dModule2dInput(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 4], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad1d(x, (3,2)) + + +@register_test_case(module_factory=lambda: ReflectionPad1dModule2dInput()) +def ReflectionPad1dModule2dInput_basic(module, tu: TestUtils): + module.forward(tu.rand(2,4)) + +class ReflectionPad1dModule3dInputLeft(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 4, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad1d(x, (2,0)) + + +@register_test_case(module_factory=lambda: ReflectionPad1dModule3dInputLeft()) +def ReflectionPad1dModule3dInput_Left(module, tu: TestUtils): + module.forward(tu.rand(1,4,5)) + +class ReflectionPad1dModule2dInputRight(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 6], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad1d(x, (0,3)) + + +@register_test_case(module_factory=lambda: ReflectionPad1dModule2dInputRight()) +def ReflectionPad1dModule2dInput_Right(module, tu: TestUtils): + module.forward(tu.rand(3,6)) + +# ============================================================================== class TransposeIntModule(torch.nn.Module): def __init__(self): From 80bd093d56c5f7b36c6852fdff05afac0d7c3a00 Mon Sep 17 00:00:00 2001 From: Andreas Falkenberg <149819731+afalkenberg1@users.noreply.github.com> Date: Fri, 29 Dec 2023 11:21:55 -0800 Subject: [PATCH 054/283] Added tensorResultTypeAtIndex to Patterns.h Need this for LayerNorm --- .../torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index d842ea77bd3c..06bbb1ac526f 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -95,6 +95,16 @@ struct OpBinder { return success(); } + ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx, int64_t idx) { + if (idx >= op->getNumResults()) + return failure(); + auto t = toValidTensorType(op->getResult(idx).getType()); + if (!t) + return failure(); + typeIdx = t; + return success(); + } + // Attribute accessors. ParseResult s64BoolAttr(bool &value, StringRef nameSuffix, bool defaultValue = false) { From 690827fe52768eff2be5b168581824de57de2e1b Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 2 Jan 2024 10:42:20 +0000 Subject: [PATCH 055/283] build: manually update PyTorch version Set PyTorch and TorchVision version to nightly release 2024-01-02. Signed-Off By: Vivek Khandelwal --- projects/pt1/e2e_testing/xfail_sets.py | 1 - pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 4 files changed, 3 insertions(+), 4 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index d6cb60e57367..76f84344bd42 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -139,7 +139,6 @@ # START tests failing due to: torch._dynamo.exc.Unsupported: data dependent operator: aten._local_scalar_dense.default 'AtenFloatScalarModule_basic', 'AtenIntBoolOpModule_basic', - 'OneHotModule_basic', 'QuantizedMLP_basic', 'ScalarImplicitFloatModule_basic', 'ScalarImplicitIntModule_basic', diff --git a/pytorch-hash.txt b/pytorch-hash.txt index 2caf78c61ce4..6d22e4c8b2c5 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -a111e45dfe64cd565b2c0369b683f67d6658d2cc +7003edfbb4995b3712ba46aa7e39f1256b7fa4a6 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 1cc35e3e3787..1d07d05a1f36 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.2.0.dev20231204 +torch==2.3.0.dev20240101 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index 5481769e4fe2..c44f5222d172 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.17.0.dev20231204 +torchvision==0.18.0.dev20240101 From 1778314620b796de7a7aba61f00396cecbd29a0b Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Wed, 3 Jan 2024 12:52:59 -0500 Subject: [PATCH 056/283] add basic cumsum. this doesn't support the exclusive and reverse attrs (#2717) fixes #2711 --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 56 +++++++++++++++++++ .../unsupported_fb_opt_ops.mlir | 9 +++ 2 files changed, 65 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 3d2cf8aaee73..86f23bee162c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -836,6 +836,62 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); + patterns.onOp( + "CumSum", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); + Torch::ValueTensorType resultType; + Value operand; + Value axisTensor; + if (binder.tensorOperands(operand, axisTensor) || + binder.tensorResultType(resultType)) + return failure(); + + int64_t exclusive; + int64_t reverse; + // if bind succeeds and either is set, fail because not implemented + if (binder.s64IntegerAttr(exclusive, "exclusive", 0)) + if (exclusive != 0) + return rewriter.notifyMatchFailure( + binder.op, "unsupported onnx.CumSum conversion: exclusive"); + if (binder.s64IntegerAttr(reverse, "reverse", 0)) + if (reverse != 0) + return rewriter.notifyMatchFailure( + binder.op, "unsupported onnx.CumSum conversion: reverse"); + + // deal with neg axis: if (axis < 0) axis += rank + int64_t rank = + cast(operand.getType()).getSizes().size(); + Value rankVal = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + rank)); + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + + Value axisScalar = rewriter.create( + binder.getLoc(), rewriter.getType(), axisTensor); + Value isNegative = + rewriter.create(binder.getLoc(), axisScalar, zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, rankVal); + Value dim = rewriter.create( + binder.getLoc(), axisScalar, finalOffset); + + Torch::BaseTensorType resultTensorType = resultType.cast(); + if (!resultTensorType.hasDtype()) { + return rewriter.notifyMatchFailure( + binder.op, "expected result type to have a dtype"); + } + // resultTensorType.print(llvm::outs()); + Value resultDType = + Torch::getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype()); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, dim, resultDType); + return success(); + }); patterns.onOp("Div", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir b/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir index 3ed9f1c6ebe6..6659935ffa6f 100644 --- a/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir +++ b/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir @@ -36,4 +36,13 @@ func.func @reduce_mean_operation(%arg0: !torch.vtensor<[1,64,768],f32>) // The ReduceMean operation as provided. %211 = torch.operator "onnx.ReduceMean"(%arg0) {torch.onnx.axes = [-1 : si64]} : (!torch.vtensor<[1,64,768],f32>) -> !torch.vtensor<[1,64,1],f32> return %211 : !torch.vtensor<[1,64,1],f32> +} + +// ----- +// Fixed. +func.func @cumsum_operation(%arg0: !torch.vtensor<[2,3],f64>, + %arg1: !torch.vtensor<[],si32>) + -> !torch.vtensor<[2,3],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %212 = torch.operator "onnx.CumSum"(%arg0, %arg1) : (!torch.vtensor<[2,3],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[2,3],f64> + return %212 : !torch.vtensor<[2,3],f64> } \ No newline at end of file From 3e9bacdb514af36dea80f4e40d251f2a2cca4c4e Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Wed, 3 Jan 2024 16:10:50 -0800 Subject: [PATCH 057/283] [torch-mlir] update e2e test class documentation (#2722) The doc seems copy-and-paste from the linalg-on-tensors class --- .../python/torch_mlir_e2e_test/configs/stablehlo_backend.py | 4 ++-- .../pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/stablehlo_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/stablehlo_backend.py index 45f32bb0b3fe..8a244b756e6c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/stablehlo_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/stablehlo_backend.py @@ -18,10 +18,10 @@ class StablehloBackendTestConfig(TestConfig): - """Base class for TestConfig's that are implemented with linalg-on-tensors. + """Base class for TestConfig's that are implemented with StableHLO. This class handles all the common lowering that torch-mlir does before - reaching the linalg-on-tensors abstraction level. + reaching the StableHLO abstraction level. """ def __init__(self, backend: StablehloBackend): diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py index 89b90567b1d4..8efab87a2bfe 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py @@ -18,10 +18,10 @@ class TosaBackendTestConfig(TestConfig): - """Base class for TestConfig's that are implemented with linalg-on-tensors. + """Base class for TestConfig's that are implemented with TOSA. This class handles all the common lowering that torch-mlir does before - reaching the linalg-on-tensors abstraction level. + reaching the TOSA abstraction level. """ def __init__(self, backend: TosaBackend, use_make_fx: bool = False): super().__init__() From 4e5e34d215fa00912a8205a1d0406ee5719003a7 Mon Sep 17 00:00:00 2001 From: John Wu Date: Wed, 3 Jan 2024 19:41:10 -0800 Subject: [PATCH 058/283] [MLIR][ONNX] Add OnnxToTorch support for Slice Op (#2696) --- .../Conversion/TorchOnnxToTorch/Patterns.h | 4 +- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 164 +++++++++++++++++- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 129 ++++++++++++++ 3 files changed, 294 insertions(+), 3 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 06bbb1ac526f..1ce381005fcc 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -33,6 +33,8 @@ struct OpBinder { Location getLoc() { return op->getLoc(); } + int getNumOperands() { return op->getNumOperands(); } + // Operand matches of different arities. ParseResult tensorOperand(Value &value0) { if (op->getNumOperands() != 1) @@ -189,7 +191,7 @@ struct OpBinder { } ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix, - std::string defaultValue = "") { + std::string defaultValue = "") { SmallString<64> name("torch.onnx."); name.append(nameSuffix); auto attr = op->getAttr(name); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 4756315c800e..f0e11ad1cd13 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -643,8 +643,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( llvm::SmallVector axes; int64_t keepDims; int64_t noop_with_empty_axes; - if (binder.tensorOperand(data) || - binder.tensorResultType(resultType) || + if (binder.tensorOperand(data) || binder.tensorResultType(resultType) || binder.s64IntegerArrayAttr(axes, "axes", 0) || binder.s64IntegerAttr(keepDims, "keepdims", 1) || binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", @@ -1092,7 +1091,168 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( rewriter.replaceOp(binder.op, operand); return success(); }); + patterns.onOp( + "Slice", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultTorchType; + Value operand, starts, ends; + // Handle if axes are not provided + + if (binder.tensorOperandAtIndex(operand, 0) || + binder.tensorOperandAtIndex(starts, 1) || + binder.tensorOperandAtIndex(ends, 2) || + binder.tensorResultType(resultTorchType)) { + return failure(); + } + + auto context = rewriter.getContext(); + auto operandTorchTy = operand.getType().cast(); + auto operandTy = + operandTorchTy.toBuiltinTensor().dyn_cast(); + + if (!operandTy) + return rewriter.notifyMatchFailure( + binder.op, + "Expected tensor operator argument to be a ranked tensor type"); + + auto startsTorchTy = starts.getType().cast(); + auto startsTy = + startsTorchTy.toBuiltinTensor().dyn_cast(); + int startSize = startsTy.getDimSize(0); + + auto endsTorchTy = ends.getType().cast(); + auto endsTy = + endsTorchTy.toBuiltinTensor().dyn_cast(); + int endSize = endsTy.getDimSize(0); + auto resultTy = + resultTorchType.toBuiltinTensor().dyn_cast(); + if (!resultTy) + return rewriter.notifyMatchFailure( + binder.op, "Expected result type to be a ranked tensor type"); + + Location loc = binder.getLoc(); + + // Binding `axes` from its arguments or through a default value + Value axes; + if (binder.getNumOperands() >= 4) { + if (binder.tensorOperandAtIndex(axes, 3)) { + return failure(); + } + } else { + // The default axes value is the range from 0 to the number of + // dimensions + Value none = rewriter.create(loc); + auto defaultAxesType = Torch::ValueTensorType::get( + context, ArrayRef{operandTy.getRank()}, + rewriter.getIntegerType(64, /*signed*/ 1)); + Value arangeLength = rewriter.create( + loc, rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + operandTy.getRank())); + axes = rewriter.create( + loc, defaultAxesType, arangeLength, none, none, none, none); + } + + // Binding `steps` from its arguments or through a default value + Value steps; + if (binder.getNumOperands() >= 5) { + if (binder.tensorOperandAtIndex(steps, 4)) { + return failure(); + } + } else { + // The default `steps` value is a 1d tensor filled with ones with a + // size of the dimension of the operand + Value none = rewriter.create(loc); + auto defaultStepsType = Torch::ValueTensorType::get( + context, ArrayRef{operandTy.getRank()}, + rewriter.getIntegerType(64, /*signed*/ 1)); + Value sizeStepInput = rewriter.create( + loc, rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + operandTy.getRank())); + Value sizeStepsInput = rewriter.create( + loc, + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + sizeStepInput); + steps = rewriter.create( + loc, defaultStepsType, sizeStepsInput, none, none, none, none); + } + if (!(endsTy.getRank() == 1 && startsTy.getRank() == 1 && + startSize == endSize)) + return rewriter.notifyMatchFailure( + binder.op, "Expected the rank of starts and ends tensors to be 1 " + "and their dimensions to match"); + + auto axesTorchTy = axes.getType().cast(); + auto axesTy = + axesTorchTy.toBuiltinTensor().dyn_cast(); + int64_t numAxes = axesTy.getDimSize(0); + + if (!(axesTy && numAxes == endSize)) + return rewriter.notifyMatchFailure( + binder.op, "Axes should be the same size of starts and ends"); + + auto stepsTy = steps.getType() + .cast() + .toBuiltinTensor() + .dyn_cast(); + + if (!(stepsTy && stepsTy.getDimSize(0) == endsTy.getDimSize(0))) + return rewriter.notifyMatchFailure( + binder.op, "Steps should be the same size of starts and ends"); + + Value zero = rewriter.create( + loc, rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + + auto select = [&](Value v, Value k) -> Value { + auto ty = v.getType().cast(); + auto sel = rewriter.create( + loc, + Torch::ValueTensorType::get(ty.getContext(), ArrayRef{1}, + ty.getOptionalDtype()), + v, zero, k); + Value item = rewriter.create( + loc, rewriter.getType(), sel); + return item; + }; + + llvm::SmallVector intermediateShape(operandTy.getShape()); + for (int i = 0, s = operandTy.getRank(); i < s; ++i) { + if (operandTy.getDimSize(i) != resultTy.getDimSize(i)) { + intermediateShape[i] = -1; + } + } + auto intermediateType = Torch::ValueTensorType::get( + context, intermediateShape, resultTorchType.getOptionalDtype()); + for (int i = 0; i < numAxes; ++i) { + + Value k = rewriter.create( + loc, rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value kTensor = rewriter.create( + loc, + Torch::ValueTensorType::get( + context, ArrayRef{1}, + rewriter.getIntegerType(64, /*signed*/ 1)), + k); + + Value start = select(starts, kTensor); + Value end = select(ends, kTensor); + Value axis = select(axes, kTensor); + Value step = select(steps, kTensor); + + auto sliceType = intermediateType; + if (i == numAxes - 1) + sliceType = resultTorchType; + operand = rewriter.create( + loc, sliceType, operand, axis, start, end, step); + } + + rewriter.replaceOp(binder.op, operand); + return success(); + }); patterns.onOp( "Reshape", 5, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index b2a19334ab09..91421d944129 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -863,6 +863,135 @@ func.func @test_transpose_all_permutations_4(%arg0: !torch.vtensor<[2,3,4],f32>) // ----- +// CHECK-LABEL: func.func @test_slice +func.func @test_slice(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[2],si64>, %arg2: !torch.vtensor<[2],si64>, %arg3: !torch.vtensor<[2],si64>, %arg4: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,10,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + //CHECK: %[[INDEX_TO_GRAB:.*]] = torch.constant.int 0 + + //CHECK: %[[CONST_0:.*]] = torch.constant.int 0 + //CHECK: %[[ZERO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_0:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_0:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_0:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[AXES_INDEX_VEC_0:.*]] = torch.aten.index_select %arg3, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[AXES_ELEMENT_0:.*]] = torch.aten.item %[[AXES_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg4, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_0:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %arg0, %[[AXES_ELEMENT_0:.*]], %[[STARTS_ELEMENT_0:.*]], %[[ENDS_ELEMENT_0:.*]], %[[STEPS_ELEMENT_0:.*]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,10,5],f32> + + //CHECK: %[[CONST_1:.*]] = torch.constant.int 1 + //CHECK: %[[ONE_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_1:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_1:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_1:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[AXES_INDEX_VEC_1:.*]] = torch.aten.index_select %arg3, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[AXES_ELEMENT_2:.*]] = torch.aten.item %[[AXES_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg4, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_1:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: torch.aten.slice.Tensor %[[SLICE_0:.*]], %[[AXES_ELEMENT_1:.*]], %[[STARTS_ELEMENT_1:.*]], %[[ENDS_ELEMENT_1:.*]], %[[STEPS_ELEMENT_1:.*]] : !torch.vtensor<[?,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[3,10,5],f32> + %0 = torch.operator "onnx.Slice"(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.vtensor<[20,10,5],f32>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,10,5],f32> + return %0 : !torch.vtensor<[3,10,5],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_slice_default_axes_and_slices +func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + //CHECK: %[[NONE_1:.*]] = torch.constant.none + //CHECK: %[[AXES_DEFAULT_SIZE:.*]] = torch.constant.int 3 + //CHECK: %[[DEFAULT_AXES:.*]] = torch.aten.arange %[[AXES_DEFAULT_SIZE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64> + //CHECK: %[[NONE_2:.*]] = torch.constant.none + //CHECK: %[[DEFAULT_SIZE_AMOUNT:.*]] = torch.constant.int 3 + //CHECK: %[[DEFAULT_SIZE_INPUT:.*]] = torch.prim.ListConstruct %[[DEFAULT_SIZE_AMOUNT:.*]] : (!torch.int) -> !torch.list + //CHECK: %[[DEFAULT_SIZES:.*]] = torch.aten.ones %[[DEFAULT_SIZE_INPUT:.*]], %[[NONE_2:.*]], %[[NONE_2:.*]], %[[NONE_2:.*]], %[[NONE_2:.*]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64> + //CHECK: %[[INDEX_TO_GRAB:.*]] = torch.constant.int 0 + + //CHECK: %[[CONST_0:.*]] = torch.constant.int 0 + //CHECK: %[[ZERO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_0:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_0:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_0:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[AXES_INDEX_VEC_0:.*]] = torch.aten.index_select %[[DEFAULT_AXES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[AXES_ELEMENT_0:.*]] = torch.aten.item %[[AXES_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_0:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_0:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %arg0, %[[AXES_ELEMENT_0:.*]], %[[STARTS_ELEMENT_0:.*]], %[[ENDS_ELEMENT_0:.*]], %[[STEPS_ELEMENT_0:.*]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32> + + //CHECK: %[[CONST_1:.*]] = torch.constant.int 1 + //CHECK: %[[ONE_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_1:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_1:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_1:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[AXES_INDEX_VEC_1:.*]] = torch.aten.index_select %[[DEFAULT_AXES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[AXES_ELEMENT_1:.*]] = torch.aten.item %[[AXES_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_1:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_1:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[TWO_INDEX_VEC:.*]] = torch.aten.slice.Tensor %[[SLICE_0:.*]], %[[AXES_ELEMENT_1:.*]], %[[STARTS_ELEMENT_1:.*]], %[[ENDS_ELEMENT_1:.*]], %[[STEPS_ELEMENT_1:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32> + + //CHECK: %[[CONST_2:.*]] = torch.constant.int 2 + //CHECK: %[[TWO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_2:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_2:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_2:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_2:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_2:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[AXES_INDEX_VEC_2:.*]] = torch.aten.index_select %[[DEFAULT_AXES:.*]], %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[AXES_ELEMENT_2:.*]] = torch.aten.item %[[AXES_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_2:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_2:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: torch.aten.slice.Tensor %[[TWO_INDEX_VEC:.*]], %[[AXES_ELEMENT_2:.*]], %[[STARTS_ELEMENT_2:.*]], %[[ENDS_ELEMENT_2:.*]], %[[STEPS_ELEMENT_2:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32> + %0 = torch.operator "onnx.Slice"(%arg0, %arg1, %arg2) : (!torch.vtensor<[20,10,5],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> + return %0 : !torch.vtensor<[20,10,1],f32> +} + +// CHECK-LABEL: func.func @test_slice_default_steps +func.func @test_slice_default_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[3],si64>, %arg3: !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + //CHECK: %[[NONE:.*]] = torch.constant.none + //CHECK: %[[DEFAULT_SIZE_AMOUNT:.*]] = torch.constant.int 3 + //CHECK: %[[DEFAULT_SIZE_INPUT:.*]] = torch.prim.ListConstruct %[[DEFAULT_SIZE_AMOUNT:.*]] : (!torch.int) -> !torch.list + //CHECK: %[[DEFAULT_SIZES:.*]] = torch.aten.ones %[[DEFAULT_SIZE_INPUT:.*]], %[[NONE:.*]], %[[NONE:.*]], %[[NONE:.*]], %[[NONE:.*]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64> + //CHECK: %[[INDEX_TO_GRAB:.*]] = torch.constant.int 0 + + //CHECK: %[[CONST_0:.*]] = torch.constant.int 0 + //CHECK: %[[ZERO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_0:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_0:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_0:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[AXES_INDEX_VEC_0:.*]] = torch.aten.index_select %arg3, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[AXES_ELEMENT_0:.*]] = torch.aten.item %[[AXES_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_0:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_0:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %arg0, %[[AXES_ELEMENT_0:.*]], %[[STARTS_ELEMENT_0:.*]], %[[ENDS_ELEMENT_0:.*]], %[[STEPS_ELEMENT_0:.*]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32> + + //CHECK: %[[CONST_1:.*]] = torch.constant.int 1 + //CHECK: %[[ONE_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_1:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_1:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_1:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[AXES_INDEX_VEC_1:.*]] = torch.aten.index_select %arg3, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[AXES_ELEMENT_1:.*]] = torch.aten.item %[[AXES_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_1:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_1:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[TWO_INDEX_VEC:.*]] = torch.aten.slice.Tensor %[[SLICE_0:.*]], %[[AXES_ELEMENT_1:.*]], %[[STARTS_ELEMENT_1:.*]], %[[ENDS_ELEMENT_1:.*]], %[[STEPS_ELEMENT_1:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32> + + //CHECK: %[[CONST_1:.*]] = torch.constant.int 2 + //CHECK: %[[TWO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_1:.*]] : !torch.int -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_INDEX_VEC_2:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STARTS_ELEMENT_2:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[ENDS_INDEX_VEC_2:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[ENDS_ELEMENT_2:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[AXES_INDEX_VEC_2:.*]] = torch.aten.index_select %arg3, %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[AXES_ELEMENT_2:.*]] = torch.aten.item %[[AXES_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: %[[STEPS_INDEX_VEC_2:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + //CHECK: %[[STEPS_ELEMENT_2:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int + //CHECK: torch.aten.slice.Tensor %[[TWO_INDEX_VEC:.*]], %[[AXES_ELEMENT_2:.*]], %[[STARTS_ELEMENT_2:.*]], %[[ENDS_ELEMENT_2:.*]], %[[STEPS_ELEMENT_2:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32> + %0 = torch.operator "onnx.Slice"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[20,10,5],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> + return %0 : !torch.vtensor<[20,10,1],f32> +} // CHECK-LABEL: func.func @test_reshape_negative_dim func.func @test_reshape_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 From aa7e95f7c8cde77528d273633baa8887f4795187 Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Thu, 4 Jan 2024 14:09:12 -0800 Subject: [PATCH 059/283] [torch-mlir] remove trailing whitespace from e2e test files (#2727) --- .../torch_mlir_e2e_test/test_suite/arange.py | 6 +++--- .../torch_mlir_e2e_test/test_suite/basic.py | 6 +++--- .../test_suite/control_flow.py | 4 ++-- .../test_suite/elementwise.py | 6 +++--- .../test_suite/elementwise_comparison.py | 2 +- .../torch_mlir_e2e_test/test_suite/matmul.py | 16 +++++++-------- .../test_suite/norm_like.py | 10 +++++----- .../test_suite/reduction.py | 20 +++++++++---------- .../test_suite/reshape_like.py | 4 ++-- .../test_suite/slice_like.py | 6 +++--- .../torch_mlir_e2e_test/test_suite/squeeze.py | 2 +- .../test_suite/type_conversion.py | 2 +- 12 files changed, 42 insertions(+), 42 deletions(-) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py index 8237d2601711..be41d71edbe3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py @@ -254,7 +254,7 @@ def ArangeFalsePinMemoryModule_basic(module, tu: TestUtils): class ArangeStartOutModule(torch.nn.Module): def __init__(self): super().__init__() - + @export @annotate_args([ None, @@ -270,7 +270,7 @@ def ArangeStartOutModule_basic(module, tu: TestUtils): class ArangeStartOutViewModule(torch.nn.Module): def __init__(self): super().__init__() - + @export @annotate_args([ None, @@ -286,7 +286,7 @@ def ArangeStartOutViewModule_basic(module, tu: TestUtils): class ArangeStartOutDtypeModule(torch.nn.Module): def __init__(self): super().__init__() - + @export @annotate_args([ None, diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 20bc293e796b..a68d229faf39 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -13,10 +13,10 @@ # ============================================================================== class ScalarConstantTupleModule(torch.nn.Module): - + def __init__(self): super().__init__() - + @export @annotate_args([ None, @@ -4490,7 +4490,7 @@ class OneHotModule(torch.nn.Module): def __init__(self): super().__init__() - + @export @annotate_args([None, ([-1], torch.long, True)]) def forward(self, x): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py index 5c00a75e06da..6f8240f54d89 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/control_flow.py @@ -28,7 +28,7 @@ def forward(self, x): for i in range(x_val): sum += i return sum - + @register_test_case(module_factory=lambda: TorchPrimLoopForLikeModule()) def TorchPrimLoopForLikeModule_basic(module, tu: TestUtils): @@ -50,7 +50,7 @@ def forward(self, x): while(x_val > sum): sum += 1 return sum - + @register_test_case(module_factory=lambda: TorchPrimLoopWhileLikeModule()) def TorchPrimLoopWhileLikeModule_basic(module, tu: TestUtils): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 15e45b52eb44..2b86aed35e52 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -3184,7 +3184,7 @@ def ElementwiseAtenLogicalOrOpRandomFloatModule_basic(module, tu: TestUtils): class ElementwiseAtenLogicalOrOpNegativeModule(torch.nn.Module): def __init__(self): super().__init__() - + @export @annotate_args([ None, @@ -3203,7 +3203,7 @@ def ElementwiseAtenLogicalOrOpNegativeModule_basic(module, tu: TestUtils): class ElementwiseAtenLogicalOrOpBrodcastModule(torch.nn.Module): def __init__(self): super().__init__() - + @export @annotate_args([ None, @@ -4089,7 +4089,7 @@ def ElementwiseBitwiseAndScalarInt8Module_basic(module, tu: TestUtils): class GluStaticModule(torch.nn.Module): def __init__(self): super().__init__() - + @export @annotate_args([ None, diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py index ac04eeb41109..6248ef5aa32c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise_comparison.py @@ -779,7 +779,7 @@ def __init__(self): def forward(self): input = [True, False, True, True, False] return torch.ops.aten.all(input) - + @register_test_case(module_factory=lambda: AllBoolFalseModule()) def AllBoolFalseModule_basic(module, tu: TestUtils): module.forward() diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index e59279ab57f7..ae7ea72031a5 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -28,7 +28,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: MatmulDot()) def Matmul_dot(module, tu: TestUtils): module.forward(tu.rand(3), tu.rand(3)) - + # ============================================================================== class Matmul2D(torch.nn.Module): @@ -48,7 +48,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: Matmul2D()) def Matmul_2d(module, tu: TestUtils): module.forward(tu.rand(3, 4), tu.rand(4, 5)) - + # ============================================================================== class MatmulVecMat(torch.nn.Module): @@ -68,7 +68,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: MatmulVecMat()) def Matmul_vecmat(module, tu: TestUtils): module.forward(tu.rand(4), tu.rand(4, 5)) - + # ============================================================================== class MatmulMatVec(torch.nn.Module): @@ -88,7 +88,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: MatmulMatVec()) def Matmul_matvec(module, tu: TestUtils): module.forward(tu.rand(4, 5), tu.rand(5)) - + # ============================================================================== class Matmul3D(torch.nn.Module): @@ -108,7 +108,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: Matmul3D()) def Matmul_3d(module, tu: TestUtils): module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4)) - + # ============================================================================== class Matmul4d(torch.nn.Module): @@ -128,7 +128,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: Matmul4d()) def Matmul_4d(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6)) - + # ============================================================================== class Matmul4dStatic(torch.nn.Module): @@ -188,7 +188,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: MatmulSingleDynamicBatchDim()) def MatmulSingleDynamicBatchDim_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6, 7), tu.rand(4, 5, 7, 6)) - + # ============================================================================== class MatmulBroadcastBatchDim(torch.nn.Module): @@ -208,7 +208,7 @@ def forward(self, lhs, rhs): @register_test_case(module_factory=lambda: MatmulBroadcastBatchDim()) def MatmulBroadcastBatchDim_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6, 7), tu.rand(5, 7, 6)) - + # ============================================================================== class Mv(torch.nn.Module): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py index 59a251082303..3b17f516f9e5 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py @@ -130,7 +130,7 @@ def __init__(self): ]) def forward(self, x, weight, bias, running_mean, running_var): return torch.ops.aten.batch_norm( - x, weight, bias, running_mean, running_var, training=False, + x, weight, bias, running_mean, running_var, training=False, momentum=0.1, eps=0.00001, cudnn_enabled=False) @@ -156,7 +156,7 @@ def __init__(self): ]) def forward(self, x, weight, bias, running_mean, running_var): return torch.ops.aten.native_batch_norm( - x, weight, bias, running_mean, running_var, training=False, + x, weight, bias, running_mean, running_var, training=False, momentum=0.1, eps=0.00001) @@ -182,7 +182,7 @@ def __init__(self): ]) def forward(self, x, weight, bias, running_mean, running_var): return torch.ops.aten.native_batch_norm( - x, weight, bias, running_mean, running_var, training=False, + x, weight, bias, running_mean, running_var, training=False, momentum=0.1, eps=0.00001) @@ -208,7 +208,7 @@ def __init__(self): ]) def forward(self, x, weight, bias, running_mean, running_var): return torch.ops.aten.native_batch_norm( - x, weight, bias, running_mean, running_var, training=False, + x, weight, bias, running_mean, running_var, training=False, momentum=0.1, eps=0.00001) @@ -233,7 +233,7 @@ def __init__(self): ]) def forward(self, x, bias, running_mean, running_var): return torch.ops.aten.native_batch_norm( - x, None, bias, running_mean, running_var, training=False, + x, None, bias, running_mean, running_var, training=False, momentum=0.1, eps=0.00001) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 75e6eb261196..8418d1ae8f5a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -826,7 +826,7 @@ def __init__(self): @export @annotate_args([ - None, + None, ([-1, -1], torch.float32, True), ]) def forward(self, a): @@ -908,7 +908,7 @@ def __init__(self): @export @annotate_args([ - None, + None, ([-1, -1], torch.float32, True), ]) def forward(self, a): @@ -1068,8 +1068,8 @@ def NormScalarOptDimKeepDimModule_basic(module, tu: TestUtils): class ReduceFrobeniusNormModule(torch.nn.Module): def __init__(self) -> None: super().__init__() - - @export + + @export @annotate_args([ None, ([-1, -1, -1], torch.float32, True), @@ -1086,8 +1086,8 @@ def ReduceFrobeniusNormModule_basic(module, tu: TestUtils): class ReduceFrobeniusNormKeepDimModule(torch.nn.Module): def __init__(self) -> None: super().__init__() - - @export + + @export @annotate_args([ None, ([-1, -1, -1], torch.float32, True), @@ -1104,8 +1104,8 @@ def ReduceFrobeniusNormKeepDimModule_basic(module, tu: TestUtils): class LinalgVectorNormModule(torch.nn.Module): def __init__(self) -> None: super().__init__() - - @export + + @export @annotate_args([ None, ([-1, -1, -1], torch.float32, True), @@ -1122,8 +1122,8 @@ def LinalgVectorNormModule_basic(module, tu: TestUtils): class LinalgVectorNormKeepDimModule(torch.nn.Module): def __init__(self) -> None: super().__init__() - - @export + + @export @annotate_args([ None, ([-1, -1, -1], torch.float32, True), diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py index a73435c3c1ad..73371058cf46 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reshape_like.py @@ -708,8 +708,8 @@ def UnsafeView1DFoldModule_basic(module, tu: TestUtils): class ReshapeAsModule(torch.nn.Module): def __init__(self) -> None: super().__init__() - - @export + + @export @annotate_args([ None, ([4, 3], torch.float32, True), diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py index e5d31fe9cf19..8014758a7411 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -456,7 +456,7 @@ def __init__(self): ]) def forward(self, x): return torch.ops.aten.narrow(x, dim=0, start=0, length=2) - + @register_test_case(module_factory=lambda: NarrowHorizontalTest()) def NarrowHorizontalTest_basic(module, tu: TestUtils): @@ -495,7 +495,7 @@ def __init__(self): ]) def forward(self, x): return torch.ops.aten.narrow(x, dim=0, start=0, length=2) - + @register_test_case(module_factory=lambda: NarrowHorizontalTest2()) def NarrowHorizontalTest2_basic(module, tu: TestUtils): @@ -738,7 +738,7 @@ def SplitTensorGetItem_Module_basic(module, tu: TestUtils): class SplitTensorListUnpackModule(torch.nn.Module): def __init__(self): super().__init__() - + @export @annotate_args([ None, diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/squeeze.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/squeeze.py index 8b7cf957ac78..078f3483bed8 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/squeeze.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/squeeze.py @@ -96,7 +96,7 @@ def forward(self, a): module_factory=lambda: SqueezeDimStaticModule()) def SqueezeDimModule_static(module, tu: TestUtils): module.forward(tu.rand(1, 7)) - + # ============================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py index 6e04c5fa8700..9c85eb873326 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/type_conversion.py @@ -275,7 +275,7 @@ def forward(self, x, y): @register_test_case(module_factory=lambda: TypeAsDifferentModule()) def TypeAsDifferentModule_basic(module, tu: TestUtils): module.forward( - tu.randint(3, 5, low=0, high=10, dtype=torch.int), + tu.randint(3, 5, low=0, high=10, dtype=torch.int), tu.randint(3, 5, low=0, high=10, dtype=torch.int64) ) From fb1dfa31268c59a829ded35d304969c48ede388b Mon Sep 17 00:00:00 2001 From: Kunwar Grover Date: Fri, 5 Jan 2024 04:03:41 +0530 Subject: [PATCH 060/283] Bump llvm-project to 6b65d79fbb4682468333cea42b62f15c2dffd8f3 (#2723) Co-authored-by: hanhanW --- externals/llvm-project | 2 +- lib/Dialect/TMTensor/IR/TMTensorOps.cpp | 21 ++++++++----------- lib/Dialect/Torch/IR/TorchOps.cpp | 14 +++++++++++++ .../Torch/Transforms/InlineGlobalSlots.cpp | 2 +- 4 files changed, 25 insertions(+), 14 deletions(-) diff --git a/externals/llvm-project b/externals/llvm-project index 99045b60b575..6b65d79fbb46 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 99045b60b57571079f9cb4aea57870692523fbe8 +Subproject commit 6b65d79fbb4682468333cea42b62f15c2dffd8f3 diff --git a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index dcb2f4215891..ec399fe9633e 100644 --- a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -166,7 +166,6 @@ static void matmul(OpBuilder &b, Location loc, Value lhs, ValueRange lhsSizes, }) ->getResult(0); b.create(loc, sum, output, localIVs); - b.create(loc); }); } @@ -229,13 +228,15 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, SmallVector(weightRank, one), init, [&](OpBuilder &b, Location loc, ValueRange localIVs, ValueRange accs) { - b.create( - loc, init, - [&](OpBuilder &b, Location loc, Value elem, Value acc) { - Value x = b.create(loc, weight, localIVs); - Value max = b.create(loc, x, acc); - b.create(loc, max); - }); + auto reduceOp = b.create(loc, init); + // Build reduce body. + Block &reductionBody = reduceOp.getReductions()[0].front(); + auto bodyBuilder = OpBuilder::atBlockEnd(&reductionBody); + Value acc = reductionBody.getArgument(0); + Value x = + bodyBuilder.create(loc, weight, localIVs); + Value max = bodyBuilder.create(loc, x, acc); + bodyBuilder.create(loc, max); }) .getResult(0); // weight = (weight - max(weight)) / math.sqrt(querySizes[-1]) @@ -247,7 +248,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, x = b.create(loc, x, globalMax); x = b.create(loc, x, scaleFactor); b.create(loc, x, weight, localIVs); - b.create(loc); }); // calculate exp(weight) SmallVector min(weightRank, zero), @@ -258,7 +258,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, Value x = b.create(loc, weight, localIVs); x = b.create(loc, x); b.create(loc, x, weight, localIVs); - b.create(loc); }); Value expWeightSum = b.create( loc, @@ -290,7 +289,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, Value y = b.create(loc, weight, coords); Value sum = b.create(loc, x, y); b.create(loc, sum, expWeightSum, outsideDims); - b.create(loc); }); }); // calculate exp(weight) / sum(exp(weight)) @@ -305,7 +303,6 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, Value sum = b.create(loc, expWeightSum, sumIVs); x = b.create(loc, x, sum); b.create(loc, x, weight, localIVs); - b.create(loc); }); // output = weight @ value diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 32a550ce813a..e63a4e376013 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -715,6 +715,8 @@ OpFoldResult AtenNeBoolOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) { + if (getOperand().getType() != getResult().getType()) + return nullptr; if (auto tensorType = getOperand().getType().dyn_cast()) { if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) return getOperand(); @@ -727,6 +729,8 @@ OpFoldResult AtenSqueezeOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) { + if (getOperand(0).getType() != getResult().getType()) + return nullptr; if (auto tensorType = getOperand(0).getType().dyn_cast()) { if (tensorType.hasSizes() && tensorType.getSizes().size() == 0) return getOperand(0); @@ -739,6 +743,8 @@ OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { + if (getSelf().getType() != getResult().getType()) + return nullptr; if (auto selfType = getSelf().getType().dyn_cast()) { if (selfType.hasDtype() && selfType.getDtype().isa()) return getSelf(); @@ -911,6 +917,8 @@ OpFoldResult AtenViewOp::fold(FoldAdaptor adaptor) { auto resType = getType().dyn_cast(); if (!resType || !resType.hasSizes() || resType.getSizes().size() != 1) return nullptr; + if (inputType != resType) + return nullptr; // Fold when both the input tensor and result are unity rank tensors. return getOperand(0); } @@ -2441,6 +2449,8 @@ OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) { auto list = getOperand(0).getDefiningOp(); if (!list || !list->hasOneUse() || list.getElements().size() != 1) return nullptr; + if (list.getElements()[0].getType() != getResult().getType()) + return nullptr; return list.getElements()[0]; } @@ -2451,6 +2461,8 @@ OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) { auto inType = getOperand(0).getType().dyn_cast(); auto outType = getResult().getType().dyn_cast(); + if (inType != outType) + return nullptr; if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes()) return nullptr; if (inType.getSizes().size() != outType.getSizes().size() || @@ -2480,6 +2492,8 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { auto inType = getOperand(0).getType().dyn_cast(); auto outType = getResult().getType().dyn_cast(); + if (inType != outType) + return nullptr; if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes()) return nullptr; if (inType.getSizes().size() != outType.getSizes().size() || diff --git a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp index 76b57fe8c9a3..c67e6dc0d3a7 100644 --- a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp +++ b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp @@ -95,7 +95,7 @@ static bool isUseTreatedWithValueSemantics(OpOperand &use) { class InlineGlobalSlotsAnalysisState : public AnalysisState { public: InlineGlobalSlotsAnalysisState(ProgramPoint point) : AnalysisState(point) { - setSafe(); + (void)setSafe(); } void print(raw_ostream &os) const override { From 6096fcb347691982d721d74d96794ac0d17af0d9 Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Thu, 4 Jan 2024 17:30:05 -0800 Subject: [PATCH 061/283] [OnnxToTorch] Delete unused variables. (#2728) --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 7 ------- 1 file changed, 7 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index f0e11ad1cd13..d88ca9d6d5cb 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -669,15 +669,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } return success(); } - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); int64_t adjustmentInt = cast(data.getType()).getSizes().size(); - Value adjustment = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - adjustmentInt)); // convert axes (tensor) into torch int list while dealing with neg axis for (int i = 0; i < axes.size(); i++) { // Go through the axes list and get each dim in the list From 985e7796a4e4c2b939c4c350047db2473fcdc8f2 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 5 Jan 2024 15:16:49 -0800 Subject: [PATCH 062/283] [linalg] Added `aten.clamp` support with integers to `torch-to-linalg` (#2718) The lowering for `aten.clamp` did not support integer types. Added support for integer types including a signed integer test. --- .../TorchToLinalg/Uncategorized.cpp | 55 +++++++++++++------ .../test_suite/elementwise.py | 28 ++++++++++ 2 files changed, 65 insertions(+), 18 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 0943534dbd9c..f742ded3f1bd 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1007,13 +1007,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, pred, lhs, rhs); } if (auto clamp = dyn_cast(op)) { - Type dtype = converter->convertType(clamp.getType()) - .cast() - .getElementType(); - if (!dtype.isa()) { - clamp.emitError("unimplemented: non-floating point dtype"); - return nullptr; - } AtenClampOp::Adaptor adaptor(operands); auto min = adaptor.getMin(); auto max = adaptor.getMax(); @@ -1022,19 +1015,45 @@ static Value createLinalgPayloadCalculationForElementwiseOp( clamp.emitError("unimplemented: runtime optional type"); return nullptr; } - auto result = payloadArgs[0]; - if (!min.getType().isa()) { - auto minPromoted = convertScalarToDtype(b, loc, min, dtype); - auto pred = b.create(loc, arith::CmpFPredicate::ULT, - result, minPromoted); - result = b.create(loc, pred, minPromoted, result); + + Type dtype = converter->convertType(clamp.getType()) + .cast() + .getElementType(); + if (!dtype.isa()) { + clamp.emitError("unimplement type for clamp"); + return nullptr; } - if (!max.getType().isa()) { - auto maxPromoted = convertScalarToDtype(b, loc, max, dtype); - auto pred = b.create(loc, arith::CmpFPredicate::UGT, - result, maxPromoted); - result = b.create(loc, pred, maxPromoted, result); + + Type dstOriginalDtype = clamp.getType().cast().getDtype(); + bool isUnsigned = isa(dstOriginalDtype); + if (auto intTy = dstOriginalDtype.dyn_cast()) { + isUnsigned = intTy.isUnsigned(); } + auto cmpSelect = [&](Value input, Value clamp, bool getMax) -> Value { + clamp = convertScalarToDtype(b, loc, clamp, dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/dstOriginalDtype); + + Value pred; + if (dtype.isa()) { + auto cmp = + getMax ? arith::CmpFPredicate::UGT : arith::CmpFPredicate::ULT; + pred = b.create(loc, cmp, input, clamp); + } else if (dtype.isa()) { + auto cmp = + isUnsigned ? arith::CmpIPredicate::ult : arith::CmpIPredicate::slt; + if (getMax) + cmp = arith::invertPredicate(cmp); + pred = b.create(loc, cmp, input, clamp); + } + return b.create(loc, pred, clamp, input); + }; + + auto result = payloadArgs[0]; + if (!min.getType().isa()) + result = cmpSelect(result, min, /*getMax=*/false); + if (!max.getType().isa()) + result = cmpSelect(result, max, /*getMax=*/true); return result; } if (auto clampTensor = dyn_cast(op)) { diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 2b86aed35e52..c18c9103d888 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -988,6 +988,34 @@ def ElementwiseClampTensorIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseClampTensorInt8Module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int8, True) + ]) + def forward(self, x): + min = -5 + max = 5 + min_clamp = torch.clamp(x, min) + max_clamp = torch.clamp(x, max=max) + both_clamp = torch.clamp(x, min=min, max=max) + return min_clamp, max_clamp, both_clamp + + +@register_test_case(module_factory=lambda: ElementwiseClampTensorInt8Module()) +def ElementwiseClampTensorInt8Module_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 5, low=-10, high=10, dtype=torch.int8)) + + +# ============================================================================== + + + class ElementwiseClampMinTensorFloatModule(torch.nn.Module): def __init__(self): From 4dd17f0b711523223e41606d9b7b6023b9149d46 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Mon, 8 Jan 2024 14:26:38 -0800 Subject: [PATCH 063/283] Fixing implicit double->float truncation warnings. (#2733) Floating-point literals should use the correct type specifier. --- lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp | 4 ++-- lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp | 10 +++++----- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 86f23bee162c..83aa0a185fac 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -184,8 +184,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.tensorOperandAtIndex(bias, 2) || binder.tensorOperandAtIndex(runningMean, 3) || binder.tensorOperandAtIndex(runningVar, 4) || - binder.f32FloatAttr(momentum, "momentum", 0.9) || - binder.f32FloatAttr(eps, "epsilon", 1e-05) || + binder.f32FloatAttr(momentum, "momentum", 0.9f) || + binder.f32FloatAttr(eps, "epsilon", 1e-05f) || binder.tensorResultType(resultType)) return failure(); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index d154edb1ab75..ee1ae6bb65be 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -34,8 +34,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value tensorOperand; float alpha, beta; if (binder.tensorOperand(tensorOperand) || - binder.f32FloatAttr(alpha, "alpha", 0.2) || - binder.f32FloatAttr(beta, "beta", 0.5) || + binder.f32FloatAttr(alpha, "alpha", 0.2f) || + binder.f32FloatAttr(beta, "beta", 0.5f) || binder.tensorResultType(resultType)) return failure(); @@ -276,8 +276,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.tensorOperandAtIndex(c, 2) || binder.s64IntegerAttr(transA, "transA", 0) || binder.s64IntegerAttr(transB, "transB", 0) || - binder.f32FloatAttr(alpha, "alpha", 1.0) || - binder.f32FloatAttr(beta, "beta", 1.0) || + binder.f32FloatAttr(alpha, "alpha", 1.0f) || + binder.f32FloatAttr(beta, "beta", 1.0f) || binder.tensorResultType(resultType)) return failure(); @@ -417,7 +417,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( float alpha; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType) || - binder.f32FloatAttr(alpha, "alpha", 0.01)) + binder.f32FloatAttr(alpha, "alpha", 0.01f)) return failure(); Value constAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), From 07d0645f640bdc8b09a706150fa1a9a5f85b8147 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Tue, 9 Jan 2024 13:14:10 -0600 Subject: [PATCH 064/283] [RFC] general support for Adaptive Pooling Ops (#2661) Adaptive pooling ops can only be decomposed into their non-adaptive counterparts in trivial cases. For example, the current decomposition for AtenAdaptiveAvgPool1dOp in DecomposeComplexOps.cpp supports outSize = inSize (i.e., do literally nothing), and outSize = 1 (i.e., do a batched average). The reason adaptive pooling ops are difficult to lower to linalg is that they are not constantly strided. They are computed by taking an input tensor of shape (N, C, Hin), and an output size Hout, and computing the output tensor at position (n,c, h) in the following way: 1. compute st(h) = (h*Hin)//Hout 2. compute en(h) = 1 + ((h+1)*Hin -1)//Hout 3. apply a computation (max or avg) to the slice: INPUT[n, c, st(h):en(h)] The provided sample implementation (for ConvertAtenAdaptiveAvgPool1dOp) uses tensor.extract to access the input tensor inside the payload of a linalg generic op. This is likely an unattractive use of linalg generic ops, which is why I am asking for some more targeted feedback on the validity of this approach before attempting to support the many other adaptive pooling ops. Specifically: - Is the performance of this implementation bad enough to warrant targeting different dialects entirely? e.g. TMtensor/linalg ext/ etc. - If the provided implementation is of acceptable performance to the community, then is it permissable to remove the Adaptive pooling decompositions from DecomposeComplexOps.cpp? Based on the current structure of the -torch-decompose-complex-ops pass, it does not seem possible to only decompose the adaptive ops in special cases (it seems to get stuck in an infinite loop on a match failure). I would be happy to instead incorporate the case logic into the conversion directly, and remove the decompositions once they are rendered completely obsolete. As long as this approach is acceptable, I can clean up the implementation with some helper functions, and quickly add support for each of the remaining Adaptive pooling ops. --- lib/Conversion/TorchToLinalg/Pooling.cpp | 237 ++++++++++++++++-- projects/pt1/e2e_testing/xfail_sets.py | 3 + projects/pt1/python/torch_mlir/__init__.py | 2 +- .../torch_mlir_e2e_test/test_suite/pooling.py | 71 +++++- 4 files changed, 280 insertions(+), 33 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 87419f0935ab..20c03f5ffeec 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -90,18 +90,19 @@ static LogicalResult createPoolingOp( SmallVector lowPaddingIncludingNC = {0, 0}; lowPaddingIncludingNC.append(paddingInts); SmallVector highPaddingIncludingNC = lowPaddingIncludingNC; - + if (ceilMode) { for (int64_t i = 0; i < dimensionality; ++i) { highPaddingIncludingNC[i + 2] += strideInts[i]; } } - Value initValue = rewriter.create(loc, cast(initValueAttr)); + Value initValue = + rewriter.create(loc, cast(initValueAttr)); paddedInput = torch_to_linalg::getPaddedTensor( op, rewriter, self, lowPaddingIncludingNC, highPaddingIncludingNC, initValue); - + Value N = getDimOp(rewriter, loc, self, 0); Value C = getDimOp(rewriter, loc, self, 1); @@ -141,7 +142,6 @@ static LogicalResult createPoolingOp( return success(); } - namespace { class ConvertAtenMaxPool2dOp : public OpConversionPattern { public: @@ -163,7 +163,8 @@ class ConvertAtenMaxPool2dOp : public OpConversionPattern { bool ceilMode; SmallVector kernelSizeIntValues; SmallVector strideInts, paddingInts, dilationInts; - if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts))) + if (!matchPattern(op.getDilation(), + m_TorchListOfConstantInts(dilationInts))) return rewriter.notifyMatchFailure(op, "only support constant int dilations"); if (failed(checkAndGetPoolingParameters( @@ -241,7 +242,8 @@ class ConvertAtenMaxPool2dWithIndicesOp bool ceilMode; SmallVector kernelSizeIntValues; SmallVector strideInts, paddingInts, dilationInts; - if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts))) + if (!matchPattern(op.getDilation(), + m_TorchListOfConstantInts(dilationInts))) return rewriter.notifyMatchFailure(op, "only support constant int dilations"); if (failed(checkAndGetPoolingParameters( @@ -372,7 +374,6 @@ class ConvertAtenMaxPool2dWithIndicesOp }; } // namespace - namespace { template class ConvertAtenAvgPoolOp : public OpConversionPattern { @@ -383,7 +384,7 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - + Location loc = op->getLoc(); const TypeConverter *typeConverter = this->getTypeConverter(); Value self = adaptor.getSelf(); @@ -397,9 +398,9 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { bool ceilMode; SmallVector kernelSizeIntValues; SmallVector strideInts, paddingInts, dilationInts(Dim, 1); - if (failed(checkAndGetPoolingParameters( - op, rewriter, typeConverter, ceilMode, kernelSizeIntValues, - strideInts, paddingInts))) + if (failed(checkAndGetPoolingParameters(op, rewriter, typeConverter, + ceilMode, kernelSizeIntValues, + strideInts, paddingInts))) return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); // TODO: Add support for count_include_pad equal to `False`. @@ -415,20 +416,21 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { // `sumPool` contains the result of sumpool operation over the input. Value sumPool, paddedInput; - SmallVector outTensorShape; + SmallVector outTensorShape; if (failed(createPoolingOp( op, rewriter, self, /*supportNonFPInput=*/true, ceilMode, - /*dimensionality=*/Dim, kernelSizeIntValues, strideInts, paddingInts, - dilationInts, rewriter.getZeroAttr(inputElementType), outTensorShape, - paddedInput, sumPool))) + /*dimensionality=*/Dim, kernelSizeIntValues, strideInts, + paddingInts, dilationInts, rewriter.getZeroAttr(inputElementType), + outTensorShape, paddedInput, sumPool))) return rewriter.notifyMatchFailure(op, "unable to compute sumpool"); Value divisor; if constexpr (std::is_same()) { Value kHtimeskW = rewriter.create( loc, kernelSizeIntValues[0], kernelSizeIntValues[1]); - divisor = op.getDivisorOverride().getType().template isa() - ? kHtimeskW - : adaptor.getDivisorOverride(); + divisor = + op.getDivisorOverride().getType().template isa() + ? kHtimeskW + : adaptor.getDivisorOverride(); } else { divisor = kernelSizeIntValues[0]; } @@ -436,9 +438,10 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { Value outputTensor = rewriter.create( loc, getAsOpFoldResult(outTensorShape), resultElementType); - SmallVector indexingMapsAvg(2, rewriter.getMultiDimIdentityMap(Dim+2)); + SmallVector indexingMapsAvg( + 2, rewriter.getMultiDimIdentityMap(Dim + 2)); SmallVector iteratorTypesAvg( - Dim+2, utils::IteratorType::parallel); + Dim + 2, utils::IteratorType::parallel); Value avgPool = rewriter .create( @@ -459,8 +462,188 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { return success(); } }; -} +} // namespace +/* +This section is for lowering adaptive pooling ops, which cannot generally be +decomposed into typical pooling ops. Given an input tensor of rank (N,C,Hin) and +an output spatial size Hout, an element of the output tensor at position (n, c, +h) is computed as follows. + 1. compute st(h) = (h*Hin)//Hout + 2. compute en(h) = 1 + ((h+1)*Hin - 1)//Hout + 3. apply the operation (max or avg) over input[n, c, st(h):en(h)] +This is problematic for linalg ops for a few reasons: + 1. The access to the input tensor is not constantly strided + 2. The size of the window itself is not contant: en(h) - st(h) can vary with +h! Although it is a bit like using a hammer to paint, our workaround is to use +tensor.extract to access the elements of the input tensor inside our linalg +generic op's payload. + +Current TODO's: + 1. gather most of the boilerplate out of this op and make it into an +adaptive pooling helper function. + 2. figure out what to do with the conflicting decompositions in +DecomposeComplexOps.cpp + 3. Implement more efficient passes for when the kernel-size, input spatial +dims, and output spatial dims are constant. +*/ + +namespace { +class ConvertAtenAdaptiveAvgPool1dOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenAdaptiveAvgPool1dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op->getLoc(); + const TypeConverter *typeConverter = getTypeConverter(); + + // get rank of input (same as rank of output) + int64_t rank = + adaptor.getSelf().getType().cast().getRank(); + // input operand should be NCH (i.e. rank 3) + if (rank != 3) { + return rewriter.notifyMatchFailure(op, "only supports input type NCH"); + } + + // input tensor and output shape + Value input = adaptor.getSelf(); + Value outputShape = op.getOutputSize(); + SmallVector outShapeVector; + getListConstructElements(outputShape, outShapeVector); + outShapeVector = + getTypeConvertedValues(rewriter, loc, typeConverter, outShapeVector); + Value hIn = getDimOp(rewriter, loc, input, 2); + Value hOut = outShapeVector[0]; + Value hOutIndex = castIntToIndex(rewriter, loc, hOut); + RankedTensorType inputType = input.getType().cast(); + RankedTensorType outputType = + typeConverter->convertType(op.getResult().getType()) + .cast(); + + // get elementType of input tensor + Type elementType = inputType.getElementType(); + + // make an iteration space of size kMax = 1 + ceildiv (hIn - 1) , hOut + Type boolType = rewriter.getI1Type(); + Value kIter; + Value constantOne = + rewriter.create(loc, rewriter.getIndexAttr(1)); + Value hInPlusOne = rewriter.create(loc, hIn, constantOne); + Value kMaxMinusOne = + rewriter.create(loc, hInPlusOne, hOutIndex); + Value kMax = rewriter.create(loc, constantOne, kMaxMinusOne); + kIter = rewriter.create( + loc, getAsOpFoldResult(ValueRange({kMax})), boolType); + + // need to buffer input, else there will possibly be an out of bounds access + // later buffVal = 0 for avg pooling and -inf for max pooling + Value buffVal = rewriter.create( + loc, elementType, rewriter.getFloatAttr(elementType, 0)); + SmallVector lowPadding = {0, 0, 0}; + SmallVector highPadding = {0, 0, 1}; + Value buffInput = torch_to_linalg::getPaddedTensor( + op, rewriter, input, lowPadding, highPadding, buffVal); + + // make a list of outputSizes + SmallVector outputSizes; + for (unsigned i = 0; i < rank - 1; i++) { + outputSizes.push_back(getDimOp(rewriter, loc, input, i)); + } + outputSizes.push_back(hOutIndex); + + // initialize a kernel size tensor (only for avg pooling) + Value kSizeTensor = rewriter.create( + loc, getAsOpFoldResult(ValueRange({hOutIndex})), elementType); + + // initialize an output tensor + Value initOutput = + createInitTensor(rewriter, loc, outputSizes, elementType, buffVal); + + // setup indexing maps and iterator types for linalg generic op + // for kIter (d0,d1,d2,d3) -> (d3) + // for output (d0,d1,d2,d3) -> (d0,d1,d2) + // for kSizeTensor (d0,d1,d2,d3) -> (d2) + SmallVector kIterExprs, outputExprs, kSizeTensorExprs; + for (unsigned i = 0; i < 3; i++) { + outputExprs.push_back(rewriter.getAffineDimExpr(i)); + } + kSizeTensorExprs.push_back(rewriter.getAffineDimExpr(2)); + kIterExprs.push_back(rewriter.getAffineDimExpr(3)); + SmallVector indexingMaps = AffineMap::inferFromExprList( + {kIterExprs, outputExprs, kSizeTensorExprs}); + SmallVector iteratorTypes( + 3, utils::IteratorType::parallel); + iteratorTypes.push_back(utils::IteratorType::reduction); + + Value indexOne = rewriter.create(loc, 1); + auto sumPool = rewriter.create( + loc, /*resultTensorTypes=*/ + TypeRange({initOutput.getType(), kSizeTensor.getType()}), + /*inputs=*/ValueRange({kIter}), + /*outputs=*/ValueRange({initOutput, kSizeTensor}), + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value res = args[1]; + Value ind0 = b.create(loc, 0); + Value ind1 = b.create(loc, 1); + Value ind2 = b.create(loc, 2); + Value ind3 = b.create(loc, 3); + // compute start and end indices + // st = s1( s0(ind2 * Hin) // Hout ) + Value s0 = b.create(loc, ind2, hIn); + Value s1 = b.create(loc, s0, hOutIndex); + // en = e4( 1 + e3( e2( e1( e0(ind2 + 1) * hIn ) - 1 ) // hOut ) ) + Value e0 = b.create(loc, ind2, indexOne); + Value e1 = b.create(loc, e0, hIn); + Value e2 = b.create(loc, e1, indexOne); + Value e3 = b.create(loc, e2, hOutIndex); + Value e4 = b.create(loc, indexOne, e3); + // get input element @ st + ind3: + Value wIndex = b.create(loc, s1, ind3); + Value inElt = b.create( + loc, elementType, buffInput, ValueRange({ind0, ind1, wIndex})); + // check if we extracted at windex < end index + Value cond = + b.create(loc, arith::CmpIPredicate(6), wIndex, e4); + // if inElt is in bounds, include it in the computation + // else, use buffVal = 0 (for max pool use -infinity) + Value out1 = b.create(loc, cond, inElt, buffVal); + // compute Kernel size: we store this to kwTensor + Value kSize = b.create(loc, e4, s1); + Value kSizeInt = castIndexToInt64(b, loc, kSize); + Value kSizeF = b.create(loc, elementType, kSizeInt); + // accumulate out2 to res = args[1] + Value out2 = b.create(loc, res, out1); + b.create(loc, ValueRange({out2, kSizeF})); + }); + + // make a linalg generic to divide each element by the corresponding + // Kernel Width. This step is only necessary for avg pooling. + SmallVector indexingMaps1 = + AffineMap::inferFromExprList({kSizeTensorExprs, outputExprs}); + SmallVector iteratorTypes1( + 3, utils::IteratorType::parallel); + auto output = rewriter.create( + loc, /*resultTensorTypes=*/initOutput.getType(), + /*inputs=*/sumPool.getResultTensors()[1], + /*outputs=*/sumPool.getResultTensors()[0], + /*indexingMaps=*/indexingMaps1, + /*iteratorTypes=*/iteratorTypes1, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value q = b.create(loc, args[1], args[0]); + b.create(loc, q); + }); + + rewriter.replaceOpWithNewOp(op, outputType, + output.getResultTensors()); + return success(); + } +}; +} // namespace void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, @@ -471,8 +654,12 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); - patterns.add>( - typeConverter, context); - patterns.add>( - typeConverter, context); + patterns + .add>( + typeConverter, context); + patterns + .add>( + typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 76f84344bd42..ddb4865ec535 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -257,6 +257,8 @@ # ERROR: Exception: Unsupported: missing default value for argument 0 in schema for aten.div.Tensor_mode "ElementwiseDivRoundingModeFloorModule_basic", "ElementwiseDivRoundingModeTruncModule_basic", + "AdaptiveAvgPool1dStaticLargerOutput_basic", + "AdaptiveAvgPool1dGeneralDynamic_basic", # ERROR: Exception: Unsupported op: get_attr "NumToTensorFloatModule_basic", @@ -1324,6 +1326,7 @@ ### Tests additionally passing in make_fx_tosa "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool1dStaticEvenMultiple_basic", "NativeGroupNormBackwardModule_basic", "SliceWholeTensorModule_basic", "TensorFloatModule_basic", diff --git a/projects/pt1/python/torch_mlir/__init__.py b/projects/pt1/python/torch_mlir/__init__.py index 1cf1aa0e048a..c916043c2cdd 100644 --- a/projects/pt1/python/torch_mlir/__init__.py +++ b/projects/pt1/python/torch_mlir/__init__.py @@ -248,7 +248,7 @@ def _get_for_tracing( # compiler where each backend can "own" its set of legal ops. BACKEND_LEGAL_OPS = { OutputType.TOSA: ['aten.flatten.using_ints', 'aten.native_layer_norm', 'aten.linear'], - OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints', ], + OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints','aten.adaptive_avg_pool1d'], OutputType.STABLEHLO: [], } diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index dd18545b0bc4..1c6748538a6b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -11,7 +11,6 @@ # ============================================================================== - class AdaptiveAvgPool2dNonUnitOutputSizeStaticModule(torch.nn.Module): def __init__(self): @@ -55,7 +54,6 @@ def AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic( module, tu: TestUtils): module.forward(tu.rand(1, 512, 7, 7)) - class AdaptiveAvgPool2dUnitOutputSizeStaticModule(torch.nn.Module): def __init__(self): @@ -776,12 +774,71 @@ def AvgPool1dStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class AdaptiveAvgPool1dStaticLargerOutput(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=13) + + @export + @annotate_args([ + None, + ([5, 512, 7], torch.float32, True) + ]) + def forward(self,x): + return self.aap1d(x) + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool1dStaticLargerOutput()) +def AdaptiveAvgPool1dStaticLargerOutput_basic( + module, tu: TestUtils): + module.forward(tu.rand(5, 512, 7)) + +class AdaptiveAvgPool1dStaticEvenMultiple(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=7) + + @export + @annotate_args([ + None, + ([5, 512, 147], torch.float32, True) + ]) + def forward(self,x): + return self.aap1d(x) + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool1dStaticEvenMultiple()) +def AdaptiveAvgPool1dStaticEvenMultiple_basic( + module, tu: TestUtils): + module.forward(tu.rand(5, 512, 147)) + +class AdaptiveAvgPool1dGeneralDynamic(torch.nn.Module): + + def __init__(self): + super().__init__() + self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=7) + + @export + @annotate_args([ + None, + ([-1,-1,-1], torch.float32, True) + ]) + def forward(self,x): + return self.aap1d(x) + +@register_test_case( + module_factory=lambda: AdaptiveAvgPool1dGeneralDynamic()) +def AdaptiveAvgPool1dGeneralDynamic_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10)) class AdaptiveAvgPool1dNonUnitOutputSizeStaticModule(torch.nn.Module): def __init__(self): super().__init__() - self.aap1d = torch.nn.AdaptiveAvgPool1d(7) + self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=7) @export @annotate_args([ @@ -801,7 +858,7 @@ class AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule(torch.nn.Module): def __init__(self): super().__init__() - self.aap1d = torch.nn.AdaptiveAvgPool1d(7) + self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=7) @export @annotate_args([ @@ -821,7 +878,7 @@ class AdaptiveAvgPool1dUnitOutputSizeStaticModule(torch.nn.Module): def __init__(self): super().__init__() - self.aap1d = torch.nn.AdaptiveAvgPool1d(1) + self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=1) @export @annotate_args([ @@ -841,7 +898,7 @@ class AdaptiveAvgPool1dUnitOutputSizeDynamicModule(torch.nn.Module): def __init__(self): super().__init__() - self.aap1d = torch.nn.AdaptiveAvgPool1d(1) + self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=1) @export @annotate_args([ @@ -855,4 +912,4 @@ def forward(self, x): module_factory=lambda: AdaptiveAvgPool1dUnitOutputSizeDynamicModule()) def AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic( module, tu: TestUtils): - module.forward(tu.rand(1, 512, 7)) \ No newline at end of file + module.forward(tu.rand(1, 512, 7)) From 35e8f8679220a35dae4e49b988eeef5e47747f8c Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Mon, 8 Jan 2024 14:38:49 +0000 Subject: [PATCH 065/283] [MLIR][ONNX] Add OnnxToTorch support for Dropout and Elu op Signed-Off By: Vivek Khandelwal --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 75 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 58 ++++++++++++++ 2 files changed, 133 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 83aa0a185fac..87df83101718 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -904,6 +904,62 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, lhs, rhs); return success(); }); + patterns.onOp( + "Dropout", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); + Torch::ValueTensorType resultType; + int64_t numOperands = binder.op->getNumOperands(); + SmallVector operands; + int64_t seed; + if (binder.tensorOperands(operands, numOperands) || + binder.s64IntegerAttr(seed, "seed", 0) || + binder.tensorResultTypeAtIndex(resultType, 0)) + return failure(); + + // Global Seed value is 0. + if (seed != 0) { + return rewriter.notifyMatchFailure(binder.op, + "expected seed value to be 0"); + } + + Value ratio, trainingMode; + if (numOperands == 3) { + ratio = rewriter.create(loc, operands[1]); + Value trainingModeScalar = + rewriter.create(loc, operands[2]); + Value cstOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + trainingMode = rewriter.create( + loc, trainingModeScalar, cstOne); + } else if (numOperands == 2) { + ratio = rewriter.create(loc, operands[1]); + trainingMode = rewriter.create(loc, false); + } else { + ratio = rewriter.create( + loc, rewriter.getF64FloatAttr(0.5)); + trainingMode = rewriter.create(loc, false); + } + + Value dropout = rewriter.create( + loc, resultType, /*input=*/operands[0], ratio, trainingMode); + + if (binder.op->getNumResults() == 1) { + rewriter.replaceOp(binder.op, dropout); + return success(); + } + Torch::ValueTensorType maskType; + if (binder.tensorResultTypeAtIndex(maskType, 1)) + return failure(); + Value dtype = rewriter.create( + loc, rewriter.getI64IntegerAttr( + (int64_t)torch_upstream::ScalarType::Bool)); + Value none = rewriter.create(loc); + Value mask = rewriter.create( + loc, maskType, operands[0], dtype, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none, /*memory_format=*/none); + rewriter.replaceOp(binder.op, {dropout, mask}); + return success(); + }); patterns.onOp("Equal", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -916,6 +972,25 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, lhs, rhs); return success(); }); + patterns.onOp("Elu", 6, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Location loc = binder.getLoc(); + Torch::ValueTensorType resultType; + Value input; + float alpha; + if (binder.tensorOperand(input) || + binder.f32FloatAttr(alpha, "alpha") || + binder.tensorResultType(resultType)) + return failure(); + Value cstAlpha = rewriter.create( + loc, rewriter.getF64FloatAttr(alpha)); + Value cstOne = rewriter.create( + loc, rewriter.getF64FloatAttr(1.0)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, cstAlpha, /*scale=*/cstOne, + /*input_scale=*/cstOne); + return success(); + }); patterns.onOp("Erf", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index dc4d3e163052..8ba88d22c256 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -740,3 +740,61 @@ func.func @test_concat_3d_axis_negative_3(%arg0: !torch.vtensor<[2,2,2],f32>, %a %0 = torch.operator "onnx.Concat"(%arg0, %arg1) {torch.onnx.axis = -3 : si64} : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[4,2,2],f32> return %0 : !torch.vtensor<[4,2,2],f32> } + +// CHECK-LABEL: @test_dropout +func.func @test_dropout(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.dropout %arg0, %float5.000000e-01, %false : !torch.vtensor<[3],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3],f32 + %0 = torch.operator "onnx.Dropout"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// CHECK-LABEL: @test_dropout_default +func.func @test_dropout_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.dropout %arg0, %float5.000000e-01, %false : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Dropout"(%arg0) {torch.onnx.seed = 0 : si64} : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_dropout_default_mask +func.func @test_dropout_default_mask(%arg0: !torch.vtensor<[3,4,5],f32>) -> (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1>) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.dropout %arg0, %float5.000000e-01, %false : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.ones_like %arg0, %int11, %none, %none, %none, %none : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4,5],i1> + %0:2 = torch.operator "onnx.Dropout"(%arg0) {torch.onnx.seed = 0 : si64} : (!torch.vtensor<[3,4,5],f32>) -> (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1>) + return %0#0, %0#1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1> +} + +// CHECK-LABEL: @test_dropout_default_mask_ratio +func.func @test_dropout_default_mask_ratio(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[],f32>) -> (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1>) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.dropout %arg0, %0, %false : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.ones_like %arg0, %int11, %none, %none, %none, %none : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3,4,5],i1> + %0:2 = torch.operator "onnx.Dropout"(%arg0, %arg1) {torch.onnx.seed = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[],f32>) -> (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1>) + return %0#0, %0#1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1> +} + +// CHECK-LABEL: @test_dropout_default_ratio +func.func @test_dropout_default_ratio(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.dropout %arg0, %0, %false : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Dropout"(%arg0, %arg1) {torch.onnx.seed = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_training_dropout_zero_ratio +func.func @test_training_dropout_zero_ratio(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],i1>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.dropout %arg0, %0, %2 : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Dropout"(%arg0, %arg1, %arg2) {torch.onnx.seed = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],i1>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_elu_default +func.func @test_elu_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.elu %arg0, %float0.000000e00, %float1.000000e00, %float1.000000e00 : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Elu"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// CHECK-LABEL: @test_elu_example +func.func @test_elu_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.elu %arg0, %float2.000000e00, %float1.000000e00, %float1.000000e00 : !torch.vtensor<[3],f32>, !torch.float, !torch.float, !torch.float -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Elu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32} : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} From 4707d3bdc6d7e1bb12b3e44dcf23455a7d445725 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 4 Jan 2024 15:12:51 +0000 Subject: [PATCH 066/283] [MLIR][ONNX] Add OnnxToTorch support for Bernoulli and CastLike op Signed-Off By: Vivek Khandelwal --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 116 +++++++++++++++--- .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 59 +++++++++ 2 files changed, 160 insertions(+), 15 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 87df83101718..8b6fddecfd56 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -15,6 +15,28 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::onnx_c; +static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) { + int64_t dtypeIntTorch; + // TODO: Add complete mapping. + switch (dtypeIntOnnx) { + case 1: + dtypeIntTorch = 6; // float + break; + case 10: + dtypeIntTorch = 5; // half + break; + case 11: + dtypeIntTorch = 7; // double + break; + case 16: + dtypeIntTorch = 15; // bfloat16 + break; + default: + dtypeIntTorch = -1; // No dtype + } + return dtypeIntTorch; +} + // Simple rewrites for the default domain. // See: https://onnx.ai/onnx/operators/ // For operators that are effectively version invariant, we register with @@ -311,6 +333,53 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( } return failure(); }); + patterns.onOp( + "Bernoulli", 15, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + int64_t dtypeIntOnnx, dtypeIntTorch; + if (binder.tensorOperand(input) || + binder.s64IntegerAttr(dtypeIntOnnx, "dtype", -1) || + binder.tensorResultType(resultType)) + return failure(); + + SmallString<64> name("torch.onnx."); + name.append("seed"); + auto attr = binder.op->getAttr(name); + if (attr) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented: support not present for seed attribute"); + } + + Value none = rewriter.create(binder.getLoc()); + Value bernoulli = rewriter.create( + binder.getLoc(), input.getType(), input, /*generator=*/none); + + if (dtypeIntOnnx == -1) { + // True, if dtype attribute value is not present. + rewriter.replaceOp(binder.op, bernoulli); + return success(); + } + dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (dtypeIntTorch == -1) { + return rewriter.notifyMatchFailure( + binder.op, + "unimplemented support for the given dtype conversion"); + } + Value constDtype = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + dtypeIntTorch)); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + rewriter.replaceOpWithNewOp( + binder.op, resultType, bernoulli, constDtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + return success(); + }); patterns.onOp( "BitShift", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -386,21 +455,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.tensorResultType(resultType)) return failure(); - // TODO: Add complete mapping. - switch (dtypeIntOnnx) { - case 1: - dtypeIntTorch = 6; // float - break; - case 10: - dtypeIntTorch = 5; // half - break; - case 11: - dtypeIntTorch = 7; // double - break; - case 16: - dtypeIntTorch = 15; // bfloat16 - break; - default: + dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); + if (dtypeIntTorch == -1) { return rewriter.notifyMatchFailure( binder.op, "unimplemented support for the given dtype conversion"); @@ -418,6 +474,36 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( /*memory_format=*/none); return success(); }); + patterns.onOp( + "CastLike", 15, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input, target; + if (binder.tensorOperands(input, target) || + binder.tensorResultType(resultType)) + return failure(); + + // TODO: Add support to handle the `saturate` attribute. + // Ignoring it right now, since it's only using during the float8 + // conversions which are not supported in Torch-MLIR right now. + + Torch::ValueTensorType targetTy = + target.getType().cast(); + if (!targetTy.hasDtype()) { + return rewriter.notifyMatchFailure(binder.op, + "target tensor must have a dtype"); + } + Type targetDtype = targetTy.getDtype(); + Value constDtype = Torch::getDtypeIntValueForType( + rewriter, binder.getLoc(), targetDtype); + Value none = rewriter.create(binder.getLoc()); + Value cstFalse = + rewriter.create(binder.getLoc(), false); + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, constDtype, + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + return success(); + }); patterns.onOp("Ceil", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 8ba88d22c256..08d6e4ea4e91 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -110,6 +110,25 @@ func.func @test_acos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, return %0 : !torch.vtensor<[3,4,5],f32> } +// CHECK-LABEL: @test_bernoulli +func.func @test_bernoulli(%arg0: !torch.vtensor<[10],f64>) -> !torch.vtensor<[10],f64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %0 = torch.aten.bernoulli %arg0, %[[NONE]] : !torch.vtensor<[10],f64>, !torch.none -> !torch.vtensor<[10],f64> + %0 = torch.operator "onnx.Bernoulli"(%arg0) : (!torch.vtensor<[10],f64>) -> !torch.vtensor<[10],f64> + return %0 : !torch.vtensor<[10],f64> +} + +// CHECK-LABEL: @test_bernoulli_double +func.func @test_bernoulli_double(%arg0: !torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[BERNOULLI:.*]] = torch.aten.bernoulli %arg0, %[[NONE]] : !torch.vtensor<[10],f32>, !torch.none -> !torch.vtensor<[10],f32> + // CHECK: %[[DTYPE:.*]] = torch.constant.int 7 + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %[[BERNOULLI]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[10],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[10],f64> + %0 = torch.operator "onnx.Bernoulli"(%arg0) {torch.onnx.dtype = 11 : si64} : (!torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f64> + return %0 : !torch.vtensor<[10],f64> +} + // CHECK-LABEL: @test_bitshift_left_uint8 func.func @test_bitshift_left_uint8(%arg0: !torch.vtensor<[3],ui8>, %arg1: !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8> -> !torch.vtensor<[3],ui8> @@ -323,6 +342,46 @@ func.func @test_cast_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>) -> !torc return %0 : !torch.vtensor<[3,4],f32> } +// CHECK-LABEL: @test_castlike_BFLOAT16_to_FLOAT +func.func @test_castlike_BFLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],bf16>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 6 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],bf16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.CastLike"(%arg0, %arg1) : (!torch.vtensor<[3,4],bf16>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// CHECK-LABEL: @test_castlike_DOUBLE_to_FLOAT +func.func @test_castlike_DOUBLE_to_FLOAT(%arg0: !torch.vtensor<[3,4],f64>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 6 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f64>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.CastLike"(%arg0, %arg1) : (!torch.vtensor<[3,4],f64>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// CHECK-LABEL: @test_castlike_FLOAT_to_DOUBLE +func.func @test_castlike_FLOAT_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],f64>) -> !torch.vtensor<[3,4],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 7 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f64> + %0 = torch.operator "onnx.CastLike"(%arg0, %arg1) : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],f64>) -> !torch.vtensor<[3,4],f64> + return %0 : !torch.vtensor<[3,4],f64> +} + +// CHECK-LABEL: @test_castlike_FLOAT16_to_FLOAT +func.func @test_castlike_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 6 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f16>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.CastLike"(%arg0, %arg1) : (!torch.vtensor<[3,4],f16>, !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + // CHECK-LABEL: @test_ceil_example func.func @test_ceil_example(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.ceil %arg0 : !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32> From 208ae355830707a0d1e85f05feca454ad2b346f0 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Wed, 3 Jan 2024 12:55:56 +0000 Subject: [PATCH 067/283] [MLIR][ONNX] Add TorchToOnnx Support for DepthToSpace op Signed-Off By: Vivek Khandelwal --- .../torch-mlir/Dialect/Torch/Utils/Utils.h | 5 +- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 126 +++++++++++++++++- .../Torch/Transforms/DecomposeComplexOps.cpp | 13 -- lib/Dialect/Torch/Utils/Utils.cpp | 13 ++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 68 ++++++++++ 5 files changed, 210 insertions(+), 15 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index 25d35f0f9f2b..b5c815ca7614 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -127,7 +127,10 @@ Value createInitTensor(PatternRewriter &rewriter, Location loc, // Helper to create a rank 0 tensor filled with the given `scalar`. `scalar` // would be converted to the element type of the given `inputType`. Value createRank0Tensor(PatternRewriter &rewriter, Location loc, - BaseTensorType inputType, Value scalar); + BaseTensorType inputType, Value scalar); + +LogicalResult getTransposedType(BaseTensorType inType, int64_t dimA, + int64_t dimB, Type &transposedType); } // namespace Torch } // namespace torch diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 8b6fddecfd56..8657b8f84b85 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -37,6 +37,23 @@ static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) { return dtypeIntTorch; } +static LogicalResult createTorchTransposeOp(ConversionPatternRewriter &rewriter, + Location loc, Value input, + int64_t dimA, int64_t dimB, + Value &transposed) { + Type transposedType; + if (failed(getTransposedType(input.getType().cast(), + dimA, dimB, transposedType))) + return failure(); + Value cstDimA = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimA)); + Value cstDimB = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimB)); + transposed = rewriter.create( + loc, transposedType, input, cstDimA, cstDimB); + return success(); +} + // Simple rewrites for the default domain. // See: https://onnx.ai/onnx/operators/ // For operators that are effectively version invariant, we register with @@ -978,11 +995,118 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand, dim, resultDType); return success(); }); + patterns.onOp( + "DepthToSpace", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + int64_t blockSize; + std::string mode; + if (binder.tensorOperand(input) || + binder.s64IntegerAttr(blockSize, "blocksize") || + binder.customOpNameStringAttr(mode, "mode", "DCR") || + binder.tensorResultType(resultType)) + return failure(); + auto inputTy = input.getType().dyn_cast(); + if (!inputTy || !inputTy.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected input type having sizes"); + } + SmallVector inputSizes{inputTy.getSizes()}; + if (inputSizes.size() != 4) { + return rewriter.notifyMatchFailure(binder.op, + "Expected input rank to be 4"); + } + Value b = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0))); + Value c = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1))); + Value h = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(2))); + Value w = rewriter.create( + binder.getLoc(), input, + rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(3))); + Value cstBlockSize = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(blockSize)); + Value cstBlockSizeSquare = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(blockSize * blockSize)); + Value cDivBlockSizeSquare = rewriter.create( + binder.getLoc(), c, cstBlockSizeSquare); + cDivBlockSizeSquare = rewriter.create( + binder.getLoc(), cDivBlockSizeSquare); + Value reshapeSizesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(input.getContext())), + llvm::SmallVector{b, cstBlockSize, cstBlockSize, + cDivBlockSizeSquare, h, w}); + int64_t cDivBlockSizeSquareInt = + inputSizes[1] == Torch::kUnknownSize + ? Torch::kUnknownSize + : inputSizes[1] / (blockSize * blockSize); + SmallVector reshapeSizesInt{ + inputSizes[0], blockSize, blockSize, + cDivBlockSizeSquareInt, inputSizes[2], inputSizes[3]}; + Value reshapedInput = rewriter.create( + binder.getLoc(), + inputTy.getWithSizesAndDtype(reshapeSizesInt, + inputTy.getOptionalDtype()), + input, reshapeSizesList); + + Value transposedInput; + if (mode == "DCR") { + if (failed(createTorchTransposeOp( + rewriter, binder.getLoc(), reshapedInput, + /*dimA=*/1, /*dimB=*/3, transposedInput))) + return rewriter.notifyMatchFailure( + binder.op, "Failed to create TorchTranspose op"); + if (failed(createTorchTransposeOp( + rewriter, binder.getLoc(), transposedInput, + /*dimA=*/2, /*dimB=*/4, transposedInput))) + return rewriter.notifyMatchFailure( + binder.op, "Failed to create TorchTranspose op"); + } else { + // mode == "CRD" + if (failed(createTorchTransposeOp( + rewriter, binder.getLoc(), reshapedInput, + /*dimA=*/2, /*dimB=*/4, transposedInput))) + return rewriter.notifyMatchFailure( + binder.op, "Failed to create TorchTranspose op"); + if (failed(createTorchTransposeOp( + rewriter, binder.getLoc(), transposedInput, + /*dimA=*/3, /*dimB=*/4, transposedInput))) + return rewriter.notifyMatchFailure( + binder.op, "Failed to create TorchTranspose op"); + } + if (failed(createTorchTransposeOp( + rewriter, binder.getLoc(), transposedInput, + /*dimA=*/4, /*dimB=*/5, transposedInput))) + return rewriter.notifyMatchFailure( + binder.op, "Failed to create TorchTranspose op"); + + Value hMulBlockSize = rewriter.create( + binder.getLoc(), h, cstBlockSize); + Value wMulBlockSize = rewriter.create( + binder.getLoc(), w, cstBlockSize); + reshapeSizesList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(input.getContext())), + llvm::SmallVector{b, cDivBlockSizeSquare, hMulBlockSize, + wMulBlockSize}); + rewriter.replaceOpWithNewOp( + binder.op, resultType, transposedInput, reshapeSizesList); + return success(); + }); patterns.onOp("Div", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; - std::string direction; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) return failure(); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 63fa66ccc31e..0a3ce2ea7797 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2554,19 +2554,6 @@ class DecomposeAtenConvTranspose2dOp }; } // namespace -static LogicalResult getTransposedType(BaseTensorType inType, int64_t dimA, - int64_t dimB, Type &transposedType) { - if (!inType.hasSizes()) - return failure(); - SmallVector shape(inType.getSizes()); - int64_t tmp = shape[0]; - shape[0] = shape[1]; - shape[1] = tmp; - transposedType = inType.getWithSizesAndDtype(llvm::ArrayRef(shape), - inType.getOptionalDtype()); - return success(); -} - // The convolution backward op is decomposed as follows: // inputH, inputW = input.shape[2:] // output_padding_ = [ diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 4bf5f7e13d1f..12ac1d58ee59 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -500,3 +500,16 @@ Value Torch::createRank0Tensor(PatternRewriter &rewriter, Location loc, ValueRange{}); return createInitTensor(rewriter, loc, rank0TensorTy, scalar, dimList); } + +LogicalResult Torch::getTransposedType(BaseTensorType inType, int64_t dimA, + int64_t dimB, Type &transposedType) { + if (!inType.hasSizes()) + return failure(); + SmallVector shape(inType.getSizes()); + int64_t tmp = shape[dimA]; + shape[dimA] = shape[dimB]; + shape[dimB] = tmp; + transposedType = inType.getWithSizesAndDtype(llvm::ArrayRef(shape), + inType.getOptionalDtype()); + return success(); +} diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 08d6e4ea4e91..aab9728ced1b 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -857,3 +857,71 @@ func.func @test_elu_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3 %0 = torch.operator "onnx.Elu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32} : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } + +// CHECK-LABEL: @test_depthtospace_example +func.func @test_depthtospace_example(%arg0: !torch.vtensor<[1,8,2,3],f32>) -> !torch.vtensor<[1,2,4,6],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE:.*]] = torch.aten.size.int %arg0, %[[C0]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE_0:.*]] = torch.aten.size.int %arg0, %[[C1]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE_1:.*]] = torch.aten.size.int %arg0, %[[C2]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[SIZE_2:.*]] = torch.aten.size.int %arg0, %[[C3]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C4:.*]] = torch.constant.int 4 + // CHECK: %[[DIV:.*]] = torch.aten.div.int %[[SIZE_0]], %[[C4]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[INT:.*]] = torch.aten.Int.float %[[DIV]] : !torch.float -> !torch.int + // CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[SIZE]], %[[C2_0]], %[[C2_0]], %[[INT]], %[[SIZE_1]], %[[SIZE_2]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[1,8,2,3],f32>, !torch.list -> !torch.vtensor<[1,2,2,2,2,3],f32> + // CHECK: %[[C1_0:.*]] = torch.constant.int 1 + // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[TRANSPOSE:.*]] = torch.aten.transpose.int %[[RESHAPE]], %[[C1_0]], %[[C3_0]] : !torch.vtensor<[1,2,2,2,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,2,3],f32> + // CHECK: %[[C2_1:.*]] = torch.constant.int 2 + // CHECK: %[[C4_0:.*]] = torch.constant.int 4 + // CHECK: %[[TRANSPOSE_1:.*]] = torch.aten.transpose.int %[[TRANSPOSE]], %[[C2_1]], %[[C4_0]] : !torch.vtensor<[1,2,2,2,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,2,3],f32> + // CHECK: %[[C4_1:.*]] = torch.constant.int 4 + // CHECK: %[[C5:.*]] = torch.constant.int 5 + // CHECK: %[[TRANSPOSE_2:.*]] = torch.aten.transpose.int %[[TRANSPOSE_1]], %[[C4_1]], %[[C5]] : !torch.vtensor<[1,2,2,2,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,3,2],f32> + // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[SIZE_1]], %[[C2_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[SIZE_2]], %[[C2_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[SIZE]], %5, %[[MUL]], %[[MUL_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[TRANSPOSE_2]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[1,2,2,2,3,2],f32>, !torch.list -> !torch.vtensor<[1,2,4,6],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[1,2,4,6],f32 + %0 = torch.operator "onnx.DepthToSpace"(%arg0) {torch.onnx.blocksize = 2 : si64, torch.onnx.mode = "DCR"} : (!torch.vtensor<[1,8,2,3],f32>) -> !torch.vtensor<[1,2,4,6],f32> + return %0 : !torch.vtensor<[1,2,4,6],f32> +} + +// CHECK-LABEL: @test_depthtospace_crd_mode_example +func.func @test_depthtospace_crd_mode_example(%arg0: !torch.vtensor<[1,8,2,3],f32>) -> !torch.vtensor<[1,2,4,6],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[C0:.*]] = torch.constant.int 0 + // CHECK: %[[SIZE:.*]] = torch.aten.size.int %arg0, %[[C0]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int + // CHECK: %[[C1:.*]] = torch.constant.int 1 + // CHECK: %[[SIZE_0:.*]] = torch.aten.size.int %arg0, %[[C1]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int + // CHECK: %[[C2:.*]] = torch.constant.int 2 + // CHECK: %[[SIZE_1:.*]] = torch.aten.size.int %arg0, %[[C2]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int + // CHECK: %[[C3:.*]] = torch.constant.int 3 + // CHECK: %[[SIZE_2:.*]] = torch.aten.size.int %arg0, %[[C3]] : !torch.vtensor<[1,8,2,3],f32>, !torch.int -> !torch.int + // CHECK: %[[C2_0:.*]] = torch.constant.int 2 + // CHECK: %[[C4:.*]] = torch.constant.int 4 + // CHECK: %[[DIV:.*]] = torch.aten.div.int %[[SIZE_0]], %[[C4]] : !torch.int, !torch.int -> !torch.float + // CHECK: %[[INT:.*]] = torch.aten.Int.float %[[DIV]] : !torch.float -> !torch.int + // CHECK: %[[RESHAPE_LIST:.*]] = torch.prim.ListConstruct %[[SIZE]], %[[C2_0]], %[[C2_0]], %[[INT]], %[[SIZE_1]], %[[SIZE_2]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESHAPE:.*]] = torch.aten.reshape %arg0, %[[RESHAPE_LIST]] : !torch.vtensor<[1,8,2,3],f32>, !torch.list -> !torch.vtensor<[1,2,2,2,2,3],f32> + // CHECK: %[[C2_1:.*]] = torch.constant.int 2 + // CHECK: %[[C4_0:.*]] = torch.constant.int 4 + // CHECK: %[[TRANSPOSE:.*]] = torch.aten.transpose.int %[[RESHAPE]], %[[C2_1]], %[[C4_0]] : !torch.vtensor<[1,2,2,2,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,2,3],f32> + // CHECK: %[[C3_0:.*]] = torch.constant.int 3 + // CHECK: %[[C4_1:.*]] = torch.constant.int 4 + // CHECK: %[[TRANSPOSE_1:.*]] = torch.aten.transpose.int %[[TRANSPOSE]], %[[C3_0]], %[[C4_1]] : !torch.vtensor<[1,2,2,2,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,2,3],f32> + // CHECK: %[[C4_1:.*]] = torch.constant.int 4 + // CHECK: %[[C5:.*]] = torch.constant.int 5 + // CHECK: %[[TRANSPOSE_2:.*]] = torch.aten.transpose.int %[[TRANSPOSE_1]], %[[C4_1]], %[[C5]] : !torch.vtensor<[1,2,2,2,2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[1,2,2,2,3,2],f32> + // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[SIZE_1]], %[[C2_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[SIZE_2]], %[[C2_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[RESHAPE_LIST_0:.*]] = torch.prim.ListConstruct %[[SIZE]], %5, %[[MUL]], %[[MUL_0]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RESULT:.*]] = torch.aten.reshape %[[TRANSPOSE_2]], %[[RESHAPE_LIST_0]] : !torch.vtensor<[1,2,2,2,3,2],f32>, !torch.list -> !torch.vtensor<[1,2,4,6],f32> + // CHECK: return %[[RESULT]] : !torch.vtensor<[1,2,4,6],f32 + %0 = torch.operator "onnx.DepthToSpace"(%arg0) {torch.onnx.blocksize = 2 : si64, torch.onnx.mode = "CRD"} : (!torch.vtensor<[1,8,2,3],f32>) -> !torch.vtensor<[1,2,4,6],f32> + return %0 : !torch.vtensor<[1,2,4,6],f32> +} From 469c055190a042575a2259f8fe759da23963ce3f Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 9 Jan 2024 13:10:28 +0000 Subject: [PATCH 068/283] build: manually update PyTorch version Set PyTorch and TorchVision version to nightly release 2024-01-09. Signed-Off By: Vivek Khandelwal --- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index 6d22e4c8b2c5..cf7d2b924e62 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -7003edfbb4995b3712ba46aa7e39f1256b7fa4a6 +03969cb2d2e773af44b71f304d8de81107b2d41e diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 1d07d05a1f36..a4291b382237 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.3.0.dev20240101 +torch==2.3.0.dev20240109 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index c44f5222d172..d081d22aca9b 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.18.0.dev20240101 +torchvision==0.18.0.dev20240109 From 29569713f3878226a6c1054a183dc227934dbe69 Mon Sep 17 00:00:00 2001 From: kumardeepakamd <123522031+kumardeepakamd@users.noreply.github.com> Date: Wed, 10 Jan 2024 13:05:37 -0800 Subject: [PATCH 069/283] support for onnx.expand operator (#2729) maps onnx.expand to torch aten broadcast_to, three tests added --------- Co-authored-by: Kumar Deepak --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 45 ++++++++++++++++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 54 +++++++++++++++++++ 2 files changed, 99 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 8657b8f84b85..5b957c53f9dc 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1213,6 +1213,51 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); + patterns.onOp( + "Expand", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // uses ideas and code from onnx.Reshape + Torch::ValueTensorType resultType; + Value data, shape; + if (binder.tensorOperands(data, shape) || + binder.tensorResultType(resultType)) + return failure(); + Torch::BaseTensorType shapeType = + shape.getType().cast(); + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = shapeType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype()); + // Variable to store 1-D onnx shape tensor, shapeSizes[0] has the + // dimension size + auto shapeSizes = + dyn_cast(shape.getType()).getSizes(); + // A constant zero value + Value zero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + // Variable to store pytorch int list of shape (dimension) + SmallVector dimList; + + // Convert the shape tensor from vector of int64_t to torch int list as + // we are using torch implementation Torch::AtenBroadcastToOp which + // takes list of int + for (int i = 0; i < shapeSizes[0]; i++) { + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, shape, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + dimList.push_back(dim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList); + return success(); + }); patterns.onOp("Floor", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index aab9728ced1b..fc9706127280 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -487,6 +487,7 @@ func.func @test_equal(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor return %0 : !torch.vtensor<[3,4,5],i1> } + // CHECK-LABEL: @test_floor_example func.func @test_floor_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.floor %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> @@ -800,6 +801,59 @@ func.func @test_concat_3d_axis_negative_3(%arg0: !torch.vtensor<[2,2,2],f32>, %a return %0 : !torch.vtensor<[4,2,2],f32> } +// CHECK-LABEL: @test_expand_dim2_shape2 +func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !torch.vtensor<[2],si32>) + -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si32> -> !torch.int + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32> + // CHECK: torch.aten.item %2 : !torch.vtensor<[1],si32> -> !torch.int + // CHECK: torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.broadcast_to %arg0, %4 : !torch.vtensor<[1,4],f32>, !torch.list -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} +// CHECK-LABEL: @test_expand_dim2_shape3 +func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %4 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.prim.ListConstruct %1, %3, %5 : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.broadcast_to %arg0, %6 : !torch.vtensor<[3,1],f32>, !torch.list -> !torch.vtensor<[2,3,6],f32> + %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32> + return %0 : !torch.vtensor<[2,3,6],f32> +} + +// CHECK-LABEL: @test_expand_dim3_shape4 +func.func @test_expand_dim3_shape4(%arg0: !torch.vtensor<[1,3,1],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,3,3,3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %4 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.prim.ListConstruct %1, %3, %5, %7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %9 = torch.aten.broadcast_to %arg0, %8 : !torch.vtensor<[1,3,1],f32>, !torch.list -> !torch.vtensor<[3,3,3,3],f32> + %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,3,1],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,3,3,3],f32> + return %0 : !torch.vtensor<[3,3,3,3],f32> +} // CHECK-LABEL: @test_dropout func.func @test_dropout(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.dropout %arg0, %float5.000000e-01, %false : !torch.vtensor<[3],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3],f32 From aee1fca2517b8bff3b18e3b01beaafe8d57a7dd8 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Wed, 10 Jan 2024 17:24:37 -0500 Subject: [PATCH 070/283] Minor typo fix: in not implemented message for the exclusive and reverse attributes for cumsum (#2740) --- lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 5b957c53f9dc..2cfc9940a940 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -952,11 +952,11 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( int64_t exclusive; int64_t reverse; // if bind succeeds and either is set, fail because not implemented - if (binder.s64IntegerAttr(exclusive, "exclusive", 0)) + if (!binder.s64IntegerAttr(exclusive, "exclusive", 0)) if (exclusive != 0) return rewriter.notifyMatchFailure( binder.op, "unsupported onnx.CumSum conversion: exclusive"); - if (binder.s64IntegerAttr(reverse, "reverse", 0)) + if (!binder.s64IntegerAttr(reverse, "reverse", 0)) if (reverse != 0) return rewriter.notifyMatchFailure( binder.op, "unsupported onnx.CumSum conversion: reverse"); From 0860c41ee2a0bdec41f544f19eba170cf646c3ce Mon Sep 17 00:00:00 2001 From: Frederik Harwath Date: Fri, 22 Dec 2023 06:25:15 -0800 Subject: [PATCH 071/283] Implement aten.reflection_pad2d lowering to linalg --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 24 ++ lib/Conversion/TorchToLinalg/DataMovement.cpp | 290 ++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 67 ++++ .../base_lazy_backend/shape_inference.cpp | 28 ++ .../build_tools/abstract_interp_lib_gen.py | 29 ++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/__init__.py | 1 + .../torch_mlir_e2e_test/test_suite/padding.py | 113 +++++++ 8 files changed, 553 insertions(+) create mode 100644 projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 23e65d75d77f..74a2e2327d1b 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7893,6 +7893,30 @@ def Torch_AtenReflectionPad1dOp : Torch_Op<"aten.reflection_pad1d", [ }]; } +def Torch_AtenReflectionPad2dOp : Torch_Op<"aten.reflection_pad2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$padding + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenReflectionPad2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenReflectionPad2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenPadOp : Torch_Op<"aten.pad", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 6534e859881e..49f5f0ec3321 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -244,6 +244,294 @@ class ConvertAtenReflectionPad1dOp }; } +namespace { + +// Lower the aten.reflection.pad_2d operator into a sequence of +// tensor.extract_slice, linalg.generic, and tensor_insert_slice +// operations. + +// To understand the lowering, consider this pytorch example: +// +// >>> t = torch.tensor([[[1.0,2,3],[4,5,6], [7,8,9]]]) +// >>> t +// tensor([[[1., 2., 3.], +// [4., 5., 6.], +// [7., 8., 9.]]]) +// >>> torch.ops.aten.reflection_pad2d(t, [1,2,1,2]) +// tensor([[[5., 4., 5., 6., 5., 4.], +// [2., 1., 2., 3., 2., 1.], +// [5., 4., 5., 6., 5., 4.], +// [8., 7., 8., 9., 8., 7.], +// [5., 4., 5., 6., 5., 4.], +// [2., 1., 2., 3., 2., 1.]]]) +// +// The result can be subdivided into "tiles" corresponding to either +// the input tensor (in the center) or slices of the input tensor +// whose width and height is determined by the padding sizes and which +// are reflected through the side of the central input tensor that +// they touch. +// In the example above, the tiles are: +// top left: [[5]] +// top center: [[4,5,6]] +// top right: [[5,4]] +// center left [[2,1],[5,4],[8,7]] +// center: copy of the input tensor +// center right: [[2,1],[5,4],[8,7]] +// bottom left: [[5,4],[2,1]] +// center bottom: [[2,3,2]] +// center right: [[2,1]] +// +// The lowering uses a tensor.extract_slice operation to create each tile, +// a linalg.generic for the reflection, and a tensor.insert_slice to +// insert the tile in the resulting tensor. +class ConvertAtenReflectionPad2dOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenReflectionPad2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + SmallVector padInts; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) + return rewriter.notifyMatchFailure( + op, "only support constant int pad ranges"); + + Location loc = op.getLoc(); + // Some generic helper functions for creating arithmetic operations. + auto createAdd = [&](Value x, Value y) { + return rewriter.create(loc, x, y); + }; + + auto createAdds = [&](std::initializer_list values) { + assert(values.size() >= 2); + return std::accumulate(values.begin() + 1, values.end(), data(values)[0], + createAdd); + }; + + auto createSub = [&](Value x, Value y) { + return rewriter.create(loc, x, y); + }; + + auto createSubs = [&](std::initializer_list values) { + assert(values.size() >= 2); + return std::accumulate(values.begin() + 1, values.end(), data(values)[0], + createSub); + }; + + // Enums for specifying the coordinates of a tile. An "h" prefix + // is used to stand for "horizontal" and "v" for "vertical" + // throughout. + enum PadHLoc { LEFT = 0, RIGHT = 1, HCENTER = 2 }; + enum PadVLoc { TOP = 0, BOTTOM = 1, VCENTER = 2 }; + + // Helper functions for obtaining information about the operator's + // padding arguments. + auto getHPadArgument = [&](PadHLoc l) { + assert(l < HCENTER); + return padInts[l]; + }; + + auto getVPadArgument = [&](PadVLoc l) { + assert(l < VCENTER); + return padInts[2 + l]; + }; + + auto shouldCreateTile = [&](PadVLoc v, PadHLoc h) { + if (!(h == HCENTER || getHPadArgument(h) > 0)) + return false; + if (!(v == VCENTER || getVPadArgument(v) > 0)) + return false; + + return true; + }; + + Value input = adaptor.getSelf(); + MLIRContext *context = rewriter.getContext(); + auto inputType = llvm::cast(input.getType()); + auto outputType = llvm::cast( + getTypeConverter()->convertType(op->getResult(0).getType())); + unsigned numDims = inputType.getRank(); + + assert(numDims >= 2 && "Not enough input dimensions"); + + SmallVector inputShape = getTensorSizes(rewriter, loc, input); + int64_t hDim = numDims - 1; + int64_t vDim = numDims - 2; + Value hDimSize = inputShape[hDim]; + Value vDimSize = inputShape[vDim]; + + assert(getHPadArgument(LEFT) < inputType.getShape()[hDim] && + "Left padding too large"); + assert(getHPadArgument(RIGHT) < inputType.getShape()[hDim] && + "Right padding too large"); + assert(getVPadArgument(TOP) < inputType.getShape()[vDim] && + "Top padding too large"); + assert(getVPadArgument(BOTTOM) < inputType.getShape()[vDim] && + "Bottom padding too large"); + + Type indexType = rewriter.getIndexType(); + Value zero = getConstant(rewriter, loc, 0, indexType); + Value one = getConstant(rewriter, loc, 1, indexType); + + Value tileWidth[3]; + tileWidth[HCENTER] = hDimSize; + for (auto h : {LEFT, RIGHT}) + tileWidth[h] = getConstant(rewriter, loc, getHPadArgument(h), indexType); + + Value tileHeight[3]; + tileHeight[VCENTER] = vDimSize; + for (auto v : {TOP, BOTTOM}) + tileHeight[v] = getConstant(rewriter, loc, getVPadArgument(v), indexType); + + // Helper to reflect/reverse the i-th dimension of an affine map + // without symbols. This only works if applied on a tensor + // for which the corresponding dimension has a statically + // known size which is good enough since we only apply + // it to reflect the padding slices. + auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i, + int64_t size) { + AffineExpr d = map.getResult(i); + return map.replace(d, size - d - 1, numDims, 0); + }; + + // Create output shape and tensor + SmallVector resultShape{inputShape}; + resultShape[vDim] = + createAdds({resultShape[vDim], tileHeight[TOP], tileHeight[BOTTOM]}); + resultShape[hDim] = + createAdds({resultShape[hDim], tileWidth[LEFT], tileWidth[RIGHT]}); + + Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape, + inputType.getElementType()); + + // Construction of the tiles + + // Example: central left tile + // + // Let m the width of the left padding as returned by getHPadargument(LEFT) + // and n the size of the input tensor's "horizontal" dimension, i.e. + // hDimSize. Assume that the subtensor of the input tensor in the relevant + // (i.e. last two) dimensions is: + // + // x_1,1 x_1,2 ... x_1,m + // x_2,1 x_2,2 ... x_2,m + // . + // . + // . + // x_n,1 x_n,2 ... x_n,m + // + // The padding tile consists of the columns 2, ..., m + 1 + // of the input in reverse order. The first column gets + // skipped because this is the column through which the + // reflection happens. + // + // x_1,m x_1,m-1 ... x_1,2 + // x_2,m x_1,m-1 ... x_2,2 + // . + // . + // . + // x_n,m x_n,m-1 ... x_n,2 + // + // The tile will be inserted to the left of the copy of the input tensor + // in the output tensor, i.e. with horizontal offset 0. + // The top padding determines the vertical offset. + + // Tiles on the diagonal (e.g. (TOP, LEFT)) are reflected through + // two sides, i.e. their columns and rows must be reversed. + + // Setup information about the tiles + + // Compute the offsets for extracting the slice from the + // input. We need to skip the row or column through which + // the tile should be reflected, if any (none for the center tile). + Value extractHOffset[3]; + extractHOffset[LEFT] = one; + extractHOffset[HCENTER] = zero; + extractHOffset[RIGHT] = createSubs({hDimSize, tileWidth[RIGHT], one}); + + Value extractVOffset[3]; + extractVOffset[TOP] = one; + extractVOffset[VCENTER] = zero; + extractVOffset[BOTTOM] = createSubs({vDimSize, tileHeight[BOTTOM], one}); + + // Compute the horizontal and vertical offsets for inserting + // the tiles in the resultTensor. + Value insertHOffset[3]; + insertHOffset[LEFT] = zero; + insertHOffset[HCENTER] = tileWidth[LEFT]; + insertHOffset[RIGHT] = createAdd(hDimSize, tileWidth[LEFT]); + + Value insertVOffset[3]; + insertVOffset[TOP] = zero; + insertVOffset[VCENTER] = tileHeight[TOP]; + insertVOffset[BOTTOM] = createAdd(vDimSize, tileHeight[TOP]); + + auto shouldHReflect = [](PadHLoc l) { return l == LEFT || l == RIGHT; }; + auto shouldVReflect = [](PadVLoc l) { return l == TOP || l == BOTTOM; }; + + SmallVector iteratorTypes{ + numDims, utils::IteratorType::parallel}; + auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context); + SmallVector allOneStrides(numDims, one); + + auto createTile = [&](PadVLoc verticalPos, PadHLoc horizontalPos) { + // Create the tile by extracting a slice from the input tenor. + SmallVector extractShape{inputShape}; + extractShape[hDim] = tileWidth[horizontalPos]; + extractShape[vDim] = tileHeight[verticalPos]; + + SmallVector extractOffsets(numDims, zero); + extractOffsets[hDim] = extractHOffset[horizontalPos]; + extractOffsets[vDim] = extractVOffset[verticalPos]; + + Value tile = rewriter.create( + loc, input, extractOffsets, extractShape, allOneStrides); + + // Reverse the tile along the horizontal, vertical, or both + // dimensions. + auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context); + if (shouldHReflect(horizontalPos)) { + inputMap = + reflectDim(inputMap, numDims, hDim, getHPadArgument(horizontalPos)); + } + if (shouldVReflect(verticalPos)) { + inputMap = + reflectDim(inputMap, numDims, vDim, getVPadArgument(verticalPos)); + } + + tile = rewriter + .create( + loc, llvm::cast(tile.getType()), tile, + tile, ArrayRef({inputMap, idMap}), iteratorTypes, + [](OpBuilder &b, Location nestedLoc, ValueRange args) { + b.create(nestedLoc, args[0]); + }) + .getResult(0); + + // Insert the tile in the resultTensor. + SmallVector insertOffsets(numDims, zero); + insertOffsets[hDim] = insertHOffset[horizontalPos]; + insertOffsets[vDim] = insertVOffset[verticalPos]; + + resultTensor = rewriter.create( + loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides); + }; + + for (auto v : {TOP, BOTTOM, VCENTER}) + for (auto h : {LEFT, RIGHT, HCENTER}) + if (shouldCreateTile(v, h)) + createTile(v, h); + + rewriter.replaceOpWithNewOp(op, outputType, resultTensor); + + return success(); + } +}; +} // namespace + namespace { class ConvertAtenFlattenUsingIntsOp : public OpConversionPattern { @@ -1552,6 +1840,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( MLIRContext *context = patterns.getContext(); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 4adf55556a2e..55b9638dd0cc 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8366,6 +8366,69 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %7 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %7 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.reflection_pad2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n" +" %int-1 = torch.constant.int -1\n" +" %int-2 = torch.constant.int -2\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: \"\n" +" %int2 = torch.constant.int 2\n" +" %int1 = torch.constant.int 1\n" +" %int4 = torch.constant.int 4\n" +" %int0 = torch.constant.int 0\n" +" %int3 = torch.constant.int 3\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %9 = torch.aten.__getitem__.t %arg1, %int3 : !torch.list, !torch.int -> !torch.int\n" +" %10 = torch.aten.lt.int %6, %3 : !torch.int, !torch.int -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.bool) {\n" +" %15 = torch.aten.lt.int %7, %3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %12 = torch.aten.lt.int %8, %2 : !torch.int, !torch.int -> !torch.bool\n" +" %13 = torch.prim.If %12 -> (!torch.bool) {\n" +" %15 = torch.aten.lt.int %9, %2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %15 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %13 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %14 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %14 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.index.Tensor\"(%arg0: !torch.list, %arg1: !torch.list>>) -> !torch.list {\n" " %0 = call @__torch__.index_tensor_like(%arg0, %arg1) : (!torch.list, !torch.list>>) -> !torch.list\n" " return %0 : !torch.list\n" @@ -9002,6 +9065,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad2d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.contiguous\"(%arg0: !torch.tuple, %arg1: !torch.int) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp index 244ee7b88cc0..3971fdd3258a 100644 --- a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp +++ b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp @@ -227,6 +227,34 @@ std::vector compute_shape_remainder( return {Shape(self.scalar_type(), self.sizes().vec())}; } +std::vector +compute_shape_reflection_pad2d(const at::Tensor &self, + at::IntArrayRef padding) { + std::vector paddings = padding.vec(); + std::vector in_sizes = self.sizes().vec(); + auto num_dims = in_sizes.size(); + + TORCH_CHECK(padding.size() == 4); + TORCH_CHECK(num_dims >= 2); + + auto vdim = num_dims - 2; + auto hdim = num_dims - 1; + auto padding_left = padding[0]; + auto padding_right = padding[1]; + auto padding_top = padding[2]; + auto padding_bottom = padding[3]; + TORCH_CHECK(padding_left < in_sizes[hdim]); + TORCH_CHECK(padding_right < in_sizes[hdim]); + TORCH_CHECK(padding_top < in_sizes[vdim]); + TORCH_CHECK(padding_bottom < in_sizes[vdim]); + + std::vector out_sizes(in_sizes); + out_sizes[hdim] += padding_left + padding_right; + out_sizes[vdim] += padding_top + padding_bottom; + + return {Shape(self.scalar_type(), out_sizes)}; +} + std::vector compute_shape_uniform( const at::Tensor& self, double from, double to, c10::optional generator) { diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 48949c318e22..a16d778c79a7 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1286,6 +1286,30 @@ def aten〇reflection_pad1d〡shape(self: List[int], padding: List[int]) -> List assert padding_left < hdim and padding_right < hdim return pad_shape_fn(self, padding) + +# Padding size must be smaller than corresponding dimension +@check_shape_function([ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,2,1,1]), + ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,1,1,1]), + ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,1,1,3]), + ErrorInvocation(TensorOfShape(2, 2, 2), padding=[2,1]), + Invocation(TensorOfShape(2, 2, 2), padding=[1,1,1,1]), + ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,1]), + ErrorInvocation(TensorOfShape(2, 2, 2), padding=[1,1,2,2])]) +def aten〇reflection_pad2d〡shape(self: List[int], padding: List[int]) -> List[int]: + assert len(self) >= 2 + vdim = self[-2] + hdim = self[-1] + + assert len(padding) == 4, 'padding size expected to be 4' + padding_left = padding[0] + padding_right = padding[1] + padding_top = padding[2] + padding_bottom = padding[3] + assert padding_left < hdim and padding_right < hdim + assert padding_top < vdim and padding_bottom < vdim + + return pad_shape_fn(self, padding) + # TODO: upstream this def index_tensor_like(self: List[int], indices: List[Optional[List[int]]]) -> List[int]: assert len(indices) <= len(self), "More indices than dimensions to index" @@ -1831,6 +1855,11 @@ def aten〇reflection_pad1d〡dtype(self_rank_dtype: Tuple[int, int], padding: L assert len(padding) == 2, 'padding size expected to be 2' return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(4, 2, 2)], padding=[1,1,1,1])) +def aten〇reflection_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇contiguous〡dtype(self_rank_dtype: Tuple[int, int], memory_format: int = 0) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 4d5b65c1dcd6..9c0a0759b443 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -542,6 +542,7 @@ def emit_with_mutating_variants(key, **kwargs): # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") emit("aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)") + emit("aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)") emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)") emit("aten::squeeze.dim : (Tensor, int) -> (Tensor)", has_folder=True) emit("aten::squeeze : (Tensor) -> (Tensor)", has_folder=True) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py index 79712a16f65b..f24266c78df8 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -59,3 +59,4 @@ def register_all_tests(): from . import return_types from . import control_flow from . import stats + from . import padding diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py new file mode 100644 index 000000000000..6b7bdeab2b48 --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py @@ -0,0 +1,113 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import functorch +import torch + +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export + +# ============================================================================== + +class ReflectionPad2dModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 20, 20], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad2d(x, (10,10,10,10)) + + +@register_test_case(module_factory=lambda: ReflectionPad2dModule()) +def ReflectionPad2dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 20, 20, low=-1)) + +# ============================================================================== + +class ReflectionPad2dModuleTop(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 3, 4], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad2d(x, (0,0,2,0)) + + +@register_test_case(module_factory=lambda: ReflectionPad2dModuleTop()) +def ReflectionPad2dModule_Top(module, tu: TestUtils): + module.forward(tu.rand(1, 3, 4)) + +# ============================================================================== + +class ReflectionPad2dModuleBottom(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 10, 10], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad2d(x, (0,0,0,5)) + + +@register_test_case(module_factory=lambda: ReflectionPad2dModuleBottom()) +def ReflectionPad2dModule_Bottom(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 10, 10)) + +# ============================================================================== + +class ReflectionPad2dModuleLeft(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 20, 20], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad2d(x, (15,0,0,0)) + + +@register_test_case(module_factory=lambda: ReflectionPad2dModuleLeft()) +def ReflectionPad2dModule_Left(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 20, 20)) + +# ============================================================================== + +class ReflectionPad2dModuleRight(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 3, 20, 20], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.reflection_pad2d(x, (0,11,0,0)) + + +@register_test_case(module_factory=lambda: ReflectionPad2dModuleRight()) +def ReflectionPad2dModule_Right(module, tu: TestUtils): + module.forward(tu.rand(2, 3, 20, 20)) + +# ============================================================================== From 5862854bc8011a94a54edeb4fa278908e9eb2c2b Mon Sep 17 00:00:00 2001 From: Andreas Falkenberg <149819731+afalkenberg1@users.noreply.github.com> Date: Thu, 11 Jan 2024 00:57:04 -0800 Subject: [PATCH 072/283] [ONNX][TORCH-MLIR] LayerNorm (#2716) Layer Normalization using the torch.aten.native_layer_norm https://github.com/nod-ai/SHARK-Turbine/issues/325 --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 43 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 13 ++++++ 2 files changed, 56 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index ee1ae6bb65be..fd7013afdb9d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -410,6 +410,49 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } return failure(); }); + patterns.onOp("LayerNormalization", 17, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType Y_type; + Torch::ValueTensorType Mean_type; + Torch::ValueTensorType InvStdDev_type; + Value X; + Value Scale; + Value B; + int64_t axis; + float epsilon; + int64_t stash_type; + if (binder.tensorOperandAtIndex(X, 0) || + binder.tensorOperandAtIndex(Scale, 1) || + binder.tensorOperandAtIndex(B, 2) || + binder.tensorResultTypeAtIndex(Y_type, 0) || + binder.tensorResultTypeAtIndex(Mean_type, 1) || + binder.tensorResultTypeAtIndex(InvStdDev_type, 2) || + binder.s64IntegerAttr(axis, "axis", -1) || + binder.f32FloatAttr(epsilon, "epsilon", 0.00001) || + binder.s64IntegerAttr(stash_type, "stash_type", 1)) + return failure(); + Value constEpsilon = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(epsilon)); + unsigned rank = 1; + if(std::optional maybeRank = Torch::getTensorRank(X)) + rank = *maybeRank; + SmallVector normalized; + axis = Torch::toPositiveDim(axis, rank); + auto X_type = X.getType().cast(); + ArrayRef X_shape = X_type.getSizes(); + for (int64_t n = axis; n < rank ; n++) { + normalized.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(X_shape[n]))); + } + Value normalized_shape = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + normalized); + rewriter.replaceOpWithNewOp( + binder.op, Y_type, Mean_type, InvStdDev_type, X, normalized_shape, Scale, B, constEpsilon); + return success(); + }); patterns.onOp("LeakyRelu", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index e224ddfa2944..07ddf3e594ea 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -116,6 +116,19 @@ func.func @test_gemm_alpha_beta(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch. // ----- +// CHECK-LABEL : func.func @test_layer_norm +func.func @test_layer_norm(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[3,4],f32>, %arg2: !torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4], f32>, !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) + attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %int3 = torch.constant.int 3 + // CHECK: %int4 = torch.constant.int 4 + // CHECK: %0 = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list + // CHECK: %result0, %result1, %result2 = torch.aten.native_layer_norm %arg0, %0, %arg1, %arg2 + %0:3 = torch.operator "onnx.LayerNormalization"(%arg0, %arg1, %arg2) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) + return %0#0, %0#1, %0#2 : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_leaky_relu func.func @test_leaky_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 16 : si64} { // CHECK-DAG: %[[F2:.+]] = torch.constant.float 2 From e1a86e480a5687f78ad4b70047e978905bee5088 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ilija=20Kalini=C4=87?= Date: Thu, 11 Jan 2024 15:55:42 +0100 Subject: [PATCH 073/283] Implement lowering of torch.aten.logit (#2697) Closes nod-ai/SHARK-Turbine#290 --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 47 ++++++++++ .../TorchToLinalg/Uncategorized.cpp | 88 ++++++++++++++++++- .../Transforms/AbstractInterpLibrary.cpp | 9 ++ projects/pt1/e2e_testing/xfail_sets.py | 1 + .../build_tools/abstract_interp_lib_gen.py | 8 ++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/elementwise.py | 22 +++++ 7 files changed, 175 insertions(+), 1 deletion(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 74a2e2327d1b..12a2bf4a86e2 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -2793,6 +2793,53 @@ def Torch_AtenLog1p_Op : Torch_Op<"aten.log1p_", [ }]; } +def Torch_AtenLogitOp : Torch_Op<"aten.logit", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::logit : (Tensor, float?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalFloatType:$eps + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLogitOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenLogitOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + +def Torch_AtenLogit_Op : Torch_Op<"aten.logit_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::logit_ : (Tensor, float?) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + AnyTorchOptionalFloatType:$eps + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLogit_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenLogit_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenRsqrtOp : Torch_Op<"aten.rsqrt", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index f742ded3f1bd..749945dee6e2 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1969,7 +1969,6 @@ class ConvertPrimsCollapseOp : public OpConversionPattern { associations.push_back(ReassociationIndices{i}); } - rewriter.replaceOpWithNewOp( op, resultRankedTensorType, adaptor.getA(), associations); @@ -1996,6 +1995,91 @@ class ConvertTensorStaticInfoCastOp }; } // namespace +namespace { +class ConvertLogitOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenLogitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op->getLoc(); + Value input = adaptor.getSelf(); + Value eps = adaptor.getEps(); + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + bool handleEps = false; + if (succeeded(checkNotNone(rewriter, op, eps))) + handleEps = true; + + if (handleEps && !eps.getType().isa()) { + op.emitError("Logit does not support non-floating point type"); + return failure(); + } + + auto inputType = input.getType().cast(); + auto inputElementType = inputType.getElementType(); + + if (!inputElementType.isa()) { + op.emitError("Logit does not support non-floating point type"); + return failure(); + } + + auto inputRank = inputType.getRank(); + + SmallVector indexingMaps = { + rewriter.getMultiDimIdentityMap(inputRank), // input + rewriter.getMultiDimIdentityMap(inputRank), // output + }; + SmallVector iteratorTypes( + inputRank, utils::IteratorType::parallel); + Value logit = + rewriter + .create( + loc, input.getType(), + /*ins=*/input, + /*outs=*/input, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value input = args[0]; + + TypedAttr oneAttr = b.getFloatAttr(inputElementType, 1.0); + Value oneValue = b.create(loc, oneAttr); + + Value zI; + if (!handleEps) { + zI = input; + } else { + Value truncEps = + b.create(loc, inputElementType, eps); + Value oneMinusEps = + b.create(loc, oneValue, truncEps); + + Value min = + b.create(loc, input, oneMinusEps); + Value clampedInput = + b.create(loc, min, truncEps); + + zI = clampedInput; + } + + Value probability = + b.create(loc, oneValue, zI); + Value odds = b.create(loc, zI, probability); + Value result = b.create(loc, odds); + + b.create(loc, result); + }) + .getResult(0); + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, logit); + return success(); + } +}; +} // namespace void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -2028,6 +2112,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 55b9638dd0cc..3901cd34a4aa 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6352,6 +6352,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.logit\"(%arg0: !torch.list, %arg1: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.rsqrt\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -8739,6 +8743,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.logit\"(%arg0: !torch.tuple, %arg1: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.rsqrt\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index ddb4865ec535..98cde05a8f73 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1446,6 +1446,7 @@ "ViewCollapseDynamicWithAtenSizeIntModule_basic", "AtenEmbeddingBagSumExample_basic", "Aten_EmbeddingBagExample_basic", + "ElementwiseLogitModule_basic", "ElementwiseRemainderScalarModule_Int_Float_basic", "ElementwiseRemainderScalarModule_Bool_basic", "AtenIntTensorByteDtypeModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index a16d778c79a7..211023a9deec 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -141,6 +141,9 @@ def aten〇log10〡shape(self: List[int]) -> List[int]: def aten〇log1p〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇logit〡shape(self: List[int], eps: Optional[float] = None) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇rsqrt〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -1659,6 +1662,11 @@ def aten〇log1p〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇logit〡dtype(self_rank_dtype: Tuple[int, int], eps: Optional[float] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇rsqrt〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 9c0a0759b443..a9f9ed96dce2 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -315,6 +315,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::log10 : (Tensor) -> (Tensor)", "aten::sqrt : (Tensor) -> (Tensor)", "aten::log1p : (Tensor) -> (Tensor)", + "aten::logit : (Tensor, float?) -> (Tensor)", "aten::rsqrt : (Tensor) -> (Tensor)", "aten::abs : (Tensor) -> (Tensor)", "aten::reciprocal : (Tensor) -> (Tensor)", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index c18c9103d888..5d6217b59072 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -1508,6 +1508,28 @@ def ElementwiseLog1pModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseLogitModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.logit(a, eps=1e-7) + + +@register_test_case(module_factory=lambda: ElementwiseLogitModule()) +def ElementwiseLogitModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + class ElementwiseErfModule(torch.nn.Module): def __init__(self): From 47ffc90db49f16c126c1fa456b92e24161847afb Mon Sep 17 00:00:00 2001 From: James Newling Date: Thu, 11 Jan 2024 09:46:46 -0800 Subject: [PATCH 074/283] signed/unsigned c++ compiler warning fixes (#2742) --- include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h | 4 ++-- lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp | 4 ++-- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 1ce381005fcc..b6189b375c5b 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -60,7 +60,7 @@ struct OpBinder { int64_t numOperands) { if (op->getNumOperands() != numOperands) return failure(); - for (int i = 0; i < numOperands; i++) { + for (int64_t i = 0; i < numOperands; i++) { Value curr = op->getOperand(i); if (!toValidTensorType(curr.getType())) { return failure(); @@ -80,7 +80,7 @@ struct OpBinder { } ParseResult tensorOperandsList( llvm::SmallVectorImpl &values) { - for (int i = 0; i < op->getNumOperands(); i++) { + for (uint32_t i = 0; i < op->getNumOperands(); i++) { values.push_back(op->getOperand(i)); } return success(); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index fd7013afdb9d..0102366fe01c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -182,7 +182,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return failure(); } Value result = operands[0]; - for (int i = 1; i < operands.size(); i++) { + for (uint64_t i = 1; i < operands.size(); i++) { result = rewriter.create( binder.getLoc(), resultType, result, operands[i]); } @@ -200,7 +200,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return failure(); } Value result = operands[0]; - for (int i = 1; i < operands.size(); i++) { + for (uint64_t i = 1; i < operands.size(); i++) { result = rewriter.create( binder.getLoc(), resultType, result, operands[i]); } diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index d88ca9d6d5cb..11a05ea41105 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -672,7 +672,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( int64_t adjustmentInt = cast(data.getType()).getSizes().size(); // convert axes (tensor) into torch int list while dealing with neg axis - for (int i = 0; i < axes.size(); i++) { + for (uint64_t i = 0; i < axes.size(); i++) { // Go through the axes list and get each dim in the list int64_t dim = axes[i]; if (dim < 0) { From 670a99ae196da892310776f110cfe29dfb68a174 Mon Sep 17 00:00:00 2001 From: Ze Zhang Date: Thu, 11 Jan 2024 10:36:48 -0800 Subject: [PATCH 075/283] Handle torch.none type in tosa.clamp op (#2739) This PR updates the torch-to-tosa conversion with following changes: - Support torch.none as min/max input argument for tosa.clamp op - Support negative value as start index for tosa.slice op - Add tosa.logical_or lowering support e2e test: python -m e2e_testing.main --config=tosa LIT tests: cmake --build build --target tools/torch-mlir/all --------- Co-authored-by: Ze Zhang --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 78 +++++++++++++--------- projects/pt1/e2e_testing/xfail_sets.py | 12 ++++ test/Conversion/TorchToTosa/basic.mlir | 71 ++++++++++++++++++++ 3 files changed, 129 insertions(+), 32 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index e123522a4542..6555f06e8702 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -18,6 +18,7 @@ #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/Dialect/Traits.h" #include "mlir/IR/Matchers.h" +#include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -3336,9 +3337,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) return rewriter.notifyMatchFailure(op, "start must be a Scalar constant"); - if (start < 0) - return rewriter.notifyMatchFailure(op, "Currently unsupported: start < 0"); - + if (start < 0) { + start = toPositiveDim(start, selfType.getShape()[dim]); + if (!isValidDim(start, selfType.getShape()[dim])) + return rewriter.notifyMatchFailure(op, "start is not a valid index"); + } start = std::min(selfType.getShape()[dim], start); int64_t end; @@ -3984,36 +3987,46 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "only tensor types input are currently supported"); - IntegerAttr min_int, max_int; - FloatAttr min_fp, max_fp; - if (op.getMin().getType().isa()) { - double fp_min, fp_max; - if (!matchPattern(op.getMin(), m_TorchConstantFloat(&fp_min))) - return rewriter.notifyMatchFailure( - op, "unimplemented: value `fp_min` should be a torch constant float"); - - if (!matchPattern(op.getMax(), m_TorchConstantFloat(&fp_max))) - return rewriter.notifyMatchFailure( - op, "unimplemented: value `fp_max` should be a torch constant float"); - - min_int = rewriter.getI64IntegerAttr(static_cast(fp_min)); - max_int = rewriter.getI64IntegerAttr(static_cast(fp_max)); - min_fp = rewriter.getF32FloatAttr(static_cast(fp_min)); - max_fp = rewriter.getF32FloatAttr(static_cast(fp_max)); - } else { - int64_t int_min, int_max; - if (!matchPattern(op.getMin(), m_TorchConstantInt(&int_min))) - return rewriter.notifyMatchFailure( - op, "unimplemented: value `int_min` should be a torch constant int"); - - if (!matchPattern(op.getMax(), m_TorchConstantInt(&int_max))) - return rewriter.notifyMatchFailure( - op, "unimplemented: value `int_max` should be a torch constant int"); + IntegerAttr min_int = + rewriter.getI64IntegerAttr(std::numeric_limits::min()); + IntegerAttr max_int = + rewriter.getI64IntegerAttr(std::numeric_limits::max()); + FloatAttr min_fp = + rewriter.getF32FloatAttr(std::numeric_limits::lowest()); + FloatAttr max_fp = + rewriter.getF32FloatAttr(std::numeric_limits::max()); + + auto getValAttr = [&](Value operand, IntegerAttr &intAttr, + FloatAttr &fpAttr) -> LogicalResult { + double valFloat; + int64_t valInt; + if (matchPattern(operand, m_TorchConstantFloat(&valFloat))) { + intAttr = rewriter.getI64IntegerAttr(static_cast(valFloat)); + fpAttr = rewriter.getF32FloatAttr(static_cast(valFloat)); + } else if (matchPattern(operand, m_TorchConstantInt(&valInt))) { + intAttr = rewriter.getI64IntegerAttr(valInt); + fpAttr = rewriter.getF32FloatAttr(static_cast(valInt)); + } else { + return failure(); + } + return success(); + }; - min_int = rewriter.getI64IntegerAttr(int_min); - max_int = rewriter.getI64IntegerAttr(int_max); - min_fp = rewriter.getF32FloatAttr(static_cast(int_min)); - max_fp = rewriter.getF32FloatAttr(static_cast(int_max)); + LogicalResult minAttrResult = getValAttr(op.getMin(), min_int, min_fp); + LogicalResult maxAttrResult = getValAttr(op.getMax(), max_int, max_fp); + if (failed(minAttrResult) && failed(maxAttrResult)) { + return rewriter.notifyMatchFailure( + op, "either `min` or `max` should be a torch constant"); + } + if (failed(minAttrResult) && + succeeded(checkNotNone(rewriter, op, op.getMin()))) { + return rewriter.notifyMatchFailure(op, + "min attr should be a torch constant"); + } + if (failed(maxAttrResult) && + succeeded(checkNotNone(rewriter, op, op.getMax()))) { + return rewriter.notifyMatchFailure(op, + "max attr should be a torch constant"); } auto outType = getTypeConverter()->convertType(op.getType()); @@ -5025,6 +5038,7 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { patterns.add>(typeConverter, context); INSERT_BINARY_PATTERN(AtenMaximumOp, tosa::MaximumOp) INSERT_BINARY_PATTERN(AtenMinimumOp, tosa::MinimumOp) + INSERT_BINARY_PATTERN(AtenLogicalOrOp, tosa::LogicalOrOp) #undef INSERT_BINARY_PATTERN #define INSERT_BINARY_ADDSUB_PATTERN(AtenOp, TosaOp) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 98cde05a8f73..de68680a82f2 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1035,6 +1035,15 @@ "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", "ElementwiseAtenDivIntScalarModule_basic", "ElementwiseAtenIsinfOpModule_basic", + "ElementwiseAtenLogicalOrOpBrodcastModule_basic", + "ElementwiseAtenLogicalOrOpDiffArgs1Module_basic", + "ElementwiseAtenLogicalOrOpDiffArgs2Module_basic", + "ElementwiseAtenLogicalOrOpDiffArgs3Module_basic", + "ElementwiseAtenLogicalOrOpModule_basic", + "ElementwiseAtenLogicalOrOpNegativeModule_basic", + "ElementwiseAtenLogicalOrOpPromoteBroadcastStaticShapeModule_basic", + "ElementwiseAtenLogicalOrOpRandomFloatModule_basic", + "ElementwiseAtenLogicalOrOpRandomModule_basic", "ElementwiseAtenWhereSelfModule_basic", "ElementwiseBinaryModule_basic", "ElementwiseBinaryStaticShapeModule_basic", @@ -1047,6 +1056,9 @@ "ElementwiseBitwiseXorModule_basic", "ElementwiseBitwiseXorStaticShapeModule_basic", "ElementwiseCeilModule_basic", + "ElementwiseClampMaxModule_basic", + "ElementwiseClampMinModule_basic", + "ElementwiseClampModule_basic", "ElementwiseCloneChannelsLastMemoryFormatModule_basic", "ElementwiseCloneContiguousModule_basic", "ElementwiseCloneModule_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index 180f48bcef2b..b36acc779547 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -645,6 +645,22 @@ func.func @torch.aten.ne.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // ----- +// CHECK-LABEL: func.func @torch.aten.logical_or$basic( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],i1>, +// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { +// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],i1> -> tensor +// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],i1> -> tensor +// CHECK: %[[VAL_4:.*]] = tosa.logical_or %[[VAL_2]], %[[VAL_3]] : (tensor, tensor) -> tensor +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor -> !torch.vtensor<[?,?],i1> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],i1> +// CHECK: } +func.func @torch.aten.logical_or$basic(%arg0: !torch.vtensor<[?,?],i1>, %arg1: !torch.vtensor<[?,?],i1>) -> !torch.vtensor<[?,?],i1> { + %0 = torch.aten.logical_or %arg0, %arg1 : !torch.vtensor<[?,?],i1>, !torch.vtensor<[?,?],i1> -> !torch.vtensor<[?,?],i1> + return %0 : !torch.vtensor<[?,?],i1> +} + +// ----- + // CHECK-LABEL: func.func @forward( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,2,4],f32> { // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,2],f32> -> tensor<3x4x2xf32> @@ -1055,6 +1071,61 @@ func.func @torch.aten.Scalar$basic(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> return %0 : !torch.vtensor<[1,1,128,128],si64> } +// ----- +// CHECK-LABEL: func.func @torch.aten.slice.negative_start( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,16,256],f32> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[4,65,256],f32> -> tensor<4x65x256xf32> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK: %[[VAL_4:.*]] = torch.constant.int 100 +// CHECK: %[[VAL_5:.*]] = torch.constant.int -16 +// CHECK: %[[VAL_4:.*]] = tosa.slice %[[VAL_1]] {size = array, start = array} : (tensor<4x65x256xf32>) -> tensor<4x16x256xf32> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<4x16x256xf32> -> !torch.vtensor<[4,16,256],f32> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[4,16,256],f32> +// CHECK: } +func.func @torch.aten.slice.negative_start(%arg0: !torch.vtensor<[4,65,256],f32>) -> !torch.vtensor<[4,16,256],f32> { + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int100 = torch.constant.int 100 + %int-16 = torch.constant.int -16 + %0 = torch.aten.slice.Tensor %arg0, %int1, %int-16, %int100, %int1 : !torch.vtensor<[4,65,256],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[4,16,256],f32> + return %0 : !torch.vtensor<[4,16,256],f32> +} + +// ----- +// CHECK-LABEL: func.func @torch.aten.clamp.min_none( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,128,128],si64> -> tensor<1x1x128x128xi64> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.none +// CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 0.000000e+00 : f32, max_int = 0 : i64, min_fp = -3.40282347E+38 : f32, min_int = -9223372036854775808 : i64} : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi64> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[1,1,128,128],si64> +// CHECK: } +func.func @torch.aten.clamp.min_none(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> { + %int0 = torch.constant.int 0 + %none = torch.constant.none + %0 = torch.aten.clamp %arg0, %none, %int0 : !torch.vtensor<[1,1,128,128],si64>, !torch.none, !torch.int -> !torch.vtensor<[1,1,128,128],si64> + return %0 : !torch.vtensor<[1,1,128,128],si64> +} + +// ----- +// CHECK-LABEL: func.func @torch.aten.clamp.max_none( +// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> { +// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,1,128,128],si64> -> tensor<1x1x128x128xi64> +// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK: %[[VAL_3:.*]] = torch.constant.none +// CHECK: %[[VAL_4:.*]] = tosa.clamp %[[VAL_1]] {max_fp = 3.40282347E+38 : f32, max_int = 9223372036854775807 : i64, min_fp = 0.000000e+00 : f32, min_int = 0 : i64} : (tensor<1x1x128x128xi64>) -> tensor<1x1x128x128xi64> +// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1x1x128x128xi64> -> !torch.vtensor<[1,1,128,128],si64> +// CHECK: return %[[VAL_5]] : !torch.vtensor<[1,1,128,128],si64> +// CHECK: } +func.func @torch.aten.clamp.max_none(%arg0: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> { + %int0 = torch.constant.int 0 + %none = torch.constant.none + %0 = torch.aten.clamp %arg0, %int0, %none : !torch.vtensor<[1,1,128,128],si64>, !torch.int, !torch.none -> !torch.vtensor<[1,1,128,128],si64> + return %0 : !torch.vtensor<[1,1,128,128],si64> +} + // ----- // CHECK-LABEL: func.func @torch.aten.clamp( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> { From c7452af4fa7b4139dbd8b78b388b84a08b8c1b7a Mon Sep 17 00:00:00 2001 From: Chi_Liu <22491986+AmosLewis@users.noreply.github.com> Date: Fri, 12 Jan 2024 14:54:38 -0800 Subject: [PATCH 076/283] [MLIR][ONNX] Add OnnxToTorch support for Maxpool Op (#2695) Add Maxpool ONNX op support. Add Utils.h/cpp files to create a constant int list for ONNX. --- .../Conversion/TorchOnnxToTorch/Utils.h | 23 +++++ .../TorchOnnxToTorch/CMakeLists.txt | 1 + .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 79 +++++++++++++++++ lib/Conversion/TorchOnnxToTorch/Utils.cpp | 28 ++++++ .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 85 ++++++++++++++++++- 5 files changed, 215 insertions(+), 1 deletion(-) create mode 100644 include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h create mode 100644 lib/Conversion/TorchOnnxToTorch/Utils.cpp diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h new file mode 100644 index 000000000000..058fee4da4a2 --- /dev/null +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -0,0 +1,23 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#ifndef TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H +#define TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" + +namespace mlir::torch::onnx_c { + +Value createConstantIntList(OpBinder binder, + ConversionPatternRewriter &rewriter, + SmallVector cstInput); + +} // namespace mlir::torch::onnx_c + +#endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H diff --git a/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt b/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt index 807db64eac64..4a5015816609 100644 --- a/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt +++ b/lib/Conversion/TorchOnnxToTorch/CMakeLists.txt @@ -5,6 +5,7 @@ add_mlir_conversion_library(TorchMLIRTorchOnnxToTorch Passes.cpp Patterns.cpp TorchOnnxToTorch.cpp + Utils.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchOnnxToTorch diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 0102366fe01c..c0a7473e4601 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" using namespace mlir; @@ -148,6 +149,84 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); + patterns.onOp( + "MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + std::string autoPad; + if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) + return rewriter.notifyMatchFailure(binder.op, + "auto_pad bind failure"); + if (autoPad != "NOTSET") + return rewriter.notifyMatchFailure( + binder.op, "unsupported conversion: auto_pad != NOTSET"); + + Torch::ValueTensorType resultType; + Value operand; + bool ceilMode; + int64_t storageOrder; + // TODO: Add support for indices output and storage_order + if (binder.tensorOperand(operand) || + binder.s64BoolAttr(ceilMode, "ceil_mode", false) || + binder.s64IntegerAttr(storageOrder, "storage_order", 0) || + binder.tensorResultType(resultType)) + return rewriter.notifyMatchFailure( + binder.op, + "operand/ceil_mode/storage_order/resultType bind failure"); + if (storageOrder != 0) + return rewriter.notifyMatchFailure( + binder.op, "storage_order setting is not supported."); + // Determine the rank of input tensor. + std::optional maybeRank = Torch::getTensorRank(operand); + if (!maybeRank) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: unranked tensor"); + unsigned rank = *maybeRank; + + SmallVector kernel, padding, strides, dilations; + if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) + return rewriter.notifyMatchFailure(binder.op, + "kernel_shape bind failure"); + if (kernel.size() != rank - 2) + return rewriter.notifyMatchFailure( + binder.op, "kernel list size does not match the number of axes"); + if (binder.s64IntegerArrayAttr(padding, "pads", {0})) + return rewriter.notifyMatchFailure(binder.op, "pads bind failure"); + if (padding.size() != 1 && padding.size() != rank - 2) + return rewriter.notifyMatchFailure( + binder.op, "padding list size does not match the number of axes"); + if (binder.s64IntegerArrayAttr(strides, "strides", {1})) + return rewriter.notifyMatchFailure(binder.op, "strides bind failure"); + if (strides.size() != 1 && strides.size() != rank - 2) + return rewriter.notifyMatchFailure( + binder.op, "strides list size does not match the number of axes"); + if (binder.s64IntegerArrayAttr(dilations, "dilations", {})) + return rewriter.notifyMatchFailure(binder.op, + "dilations bind failure"); + + Value kernelSizeList = createConstantIntList(binder, rewriter, kernel); + Value paddingList = createConstantIntList(binder, rewriter, padding); + Value stridesList = createConstantIntList(binder, rewriter, strides); + Value dilationsList = + createConstantIntList(binder, rewriter, dilations); + Value cstCeilMode = + rewriter.create(binder.getLoc(), ceilMode); + + if (rank == 3) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: AtenMaxPool1dOp"); + if (rank == 4) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, dilationsList, cstCeilMode); + return success(); + } + if (rank == 5) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand, kernelSizeList, stridesList, + paddingList, dilationsList, cstCeilMode); + return success(); + } + return rewriter.notifyMatchFailure(binder.op, "No rank is matched."); + }); patterns.onOp("Greater", 16, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/lib/Conversion/TorchOnnxToTorch/Utils.cpp b/lib/Conversion/TorchOnnxToTorch/Utils.cpp new file mode 100644 index 000000000000..8f5a2e67c0cb --- /dev/null +++ b/lib/Conversion/TorchOnnxToTorch/Utils.cpp @@ -0,0 +1,28 @@ +//===------------------------------------------------------------*- C++ -*-===// +// +// This file is licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::onnx_c; + +Value mlir::torch::onnx_c::createConstantIntList( + OpBinder binder, ConversionPatternRewriter &rewriter, + SmallVector cstInput) { + SmallVector cstValue; + for (int64_t i : cstInput) { + cstValue.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + return rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + cstValue); +} diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 07ddf3e594ea..c85659c25aa8 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -13,6 +13,8 @@ func.func @test_greater(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtenso return %0 : !torch.vtensor<[3,4,5],i1> } +// ----- + // CHECK-LABEL: func.func @test_greater_or_equal func.func @test_greater_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> @@ -22,6 +24,8 @@ func.func @test_greater_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor return %0 : !torch.vtensor<[3,4,5],i1> } +// ----- + // CHECK-LABEL: func.func @test_less func.func @test_less(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> @@ -31,6 +35,8 @@ func.func @test_less(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[ return %0 : !torch.vtensor<[3,4,5],i1> } +// ----- + // CHECK-LABEL: func.func @test_gather_elements func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} { // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0 @@ -99,7 +105,7 @@ func.func @test_gemm_beta(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtenso return %0 : !torch.vtensor<[3,4],f32> } - // ----- +// ----- // CHECK-LABEL: func.func @test_gemm_alpha_beta func.func @test_gemm_alpha_beta(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { @@ -137,6 +143,8 @@ func.func @test_leaky_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_matmul_2d func.func @test_matmul_2d(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.matmul %arg0, %arg1 : !torch.vtensor<[3,4],f32>, !torch.vtensor<[4,3],f32> -> !torch.vtensor<[3,3],f32> @@ -173,6 +181,62 @@ func.func @test_matmul_4d(%arg0: !torch.vtensor<[1,2,3,4],f32>, %arg1: !torch.vt // ----- +// CHECK-LABEL: func.func @test_maxpool_2d_default +func.func @test_maxpool_2d_default(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[I2_1:.*]] = torch.constant.int 2 + // CHECK: %[[LIST22:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0]] : (!torch.int) -> !torch.list + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.max_pool2d %arg0, %[[LIST22]], %[[LIST1]], %[[LIST0]], %[[LIST]], %[[FALSE]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,31,31],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> + return %0 : !torch.vtensor<[1,3,31,31],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxpool_2d_ceil +func.func @test_maxpool_2d_ceil(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[I3:.*]] = torch.constant.int 3 + // CHECK: %[[I3_1:.*]] = torch.constant.int 3 + // CHECK: %[[LIST33:.*]] = torch.prim.ListConstruct %[[I3]], %[[I3_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0]] : (!torch.int) -> !torch.list + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[I2_1:.*]] = torch.constant.int 2 + // CHECK: %[[LIST22:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[TRUE:.*]] = torch.constant.bool true + // CHECK: torch.aten.max_pool2d %arg0, %[[LIST33]], %[[LIST22]], %[[LIST0]], %[[LIST]], %[[TRUE]] : !torch.vtensor<[1,1,4,4],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,1,2,2],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 1 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32> + return %0 : !torch.vtensor<[1,1,2,2],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_maxpool_3d_default +func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[I2:.*]] = torch.constant.int 2 + // CHECK: %[[I2_1:.*]] = torch.constant.int 2 + // CHECK: %[[I2_2:.*]] = torch.constant.int 2 + // CHECK: %[[LIST222:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]], %[[I2_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[I0:.*]] = torch.constant.int 0 + // CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0]] : (!torch.int) -> !torch.list + // CHECK: %[[I1:.*]] = torch.constant.int 1 + // CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.max_pool3d %arg0, %[[LIST222]], %[[LIST1]], %[[LIST0]], %[[LIST]], %[[FALSE]] : !torch.vtensor<[1,3,32,32,32],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,31,31,31],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> + return %0 : !torch.vtensor<[1,3,31,31,31],f32> +} + +// ----- + // CHECK-LABEL: @test_gelu_default_1 func.func @test_gelu_default_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[STR1:.*]] = torch.constant.str "none" @@ -222,6 +286,8 @@ func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch. return %0 : !torch.vtensor<[3,4,5],i1> } +// ----- + // CHECK-LABEL: func.func @test_pow func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -229,6 +295,8 @@ func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch. return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_hardsigmoid_example func.func @test_hardsigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 5.000000e-01 @@ -252,6 +320,8 @@ func.func @test_hardsigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vt return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: @test_hardsigmoid func.func @test_hardsigmoid(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 5.000000e-01 @@ -274,6 +344,8 @@ func.func @test_hardsigmoid(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_hardsigmoid_default func.func @test_hardsigmoid_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 0.20000000298023224 @@ -331,6 +403,8 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 return %0 : !torch.vtensor<[1,1,1,1],f32> } +// ----- + // CHECK-LABEL: func.func @test_max_example func.func @test_max_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.maximum %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> @@ -338,6 +412,8 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: func.func @test_min_example func.func @test_min_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.minimum %arg0, %arg1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> @@ -345,6 +421,7 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 return %0 : !torch.vtensor<[3],f32> } +// ----- // CHECK-LABEL: func.func @test_log func.func @test_log(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { @@ -353,6 +430,8 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_neg func.func @test_neg(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.neg %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -360,6 +439,8 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_not_2d func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1> @@ -367,6 +448,8 @@ func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4], return %0 : !torch.vtensor<[3,4],i1> } +// ----- + // CHECK-LABEL: func.func @test_or2d func.func @test_or2d(%arg0: !torch.vtensor<[3,4],i1>, %arg1: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_or.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],i1>, !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1> From dc37616d6773acc55c7452c242c7f13e838362f4 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 12 Jan 2024 19:11:14 -0800 Subject: [PATCH 077/283] [torch][quant] Support quantize and dequantize for torch (#2731) Handle both `torch.dequantize` and `torch.quantize_per_tensor` including the op based quantization parameter tracking. This includes adding `qint32` to torch types as it was missing during the initial type inclusion. For testing we only have `torch.int8` and `torch.float` types on function boundaries as the `qint8` types require passing the scale and zero point quantization information which is not supported yet. --- .../Conversion/TorchToLinalg/Utils.h | 2 + .../Dialect/Torch/IR/GeneratedTorchOps.td | 120 ++++++++++++++ .../torch-mlir/Dialect/Torch/IR/TorchTypes.td | 11 ++ .../TorchToLinalg/Uncategorized.cpp | 153 +++++++++++++++++- lib/Conversion/TorchToLinalg/Utils.cpp | 17 ++ lib/Dialect/Torch/IR/TorchTypes.cpp | 12 +- .../Transforms/AbstractInterpLibrary.cpp | 73 +++++++++ lib/Dialect/Torch/Utils/Utils.cpp | 12 ++ .../base_lazy_backend/shape_inference.cpp | 2 +- projects/pt1/e2e_testing/xfail_sets.py | 6 + .../build_tools/abstract_interp_lib_gen.py | 43 +++++ .../build_tools/torch_ods_gen.py | 7 + .../test_suite/elementwise.py | 46 ++++++ 13 files changed, 496 insertions(+), 8 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h index 134fbeca46dc..7c9257075824 100644 --- a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h +++ b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h @@ -95,6 +95,8 @@ FailureOr getBackendTypeForScalarType(MLIRContext *context, torch_upstream::ScalarType dtypeInt); +bool isUnsignedTorchType(Type type); + } // namespace torch_to_linalg } // namespace torch } // namespace mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 12a2bf4a86e2..9525f9f9ffa6 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -14206,6 +14206,126 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [ }]; } +def Torch_AtenQuantizePerTensorOp : Torch_Op<"aten.quantize_per_tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_FloatType:$scale, + Torch_IntType:$zero_point, + Torch_IntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenQuantizePerTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenQuantizePerTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenDequantizeSelfOp : Torch_Op<"aten.dequantize.self", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::dequantize.self : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenDequantizeSelfOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenDequantizeSelfOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenDequantizeTensorOp : Torch_Op<"aten.dequantize.tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::dequantize.tensor : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$qtensor + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenDequantizeTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenDequantizeTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenIntReprOp : Torch_Op<"aten.int_repr", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::int_repr : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIntReprOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenIntReprOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_Aten_MakePerTensorQuantizedTensorOp : Torch_Op<"aten._make_per_tensor_quantized_tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_FloatType:$scale, + Torch_IntType:$zero_point + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_MakePerTensorQuantizedTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void Aten_MakePerTensorQuantizedTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_PrimLayoutOp : Torch_Op<"prim.layout", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index 1f7231b3500a..c3b5c1582c02 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -306,6 +306,17 @@ def Torch_QUInt8Type : Torch_Type<"QUInt8", "quint8"> { }]; } +def Torch_QInt32Type : Torch_Type<"QInt32", "qint32"> { + let summary = "Type modeling `ScalarType::QInt32`"; + let description = [{ + This is intended to be a 1:1 match for the Torch `ScalarType` types. + + Looking at the variety / ad-hocness (e.g. `QUInt4x2`) of that set of + types, it is deemed preferable to import them as one-off ad-hoc types + instead of a single parameterized type. + }]; +} + def Torch_LinearParamsType : Torch_Type<"LinearParams", "LinearParams"> { let summary = "Torch packed linear params type"; let description = [{ diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 749945dee6e2..e35136e333f0 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1316,6 +1316,106 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, payloadArgs[0], allOnesVal); } + if (isa(op)) { + auto value = payloadArgs[0]; + auto valueTy = value.getType(); + auto qtensor = op->getOperand(0); + auto qtensorTy = qtensor.getType().cast().getDtype(); + auto makeQTensor = + qtensor.getDefiningOp(); + if (!makeQTensor) { + op->emitError( + "unimplemented: dequantizing tensor of unknown scale / zero-point"); + return nullptr; + } + + auto outFpTy = payloadArgs[1].getType(); + auto outBw = outFpTy.getIntOrFloatBitWidth(); + auto outIntTy = b.getIntegerType(outBw); + + if (valueTy != outIntTy) { + if (torch_to_linalg::isUnsignedTorchType(qtensorTy)) { + value = b.create(loc, outIntTy, value); + } else { + value = b.create(loc, outIntTy, value); + } + } + + Value zp = makeQTensor.getZeroPoint(); + zp = converter->materializeTargetConversion( + b, loc, converter->convertType(zp.getType()), + makeQTensor.getZeroPoint()); + auto zpTy = zp.getType(); + + if (zpTy != outIntTy) { + zp = b.create(loc, outIntTy, zp); + } + + value = b.create(loc, value, zp); + + if (torch_to_linalg::isUnsignedTorchType(qtensorTy)) { + value = b.create(loc, outFpTy, value); + } else { + value = b.create(loc, outFpTy, value); + } + + Value scale = makeQTensor.getScale(); + scale = converter->materializeTargetConversion( + b, loc, converter->convertType(scale.getType()), + makeQTensor.getScale()); + if (scale.getType() != value.getType()) { + scale = b.create(loc, value.getType(), scale); + } + value = b.create(loc, value, scale); + return value; + } + + if (auto quant = dyn_cast(op)) { + Value value = payloadArgs[0]; + Value scale = quant.getScale(); + Value zp = quant.getZeroPoint(); + auto valueTy = value.getType(); + + zp = converter->materializeTargetConversion( + b, loc, converter->convertType(zp.getType()), zp); + zp = b.create(loc, valueTy, zp); + + scale = converter->materializeTargetConversion( + b, loc, converter->convertType(scale.getType()), scale); + scale = b.create(loc, valueTy, scale); + + value = b.create(loc, value, scale); + value = b.create(loc, value); + value = b.create(loc, value, zp); + + auto destTy = payloadArgs[1].getType(); + auto bitwidth = destTy.getIntOrFloatBitWidth(); + bool isUnsigned = torch_to_linalg::isUnsignedTorchType(quant.getType()); + APInt min = isUnsigned ? APInt::getMinValue(bitwidth) + : APInt::getSignedMinValue(bitwidth); + APInt max = isUnsigned ? APInt::getMaxValue(bitwidth) + : APInt::getSignedMaxValue(bitwidth); + + Value minVal = b.create( + loc, b.getFloatAttr(valueTy, min.getSExtValue())); + Value maxVal = b.create( + loc, b.getFloatAttr(valueTy, max.getSExtValue())); + Value minCmp = + b.create(loc, arith::CmpFPredicate::ULT, value, minVal); + Value maxCmp = + b.create(loc, arith::CmpFPredicate::UGT, value, maxVal); + value = b.create(loc, minCmp, minVal, value); + value = b.create(loc, maxCmp, maxVal, value); + + if (isUnsigned) { + value = b.create(loc, destTy, value); + } else { + value = b.create(loc, destTy, value); + } + + return value; + } + op->emitError("unimplemented lowering in " "createLinalgPayloadCalculationForElementwiseOp"); return nullptr; @@ -1368,9 +1468,10 @@ class ConvertElementwiseOp : public ConversionPattern { AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, - AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp, - AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, - AtenAtanOp, AtenAcosOp, AtenRealOp, AtenImagOp>(op)) + AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, + AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, + AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenRealOp, AtenImagOp, + AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -2080,6 +2181,42 @@ class ConvertLogitOp : public OpConversionPattern { } }; } // namespace + +namespace { +class ConvertAtenIntReprOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenIntReprOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType resultType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getSelf()); + return success(); + } +}; +} // namespace + +namespace { +class ConvertMakePerTensorQuantizedTensorOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(Aten_MakePerTensorQuantizedTensorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + RankedTensorType resultType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + rewriter.replaceOpWithNewOp(op, resultType, + adaptor.getSelf()); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -2102,9 +2239,9 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, - AtenTrilOp, - AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, - AtenFillTensorOp, AtenRealOp, AtenImagOp>(); + AtenTrilOp, AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, + AtenFillScalarOp, AtenFillTensorOp, AtenRealOp, AtenImagOp, + AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); @@ -2122,4 +2259,8 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( patterns.add(typeConverter, context); patterns.add(typeConverter, context); target.addIllegalOp(); + patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); + target.addIllegalOp(); } diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index ccc78985dc6c..77459aca3a60 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -559,3 +559,20 @@ FailureOr torch_to_linalg::getBackendTypeForScalarType( } return type; } + +bool torch_to_linalg::isUnsignedTorchType(Type type) { + if (auto tty = dyn_cast(type)) + return isUnsignedTorchType(tty.getDtype()); + if (isa(type)) + return false; + if (isa(type)) + return false; + if (isa(type)) + return true; + if (isa(type)) + return false; + if (auto intTy = dyn_cast(type)) + return intTy.isUnsigned(); + llvm_unreachable("Unknown type checked for signedness"); + return false; +} diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index cf832b1b755e..33ef459081c4 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -184,7 +184,7 @@ static bool isValidTorchDtype(Type dtype) { dtype = dtype.cast().getElementType(); } // Torch quantized types. - if (dtype.isa()) + if (dtype.isa()) return true; // Builtin floating point types. if (dtype.isa()) @@ -410,6 +410,16 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { } else if (dtype.isa()){ return dtype; } + + if (isa(dtype)) + return IntegerType::get(context, 8, IntegerType::Signless); + + if (isa(dtype)) + return IntegerType::get(context, 8, IntegerType::Signless); + + if (isa(dtype)) + return IntegerType::get(context, 32, IntegerType::Signless); + emitError(UnknownLoc::get(context)) << "unimplemented: conversion of dtype " << dtype << " to builtin tensor element type"; diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 3901cd34a4aa..c286168080e4 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6481,6 +6481,26 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_tensor\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.dequantize.self\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.dequantize.tensor\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.int_repr\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten._make_per_tensor_quantized_tensor\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.prims.convert_element_type\"(%arg0: !torch.list, %arg1: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -11783,6 +11803,59 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_tensor\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" return %arg3 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.dequantize.self\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" return %int6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.dequantize.tensor\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" return %int6 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.int_repr\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int3 = torch.constant.int 3\n" +" %int1 = torch.constant.int 1\n" +" %int12 = torch.constant.int 12\n" +" %int0 = torch.constant.int 0\n" +" %int13 = torch.constant.int 13\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int13 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int0 : !torch.int\n" +" } else {\n" +" %3 = torch.aten.eq.int %0#1, %int12 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int3 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._make_per_tensor_quantized_tensor\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.int) -> !torch.int {\n" +" %int14 = torch.constant.int 14\n" +" %int12 = torch.constant.int 12\n" +" %int1 = torch.constant.int 1\n" +" %int13 = torch.constant.int 13\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int13 : !torch.int\n" +" } else {\n" +" %3 = torch.aten.eq.int %0#1, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int12 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int14 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" "}\n" ""; // clang-format on diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 12ac1d58ee59..06330f16a57e 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -64,6 +64,12 @@ torch_upstream::ScalarType Torch::getScalarTypeForType(Type type) { return torch_upstream::ScalarType::Byte; if (type.isSignedInteger(8)) return torch_upstream::ScalarType::Char; + if (type.isa()) + return torch_upstream::ScalarType::QUInt8; + if (type.isa()) + return torch_upstream::ScalarType::QInt8; + if (type.isa()) + return torch_upstream::ScalarType::QInt32; if (type.isa()) { mlir::Type complexElemType = type.cast().getElementType(); if (complexElemType.isF16()) @@ -109,6 +115,12 @@ Torch::getTypeForScalarType(MLIRContext *context, return mlir::IntegerType::get(context, 8, mlir::IntegerType::Unsigned); case torch_upstream::ScalarType::Char: return mlir::IntegerType::get(context, 8, mlir::IntegerType::Signed); + case torch_upstream::ScalarType::QUInt8: + return QUInt8Type::get(context); + case torch_upstream::ScalarType::QInt8: + return QInt8Type::get(context); + case torch_upstream::ScalarType::QInt32: + return QInt32Type::get(context); case torch_upstream::ScalarType::ComplexHalf: return mlir::ComplexType::get(Float16Type::get(context)); case torch_upstream::ScalarType::ComplexFloat: diff --git a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp index 3971fdd3258a..15080f9764cc 100644 --- a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp +++ b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp @@ -431,4 +431,4 @@ std::vector compute_shape_linspace(const at::Scalar & start, } // namespace lazy -} // namespace torch \ No newline at end of file +} // namespace torch diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index de68680a82f2..e04657df4d2c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -311,6 +311,10 @@ # ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' "GroupNormModule_basic", "GroupNormNoWeightAndBiasModule_basic", + + # Dynamo does not support tracing quantized tensors + "ElementwiseDequantizePerTensorModule_basic", + "ElementwiseQuantizePerTensorModule_basic", } TORCHDYNAMO_CRASHING_SET = { @@ -1507,4 +1511,6 @@ "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", + "ElementwiseQuantizePerTensorModule_basic", + "ElementwiseDequantizePerTensorModule_basic" } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 211023a9deec..6a8fbf34e911 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -221,6 +221,21 @@ def aten〇clamp_max〡shape(self: List[int], max: float) -> List[int]: def aten〇rsub〇Scalar〡shape(self: List[int], other: float, alpha: float = 1) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇quantize_per_tensor〡shape(self: List[int], scale: float, zero_point: int, dtype: int) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇dequantize〇self〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇dequantize〇tensor〡shape(qtensor: List[int]) -> List[int]: + return upstream_shape_functions.unary(qtensor) + +def aten〇int_repr〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇_make_per_tensor_quantized_tensor〡shape(self: List[int], scale: float, zero_point: int) -> List[int]: + return upstream_shape_functions.unary(self) + def prims〇convert_element_type〡shape(a: List[int], dtype: int) -> List[int]: return upstream_shape_functions.unary(a) @@ -3958,6 +3973,34 @@ def prims〇collapse〡dtype(a_rank_dtype: Tuple[int, int], start: int, end: int return a_dtype +def aten〇quantize_per_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int, dtype: int) -> int: + return dtype + +def aten〇dequantize〇self〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + return torch.float32 + +def aten〇dequantize〇tensor〡dtype(qtensor_rank_dtype: Tuple[int, int]) -> int: + return torch.float32 + +def aten〇int_repr〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if (self_dtype == torch.quint8): + return torch.uint8 + if (self_dtype == torch.qint8): + return torch.int8 + return torch.int32 + +def aten〇_make_per_tensor_quantized_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int) -> int: + self_rank, self_dtype = self_rank_dtype + if (self_dtype == torch.uint8): + return torch.quint8 + if (self_dtype == torch.int8): + return torch.qint8 + return torch.qint32 + + + + # ============================================================================== # Main diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index a9f9ed96dce2..249c25628a82 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -805,6 +805,13 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::elu_backward : (Tensor, Scalar, Scalar, Scalar, bool, Tensor) -> (Tensor)") emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)") + # quantized ops + emit("aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)") + emit("aten::dequantize.self : (Tensor) -> (Tensor)") + emit("aten::dequantize.tensor : (Tensor) -> (Tensor)") + emit("aten::int_repr : (Tensor) -> (Tensor)") + emit("aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)") + # ========================================================================== # `prim::` namespace. # ========================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 5d6217b59072..23a22142c4d5 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -4136,6 +4136,52 @@ def ElementwiseBitwiseAndScalarInt8Module_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseQuantizePerTensorModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float, True), + ]) + def forward(self, x): + scale = 0.04 + zp = -110 + dtype = torch.qint8 + # We return the int representation as we can not map to quint8 type yet on boundaries. + q = torch.quantize_per_tensor(x, scale, zp, dtype).int_repr() + return q + +@register_test_case(module_factory=lambda: ElementwiseQuantizePerTensorModule()) +def ElementwiseQuantizePerTensorModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + +class ElementwiseDequantizePerTensorModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int8, True), + ]) + def forward(self, x): + qx = torch._make_per_tensor_quantized_tensor(x, 0.1, 8) + qx = torch.dequantize(qx) + return qx + +@register_test_case(module_factory=lambda: ElementwiseDequantizePerTensorModule()) +def ElementwiseDequantizePerTensorModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8)) + +# ============================================================================== + class GluStaticModule(torch.nn.Module): def __init__(self): super().__init__() From 10acea71be38c409470c66b57a77330b7b8bebd0 Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Mon, 15 Jan 2024 07:12:12 -0800 Subject: [PATCH 078/283] Bump LLVM to llvm/llvm-project@0cb024b (#2753) - Add fixes for https://github.com/llvm/llvm-project/commit/af78e5daf0791135485dbd7972ffedb927727a6b - Add fixes for https://github.com/llvm/llvm-project/commit/bb6d5c220004a5d7e466a669324001285a688918 --- externals/llvm-project | 2 +- lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp | 16 +++++++++------- test/Dialect/Torch/canonicalize.mlir | 8 ++++---- test/Dialect/Torch/decompose-complex-ops.mlir | 4 ++-- .../Torch/simplify-shape-calculations.mlir | 14 +++++++------- 5 files changed, 23 insertions(+), 21 deletions(-) diff --git a/externals/llvm-project b/externals/llvm-project index 6b65d79fbb46..0cb024b357af 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 6b65d79fbb4682468333cea42b62f15c2dffd8f3 +Subproject commit 0cb024b357aff294b1ba0f9d3de8f48ab684962b diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 14e6f351ed97..b8f719792476 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -33,8 +33,9 @@ Value buildRescale(PatternRewriter &rewriter, Operation *op, rewriter.getI32IntegerAttr(static_cast(input_zp)), rewriter.getI32IntegerAttr(static_cast(output_zp)), rewriter.getDenseI32ArrayAttr({multiplier}), - rewriter.getDenseI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32), - rewriter.getBoolAttr(double_round), rewriter.getBoolAttr(false)); + rewriter.getDenseI8ArrayAttr({static_cast(shift)}), + rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round), + rewriter.getBoolAttr(false)); return rescale_op.getResult(); } @@ -86,8 +87,9 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op, rewriter, op->getLoc(), output_type, conv_val, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp), rewriter.getDenseI32ArrayAttr({multiplier}), - rewriter.getDenseI32ArrayAttr({shift}), rewriter.getBoolAttr(scale32), - rewriter.getBoolAttr(true), rewriter.getBoolAttr(false)); + rewriter.getDenseI8ArrayAttr({static_cast(shift)}), + rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(true), + rewriter.getBoolAttr(false)); return rescale_op.getResult(); @@ -96,7 +98,7 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op, .dyn_cast()) { // Per-channel quantization SmallVector multiplier_arr; - SmallVector shift_arr; + SmallVector shift_arr; SmallVector weight_scale_arr( weight_per_channel_qtype.getScales().begin(), @@ -115,14 +117,14 @@ Value buildRescaleOpConvOutput(PatternRewriter &rewriter, Operation *op, scale_width); multiplier_arr.push_back(multiplier); - shift_arr.push_back(shift); + shift_arr.push_back(static_cast(shift)); } auto rescale_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type, conv_val, rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp), rewriter.getDenseI32ArrayAttr(multiplier_arr), - rewriter.getDenseI32ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32), + rewriter.getDenseI8ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(true), rewriter.getBoolAttr(true)); return rescale_op.getResult(); diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 5dfd8daa9d44..abb990cccc8c 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1,10 +1,10 @@ // RUN: torch-mlir-opt %s -canonicalize | FileCheck %s // CHECK-LABEL: func.func @torch.aten.__range_length$fold() -> (!torch.int, !torch.int, !torch.int, !torch.int) { -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[INTM1:.*]] = torch.constant.int -1 +// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1 +// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[INT3:.*]] = torch.constant.int 3 +// CHECK-DAG: %[[INTM1:.*]] = torch.constant.int -1 // CHECK: %[[NEG_STEP:.*]] = torch.aten.__range_length %[[INT1]], %[[INT3]], %[[INTM1]] : !torch.int, !torch.int, !torch.int -> !torch.int // CHECK: return %[[INT2]], %[[INT2]], %[[INT1]], %[[NEG_STEP]] : !torch.int, !torch.int, !torch.int, !torch.int func.func @torch.aten.__range_length$fold() -> (!torch.int, !torch.int, !torch.int, !torch.int) { diff --git a/test/Dialect/Torch/decompose-complex-ops.mlir b/test/Dialect/Torch/decompose-complex-ops.mlir index 17767f9f4e02..d223bb21ec43 100644 --- a/test/Dialect/Torch/decompose-complex-ops.mlir +++ b/test/Dialect/Torch/decompose-complex-ops.mlir @@ -84,8 +84,8 @@ func.func @torch.aten.adaptive_avg_pool2d$unit_output_size(%arg0: !torch.vtensor // CHECK-LABEL: func.func @torch.aten.type_as$basic( // CHECK-SAME: %[[ARG_0:.*]]: !torch.tensor, %[[ARG_1:.*]]: !torch.tensor) -> !torch.tensor { -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false +// CHECK-DAG: %[[NONE:.*]] = torch.constant.none // CHECK: %[[DTYPE:.*]] = torch.prim.dtype %[[ARG_1]] : !torch.tensor -> !torch.int // CHECK: %[[VAR:.*]] = torch.aten.to.dtype %[[ARG_0]], %[[DTYPE]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.tensor, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.tensor // CHECK: return %[[VAR]] : !torch.tensor diff --git a/test/Dialect/Torch/simplify-shape-calculations.mlir b/test/Dialect/Torch/simplify-shape-calculations.mlir index 10a65a527873..0d3b1f661bde 100644 --- a/test/Dialect/Torch/simplify-shape-calculations.mlir +++ b/test/Dialect/Torch/simplify-shape-calculations.mlir @@ -105,9 +105,9 @@ func.func @refine_shape_calculate_result$user_allows_type_refinement(%arg0: !tor // CHECK-LABEL: func.func @fully_unroll_prim_loop$unroll( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor, // CHECK-SAME: %[[ARG1:.*]]: !torch.list) -> !torch.vtensor { -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1 +// CHECK-DAG: %[[INT2:.*]] = torch.constant.int 2 +// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[RESULT:.*]] = torch.shape.calculate { // CHECK: torch.shape.calculate.yield %[[ARG0]] : !torch.vtensor // CHECK: } shapes { @@ -375,8 +375,8 @@ func.func @abstractly_interpret_list_ops$miscompile$list_identity(%arg0: !torch. // missing. // CHECK-LABEL: func.func @basic_integration( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],unk>) -> !torch.vtensor { -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK-DAG: %[[INT0:.*]] = torch.constant.int 0 +// CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[RESULT:.*]] = torch.shape.calculate { // CHECK: %[[TANH:.*]] = torch.aten.tanh %[[ARG0]] : !torch.vtensor<[?,?],unk> -> !torch.vtensor<[?,?],unk> // CHECK: torch.shape.calculate.yield %[[TANH]] : !torch.vtensor<[?,?],unk> @@ -410,8 +410,8 @@ func.func @basic_integration(%arg0: !torch.vtensor<[?,?],unk>) -> !torch.vtensor // CHECK-LABEL: func.func @fold_prim_unchecked_cast_op( // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor, // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor { -// CHECK: %[[VAL_2:.*]] = torch.constant.int 0 -// CHECK: %[[VAL_3:.*]] = torch.constant.int 1 +// CHECK-DAG: %[[VAL_2:.*]] = torch.constant.int 0 +// CHECK-DAG: %[[VAL_3:.*]] = torch.constant.int 1 // CHECK: %[[VAL_4:.*]] = torch.shape.calculate { // CHECK: %[[VAL_5:.*]] = torch.tensor_static_info_cast %[[VAL_0]] : !torch.vtensor to !torch.vtensor<[?,?],unk> // CHECK: torch.shape.calculate.yield %[[VAL_5]] : !torch.vtensor<[?,?],unk> From 197b3b475c2fa4c452f08c79f2cab1c7482d6ccc Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 15 Jan 2024 09:31:22 -0800 Subject: [PATCH 079/283] [onnx] Convert `onnx.constant` to `torch` literal tensor (#2748) Handles the multiple cases of `onnx` constant values and converts them to `torch` literal tensors. This can include splats with a single integer or floating point value, a set of explicit integer values, or an elements array attr of values. --- .../Conversion/TorchOnnxToTorch/Patterns.h | 13 +++++ .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 53 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 41 ++++++++++++++ 3 files changed, 107 insertions(+) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index b6189b375c5b..44e33ab09741 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -190,6 +190,19 @@ struct OpBinder { return failure(); } + ParseResult denseElementsAttr(ElementsAttr elementsattr, + StringRef nameSuffix) { + SmallString<64> name("torch.onnx."); + name.append(nameSuffix); + Attribute attr = op->getAttr(name); + if (!attr || !isa(attr)) { + return failure(); + } + + elementsattr = cast(attr); + return success(); + } + ParseResult customOpNameStringAttr(std::string &value, StringRef nameSuffix, std::string defaultValue = "") { SmallString<64> name("torch.onnx."); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 2cfc9940a940..aa3b5fc012d0 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -590,6 +590,59 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( tensorList, cstDim); return success(); }); + patterns.onOp( + "Constant", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + if (binder.tensorResultType(resultType)) + return failure(); + auto dtype = resultType.getDtype(); + Value scalarValue; + + float floatValue; + if (binder.op->hasAttr("torch.onnx.value_float") && + !binder.f32FloatAttr(floatValue, "value_float", 0.0)) { + auto splatAttr = + SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype), + rewriter.getFloatAttr(dtype, floatValue)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, splatAttr); + return success(); + } + + int64_t intValue; + if (binder.op->hasAttr("torch.onnx.value_int") && + !binder.s64IntegerAttr(intValue, "value_int", 0)) { + auto splatAttr = + SplatElementsAttr::get(resultType.toBuiltinTensor().clone(dtype), + rewriter.getIntegerAttr(dtype, intValue)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, splatAttr); + return success(); + } + + if (ElementsAttr attr = binder.op->getAttr("torch.onnx.value") + .dyn_cast_or_null()) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, attr); + return success(); + } + + llvm::SmallVector intValues; + if (!binder.s64IntegerArrayAttr(intValues, "value_ints", {}) && + !intValues.empty()) { + llvm::SmallVector apValues; + for (auto intVal : intValues) { + apValues.push_back(APInt(dtype.getIntOrFloatBitWidth(), intVal)); + } + auto attr = DenseElementsAttr::get( + resultType.toBuiltinTensor().clone(dtype), apValues); + rewriter.replaceOpWithNewOp( + binder.op, resultType, attr); + return success(); + } + + return failure(); + }); patterns.onOp( "Conv", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index fc9706127280..f8bc219dcb48 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -979,3 +979,44 @@ func.func @test_depthtospace_crd_mode_example(%arg0: !torch.vtensor<[1,8,2,3],f3 %0 = torch.operator "onnx.DepthToSpace"(%arg0) {torch.onnx.blocksize = 2 : si64, torch.onnx.mode = "CRD"} : (!torch.vtensor<[1,8,2,3],f32>) -> !torch.vtensor<[1,2,4,6],f32> return %0 : !torch.vtensor<[1,2,4,6],f32> } + +// ----- + +// CHECK-LABEL: @float_constant +func.func @float_constant() -> !torch.vtensor<[], f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<2.500000e-01> : tensor) : !torch.vtensor<[],f32> + // CHECK: return %[[CST]] + %0 = torch.operator "onnx.Constant"() {torch.onnx.value_float = 0.25 : f32} : () -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// ----- + +// CHECK-LABEL: @int_constant +func.func @int_constant() -> !torch.vtensor<[], si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<79> : tensor) : !torch.vtensor<[],si64> + // CHECK: return %[[CST]] + %0 = torch.operator "onnx.Constant"() {torch.onnx.value_int = 79 : si64} : () -> !torch.vtensor<[],si64> + return %0 : !torch.vtensor<[],si64> +} + +// ----- + +// CHECK-LABEL: @dense_constant +func.func @dense_constant() -> !torch.vtensor<[1], si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<13> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + // CHECK: return %[[CST]] + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense<13> : tensor<1xsi64>} : () -> !torch.vtensor<[1],si64> + return %0 : !torch.vtensor<[1],si64> +} + +// ----- + +// CHECK-LABEL: @ints_constant +func.func @ints_constant() -> !torch.vtensor<[2], si64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<[7, 9]> : tensor<2xsi64>) : !torch.vtensor<[2],si64> + // CHECK: return %[[CST]] + %0 = "torch.operator"() <{name = "onnx.Constant"}> {torch.onnx.value_ints = [7 : si64, 9 : si64]} : () -> !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> +} + From 09421b1cf3153e9794187fe665395713a048803f Mon Sep 17 00:00:00 2001 From: lisaliu1 Date: Mon, 15 Jan 2024 20:02:27 +0100 Subject: [PATCH 080/283] [TorchToLinalg] Add lowering for aten.replication_pad2d (#2715) Co-authored-by: Lisa Liu --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 24 ++ .../TorchToLinalg/TensorConstructors.cpp | 225 ++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 29 +++ .../build_tools/abstract_interp_lib_gen.py | 9 + .../build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/basic.py | 99 ++++++++ 6 files changed, 387 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 9525f9f9ffa6..758b10315391 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7916,6 +7916,30 @@ def Torch_AtenConstantPadNdOp : Torch_Op<"aten.constant_pad_nd", [ }]; } +def Torch_AtenReplicationPad2dOp : Torch_Op<"aten.replication_pad2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::replication_pad2d : (Tensor, int[]) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$padding + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenReplicationPad2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenReplicationPad2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenReflectionPad1dOp : Torch_Op<"aten.reflection_pad1d", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 434b50b034dd..9429d1e8caca 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -24,6 +24,8 @@ #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include + using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; @@ -79,6 +81,227 @@ class ConvertAtenConstantPadNdOp }; } // namespace +namespace { + + // Lower aten.replication_pad2d operator into a sequence of + // tensor.extract_slice and tensor.concat operations. + + class ConvertAtenReplicationPad2dOp + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenReplicationPad2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op->getLoc(); + Value input = adaptor.getSelf(); + MLIRContext *context = rewriter.getContext(); + auto inputType = llvm::cast(input.getType()); + int64_t inputRank = inputType.getRank(); + auto outputType = llvm::cast( + getTypeConverter()->convertType(op->getResult(0).getType())); + unsigned numDims = inputType.getRank(); + assert(numDims >= 2 && "Not enough input dimensions"); + + SmallVector padInts; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) + return rewriter.notifyMatchFailure( + op, "only support constant int pad ranges"); + uint64_t padRank = padInts.size() / 2; + if (padRank * 2 != padInts.size()) + return rewriter.notifyMatchFailure(op, "pad range size is not even"); + if (inputRank < 0 || padRank > (uint64_t)inputRank) + return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank"); + + SmallVector inputShape = getTensorSizes(rewriter, loc, input); + int64_t hDim = numDims - 1; + int64_t vDim = numDims - 2; + Value hDimSize = inputShape[hDim]; + Value vDimSize = inputShape[vDim]; + + enum tileHLoc { LEFT = 0, HCENTER = 1, RIGHT = 2 }; + enum tileVLoc { TOP = 0, VCENTER = 2, BOTTOM = 1, }; + // vTile denotes the vertical size of the tile + // hTile denotes the horizontal size of the tile + // The padding results are composed of following tiles: + // vTile[TOP]hTile[LEFT], vTile[TOP]hTile[HCENTER], vTile[TOP]hTile[RIGHT] + // vTile[VCENTER]hTile[LEFT], vTile[VCENTER]hTile[HCENTER], vTile[VCENTER]hTile[RIGHT] + // vTile[BOTTOM]hTile[LEFT], vTile[BOTTOM]hTile[HCENTER], vTile[BOTTOM]hTile[RIGHT] + // vTile[VCENTER]hTile[HCENTER] is the original input tensor + Type indexType = rewriter.getIndexType(); + Value vTile[3]; + Value hTile[3]; + vTile[VCENTER] = vDimSize; + hTile[HCENTER] = hDimSize; + vTile[TOP] = getConstant(rewriter, loc, padInts[2], indexType); + vTile[BOTTOM] = getConstant(rewriter, loc, padInts[3], indexType); + hTile[LEFT] = getConstant(rewriter, loc, padInts[0], indexType); + hTile[RIGHT] = getConstant(rewriter, loc, padInts[1], indexType); + + bool hasLeftPadding = false; + bool hasRightPadding = false; + bool hasTopPadding = false; + bool hasBottomPadding = false; + + for (auto i: {TOP, VCENTER, BOTTOM}){ + for (auto j: {LEFT, HCENTER, RIGHT}) { + auto constVtile{ + mlir::dyn_cast(vTile[i].getDefiningOp()) + .getValue() + .dyn_cast_or_null()}; + + auto constHtile{ + mlir::dyn_cast(hTile[j].getDefiningOp()) + .getValue() + .dyn_cast_or_null()}; + auto vSize = constVtile.getInt(); + auto hSize = constHtile.getInt(); + + if ((i == TOP) && (vSize > 0)) + hasTopPadding = true; + if ((i == BOTTOM) && (vSize > 0)) + hasBottomPadding = true; + if ((j == LEFT) && (hSize > 0)) + hasLeftPadding = true; + if ((j == RIGHT) && (hSize > 0)) + hasRightPadding = true; + } + } + + // Some generic helper functions to aid in constructing basic arithmetic. + auto createAdd = [&](Value x, Value y) { + return rewriter.create(loc, x, y); + }; + + auto createAdds = [&](std::initializer_list values) { + assert(values.size() >= 2); + return std::accumulate(values.begin() + 1, values.end(), data(values)[0], + createAdd); + }; + auto createSub = [&](Value x, Value y) { + return rewriter.create(loc, x, y); + }; + + // Extract left and right pad tiles. + Value zero = getConstant(rewriter, loc, 0, indexType); + Value one = getConstant(rewriter, loc, 1, indexType); + Value hDimSizeMinusOne = createSub(hDimSize, one); + Value vDimSizeMinusOne = createSub(vDimSize, one); + SmallVector allOneStrides(numDims, one); + + SmallVector extractOffsetsLT(numDims, zero); + extractOffsetsLT[hDim] = zero; + extractOffsetsLT[vDim] = zero; + SmallVector extractShapeLR(numDims, one); + extractShapeLR[hDim] = one; + extractShapeLR[vDim] = vDimSize; + + SmallVector extractOffsetsRight(numDims, zero); + extractOffsetsRight[hDim] = hDimSizeMinusOne; + extractOffsetsRight[vDim] = zero; + + SmallVector extractOffsetsBottom(numDims, zero); + extractOffsetsBottom[hDim] = zero; + extractOffsetsBottom[vDim] = vDimSizeMinusOne; + + SmallVector extractShapeTB(numDims, one); + extractShapeTB[hDim] = hDimSize; + extractShapeTB[vDim] = one; + + SmallVector tensorsLeft; + SmallVector tensorsRight; + SmallVector tensorsCenter; + Value centerTile; + SmallVector tensorsRes; + + if (hasLeftPadding) { + Value vCenterLeftSlice = rewriter.create( + loc, input, extractOffsetsLT, extractShapeLR, allOneStrides); + Value vLeftSlice = vCenterLeftSlice; + if (hasTopPadding) { + Value topLeftValue = rewriter.create( + loc, input, ValueRange{zero, zero, zero, zero}); + //pad vCenterLeftSlice on the top + SmallVector lowPadding(4, 0); + SmallVector highPadding(4, 0); + lowPadding[2] = padInts[2]; + vLeftSlice = torch_to_linalg::getPaddedTensor(op, rewriter, vLeftSlice, lowPadding, highPadding, topLeftValue); + } + if (hasBottomPadding) { + Value bottomLeftValue = rewriter.create (loc, input, ValueRange{zero, zero, vDimSizeMinusOne, zero}); + + //pad vLeftSlice at the bottom + SmallVector lowPadding(4, 0); + SmallVector highPadding(4, 0); + highPadding[2] = padInts[3]; + vLeftSlice = torch_to_linalg::getPaddedTensor(op, rewriter, vLeftSlice, lowPadding, highPadding, bottomLeftValue); + } + for (auto i=0; i(loc, 3, tensorsLeft); + tensorsRes.push_back(leftPadTile); + } + if (hasTopPadding) { + Value topLeftValue = rewriter.create( + loc, input, ValueRange{zero, zero, zero, zero}); + Value topHcenterSlice = rewriter.create( + loc, input, extractOffsetsLT, extractShapeTB, allOneStrides); + for (auto i = 0; i < padInts[2]; ++i) { + tensorsCenter.push_back(topHcenterSlice); + } + } + tensorsCenter.push_back(input); + if (hasBottomPadding) { + Value bottomHcenterSlice = rewriter.create( + loc, input, extractOffsetsBottom, extractShapeTB, allOneStrides); + for (auto i = 0; i < padInts[3]; ++i) { + tensorsCenter.push_back(bottomHcenterSlice); + } + } + centerTile = rewriter.create(loc, 2, tensorsCenter); + tensorsRes.push_back(centerTile); + + if (hasRightPadding) { + Value vCenterRightSlice = rewriter.create( + loc, input, extractOffsetsRight, extractShapeLR, allOneStrides); + Value vRightSlice = vCenterRightSlice; + if (hasTopPadding) { + Value topRightValue = rewriter.create (loc, input, ValueRange{zero, zero, zero, hDimSizeMinusOne}); + + //pad vCenterRightSlice on the top + SmallVector lowPadding(4, 0); + SmallVector highPadding(4, 0); + lowPadding[2] = padInts[2]; + vRightSlice = torch_to_linalg::getPaddedTensor(op, rewriter, vRightSlice, lowPadding, highPadding, topRightValue); + } + if (hasBottomPadding) { + Value bottomRightValue = rewriter.create (loc, input, ValueRange{zero, zero, vDimSizeMinusOne, hDimSizeMinusOne}); + + // Pad vCenterRightSlice or vRightTopPaddedSlice at the bottom. + SmallVector lowPadding(4, 0); + SmallVector highPadding(4, 0); + highPadding[2] = padInts[3]; + vRightSlice = torch_to_linalg::getPaddedTensor(op, rewriter, vRightSlice, lowPadding, highPadding, bottomRightValue); + } + for (auto i=0; i(loc, 3, tensorsRight); + tensorsRes.push_back(rightPadTile); + } + Value resTensor = rewriter.create(loc, 3, tensorsRes); + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, resTensor); + return success(); + } + }; +} + namespace { // Converts constant tensor allocation like ops. template @@ -348,6 +571,8 @@ void mlir::torch::torch_to_linalg:: RewritePatternSet &patterns, ConversionTarget &target) { MLIRContext *context = patterns.getContext(); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index c286168080e4..62aa96086e83 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8351,6 +8351,35 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " } : (!torch.int, !torch.bool) -> ()\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.replication_pad2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: padding size expected to be 4\"\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: \"\n" +" %int2 = torch.constant.int 2\n" +" %int4 = torch.constant.int 4\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.replication_pad2d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.pad\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.str, %arg3: !torch.optional) -> !torch.list {\n" " %0 = call @__torch__.pad_shape_fn(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 6a8fbf34e911..0bcc3f02343b 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1286,6 +1286,15 @@ def pad_shape_fn(input: List[int], pad: List[int]): def aten〇constant_pad_nd〡shape(self: List[int], pad: List[int], value: float = 0) -> List[int]: return pad_shape_fn(self, pad) +def aten〇replication_pad2d〡shape(self: List[int], padding: List[int]) -> List[int]: + assert len(self) >= 2 + assert len(padding) == 4, 'padding size expected to be 4' + return pad_shape_fn(self, padding) + +def aten〇replication_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + def aten〇pad〡shape(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]: return pad_shape_fn(self, pad) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 249c25628a82..604294f7409b 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -542,6 +542,7 @@ def emit_with_mutating_variants(key, **kwargs): # Misc tensor ops. emit("aten::constant_pad_nd : (Tensor, int[], Scalar) -> (Tensor)") + emit("aten::replication_pad2d : (Tensor, int[]) -> (Tensor)") emit("aten::reflection_pad1d : (Tensor, int[]) -> (Tensor)") emit("aten::reflection_pad2d : (Tensor, int[]) -> (Tensor)") emit("aten::pad : (Tensor, int[], str, float?) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index a68d229faf39..51deffb6175a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -565,6 +565,105 @@ def __init__(self): def forward(self, x): return torch.ops.aten.reflection_pad1d(x, (3,1)) +class ReplicationPad2dModule_basic_module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 3, 3], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.replication_pad2d(x, (1, 2, 3, 4)) + + +@register_test_case(module_factory=lambda: ReplicationPad2dModule_basic_module()) +def ReplicationPad2dModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 3, low=-1)) + +# ============================================================================== + +class ReplicationPad2dModule_left0_module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 3, 3], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.replication_pad2d(x, (0, 2, 3, 4)) + + +@register_test_case(module_factory=lambda: ReplicationPad2dModule_left0_module()) +def ReplicationPad2dModule_left0(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 3, low=-1)) + +# ============================================================================== + +class ReplicationPad2dModule_right0_module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 3, 3], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.replication_pad2d(x, (1, 0, 3, 4)) + + +@register_test_case(module_factory=lambda: ReplicationPad2dModule_right0_module()) +def ReplicationPad2dModule_right0(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 3, low=-1)) + +# ============================================================================== + +class ReplicationPad2dModule_top0_module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 3, 3], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.replication_pad2d(x, (1, 2, 0, 4)) + + +@register_test_case(module_factory=lambda: ReplicationPad2dModule_top0_module()) +def ReplicationPad2dModule_top0(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 3, low=-1)) + +# ============================================================================== + +class ReplicationPad2dModule_bottom0_module(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 3, 3], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.replication_pad2d(x, (1, 2, 3, 0)) + + +@register_test_case(module_factory=lambda: ReplicationPad2dModule_bottom0_module()) +def ReplicationPad2dModule_bottom0(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 3, 3, low=-1)) + +# ============================================================================== @register_test_case(module_factory=lambda: ReflectionPad1dModule3dInput()) def ReflectionPad1dModule3dInput_basic(module, tu: TestUtils): From 87389f0762c1626a56f3afaafcf51bd9f5e28518 Mon Sep 17 00:00:00 2001 From: kumardeepakamd <123522031+kumardeepakamd@users.noreply.github.com> Date: Mon, 15 Jan 2024 11:26:46 -0800 Subject: [PATCH 081/283] [ONNXToTorch] Add conversion for Onnx range (#2752) Implemented ONNX.Range. The spec says the data type for start, limit, delta are 0-D can be double, float, int16, int32, int64, All int types mapped to !torch.int and all float types mapped to !torch.float --------- Co-authored-by: Kumar Deepak --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 58 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 55 ++++++++++++++++++ 2 files changed, 113 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 11a05ea41105..0833af54d43f 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -27,6 +27,18 @@ using namespace mlir::torch::onnx_c; // to be more normal and a direct translation vs a special case. This // results in a lot of ONNX test cases that all reduce to the exact same // thing here, so we simplify. + +// utilities +// Templatized function to get an item op of a type +namespace { +template +Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter, + Value &ofItem) { + return rewriter.create(binder.getLoc(), + rewriter.getType(), ofItem); +} +} // namespace + void mlir::torch::onnx_c::populateDefaultDomainQtoZ( OnnxCustomOpConversionPattern &patterns) { patterns.onOp("Reciprocal", 1, @@ -1336,4 +1348,50 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( data, dimValueList); return success(); }); + patterns.onOp( + "Range", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // ONNX.Range(start, limit, delta) -- limit is exclusive + + Torch::ValueTensorType resultType; + Value start, limit, delta; + auto loc = binder.getLoc(); + Value none = rewriter.create(loc); + if (binder.tensorOperandAtIndex(start, 0) || + binder.tensorOperandAtIndex(limit, 1) || + binder.tensorOperandAtIndex(delta, 2) || + binder.tensorResultType(resultType)) + return failure(); + + // Convert a 0-dimensional/Scalar Tensor ([]) to Scalar Torch Numeric + // Value torch.tensor(1.1) equivalent in ONNX to 1.1 as an example + // type of start, limit, delta can be one of: double, float, int16, + // int32, int64 Assuming start, limit and delta to be same type (could + // they be different?) + Torch::BaseTensorType startTensorType = + start.getType().cast(); + bool isFloatDType = startTensorType.getDtype().isF64() || + startTensorType.getDtype().isF32(); + bool isIntDType = startTensorType.getDtype().isInteger(16) || + startTensorType.getDtype().isInteger(32) || + startTensorType.getDtype().isInteger(64); + if (!isFloatDType && !isIntDType) { + return rewriter.notifyMatchFailure( + binder.op, "Expected the start, limit, delta to be one of " + "double, float, int16, int32, int64"); + } + Value scalarStart, scalarLimit, scalarDelta; + if (isFloatDType) { + scalarStart = getItemOp(binder, rewriter, start); + scalarLimit = getItemOp(binder, rewriter, limit); + scalarDelta = getItemOp(binder, rewriter, delta); + } else { + scalarStart = getItemOp(binder, rewriter, start); + scalarLimit = getItemOp(binder, rewriter, limit); + scalarDelta = getItemOp(binder, rewriter, delta); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, scalarStart, scalarLimit, scalarDelta, none, + none, none, none); + return success(); + }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 91421d944129..593a993c8968 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1179,3 +1179,58 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32> %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[2,3,1,4],f32> return %0 : !torch.vtensor<[2,3,1,4],f32> } + +// CHECK-LABEL: func.func @test_range_float64_type + func.func @test_range_float64_type(%arg0: !torch.vtensor<[],f64>, %arg1: !torch.vtensor<[],f64>, %arg2: !torch.vtensor<[],f64>) -> !torch.vtensor<[2],f64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] torch.constant.none + // CHECK: torch.aten.item %arg0 : !torch.vtensor<[],f64> -> !torch.float + // CHECK: torch.aten.item %arg1 : !torch.vtensor<[],f64> -> !torch.float + // CHECK: torch.aten.item %arg2 : !torch.vtensor<[],f64> -> !torch.float + // CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.float, !torch.float, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],f64> + %0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],f64>, !torch.vtensor<[],f64>, !torch.vtensor<[],f64>) -> !torch.vtensor<[2],f64> + return %0 : !torch.vtensor<[2],f64> + } + +// CHECK-LABEL: func.func @test_range_float32_type + func.func @test_range_float32_type(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] torch.constant.none + // CHECK: torch.aten.item %arg0 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.float, !torch.float, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],f32> + %0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32>) -> !torch.vtensor<[2],f32> + return %0 : !torch.vtensor<[2],f32> + } + +// CHECK-LABEL: func.func @test_range_int64_type + func.func @test_range_int64_type(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] torch.constant.none + // CHECK: torch.aten.item %arg0 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: torch.aten.item %arg1 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: torch.aten.item %arg2 : !torch.vtensor<[],si64> -> !torch.int + // CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si64> + %0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[2],si64> + return %0 : !torch.vtensor<[2],si64> + } + +// CHECK-LABEL: func.func @test_range_int32_type + func.func @test_range_int32_type(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>) -> !torch.vtensor<[2],si32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] torch.constant.none + // CHECK: torch.aten.item %arg0 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: torch.aten.item %arg1 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: torch.aten.item %arg2 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si32> + %0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si32>, !torch.vtensor<[],si32>, !torch.vtensor<[],si32>) -> !torch.vtensor<[2],si32> + return %0 : !torch.vtensor<[2],si32> + } + + // CHECK-LABEL: func.func @test_range_int16_type + func.func @test_range_int16_type(%arg0: !torch.vtensor<[],si16>, %arg1: !torch.vtensor<[],si16>, %arg2: !torch.vtensor<[],si16>) -> !torch.vtensor<[2],si16> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] torch.constant.none + // CHECK: torch.aten.item %arg0 : !torch.vtensor<[],si16> -> !torch.int + // CHECK: torch.aten.item %arg1 : !torch.vtensor<[],si16> -> !torch.int + // CHECK: torch.aten.item %arg2 : !torch.vtensor<[],si16> -> !torch.int + // CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si16> + %0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si16>, !torch.vtensor<[],si16>, !torch.vtensor<[],si16>) -> !torch.vtensor<[2],si16> + return %0 : !torch.vtensor<[2],si16> + } From f78ec78ac85bf551952d5421befbc6005b232546 Mon Sep 17 00:00:00 2001 From: James Newling Date: Mon, 15 Jan 2024 11:44:45 -0800 Subject: [PATCH 082/283] Adjust bound check to be the same as PyTorch native (i.e. stricter) (#2755) prims.expand expects the start and end dimensions to be strictly less than the rank of the tensor. --- lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp | 4 ++-- .../build_tools/abstract_interp_lib_gen.py | 10 +++++++--- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 62aa96086e83..fb21984917cd 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6514,7 +6514,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int0 = torch.constant.int 0\n" " %int1 = torch.constant.int 1\n" " %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %1 = torch.aten.le.int %arg1, %0 : !torch.int, !torch.int -> !torch.bool\n" +" %1 = torch.aten.lt.int %arg1, %0 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" @@ -6522,7 +6522,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " torch.prim.If.yield\n" " }\n" " %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" -" %3 = torch.aten.le.int %arg2, %2 : !torch.int, !torch.int -> !torch.bool\n" +" %3 = torch.aten.lt.int %arg2, %2 : !torch.int, !torch.int -> !torch.bool\n" " torch.prim.If %3 -> () {\n" " torch.prim.If.yield\n" " } else {\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 0bcc3f02343b..a4ade77364f4 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -241,17 +241,21 @@ def prims〇convert_element_type〡shape(a: List[int], dtype: int) -> List[int]: def prims〇collapse〡shape(a: List[int], start: int, end: int) -> List[int]: # Obtained through trial and error on a few examples in PyTorch: - assert start <= len(a), "start out of bounds" - assert end <= len(a), "end out of bounds" + assert start < len(a), "start out of bounds" + assert end < len(a), "end out of bounds" assert start >= 0, "start out of bounds" assert end >= 0, "end out of bounds" assert start <= end, "start must be less than or equal to end" - # Example: + # Examples: # # torch._prims.collapse(torch.empty(2,3,4), 1,2).shape # is # torch.Size([2, 12]) + # + # torch._prims.collapse(torch.empty(2,3,4), 1,3).shape + # gives + # --> 524 assert idx >= 0 and idx < rank or idx == 0 collapsed: List[int] = [] for i in range(start): From f85e5c932bb6462d66d290ac11860e916f77243a Mon Sep 17 00:00:00 2001 From: lonely eagle <75576166+linuxlonelyeagle@users.noreply.github.com> Date: Tue, 16 Jan 2024 14:29:34 +0800 Subject: [PATCH 083/283] [Torch Dialect] support aten.isneginf, aten.isposinf, aten.nan_to_num (#2743) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 72 +++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 62 ++++++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 80 +++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 3 + projects/pt1/e2e_testing/xfail_sets.py | 7 ++ .../build_tools/abstract_interp_lib_gen.py | 29 +++++++ .../build_tools/torch_ods_gen.py | 3 + .../test_suite/elementwise.py | 79 ++++++++++++++++++ 8 files changed, 335 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 758b10315391..b2c7bf32e368 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8543,6 +8543,52 @@ def Torch_AtenIsinfOp : Torch_Op<"aten.isinf", [ }]; } +def Torch_AtenIsneginfOp : Torch_Op<"aten.isneginf", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::isneginf : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIsneginfOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenIsneginfOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenIsposinfOp : Torch_Op<"aten.isposinf", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::isposinf : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenIsposinfOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenIsposinfOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenAllOp : Torch_Op<"aten.all", [ AllowsTypeRefinement, HasValueSemantics, @@ -10473,6 +10519,32 @@ def Torch_AtenWhereScalarSelfOp : Torch_Op<"aten.where.ScalarSelf", [ }]; } +def Torch_AtenNanToNumOp : Torch_Op<"aten.nan_to_num", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalFloatType:$nan, + AnyTorchOptionalFloatType:$posinf, + AnyTorchOptionalFloatType:$neginf + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNanToNumOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenNanToNumOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenSliceTensorOp : Torch_Op<"aten.slice.Tensor", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index fb21984917cd..38a88fe16ac4 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6702,6 +6702,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.isneginf\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.isposinf\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.ne.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -7874,6 +7882,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg2) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.nan_to_num\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.lerp.Tensor\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list, !torch.list) -> !torch.list\n" " %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list, !torch.list) -> !torch.list\n" @@ -9739,6 +9751,52 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.isneginf\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int9 = torch.constant.int 9\n" +" %int10 = torch.constant.int 10\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %3 = torch.aten.ne.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %3 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %int11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.isposinf\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int9 = torch.constant.int 9\n" +" %int10 = torch.constant.int 10\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.ne.int %0#1, %int10 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %3 = torch.aten.ne.int %0#1, %int9 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %3 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %int11 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.ne.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" " %int11 = torch.constant.int 11\n" " return %int11 : !torch.int\n" @@ -10742,6 +10800,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.nan_to_num\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.nll_loss_forward\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.int, %arg4: !torch.int) -> !torch.tuple {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 0a3ce2ea7797..9c4776231cc8 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -932,6 +932,40 @@ class DecomposeAtenIsinfOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenIsneginfOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenIsneginfOp op, + PatternRewriter &rewriter) const override { + mlir::FloatType f64Type = rewriter.getF64Type(); + Value inf = rewriter.create( + op.getLoc(), + rewriter.getFloatAttr( + f64Type, APFloat::getInf(f64Type.getFloatSemantics(), true))); + rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), + inf); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenIsposinfOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenIsposinfOp op, + PatternRewriter &rewriter) const override { + mlir::FloatType f64Type = rewriter.getF64Type(); + Value inf = rewriter.create( + op.getLoc(), + rewriter.getFloatAttr(f64Type, + APFloat::getInf(f64Type.getFloatSemantics()))); + rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), + inf); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenReshapeOp : public OpRewritePattern { public: @@ -2471,6 +2505,49 @@ class DecomposeAtenWhereScalarSelfOp }; } // namespace +namespace { +class DecomposeAtenNanToNumOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNanToNumOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + mlir::FloatType f64Type = rewriter.getF64Type(); + Value nan = op.getNan(); + Value posinf = op.getPosinf(); + Value neginf = op.getNeginf(); + auto baseType = + ValueTensorType::getWithLeastStaticInformation(op.getContext()); + if (dyn_cast_or_null(nan.getDefiningOp())) + nan = rewriter.create( + loc, rewriter.getFloatAttr( + f64Type, APFloat::getZero(f64Type.getFloatSemantics()))); + if (dyn_cast_or_null(posinf.getDefiningOp())) + posinf = rewriter.create( + loc, rewriter.getFloatAttr( + f64Type, APFloat::getInf(f64Type.getFloatSemantics()))); + if (dyn_cast_or_null(neginf.getDefiningOp())) + neginf = rewriter.create( + loc, + rewriter.getFloatAttr( + f64Type, APFloat::getInf(f64Type.getFloatSemantics(), true))); + Value isNan = + rewriter.create(loc, baseType, op.getSelf()); + Value where = rewriter.create( + loc, baseType, isNan, nan, op.getSelf()); + Value isposinf = + rewriter.create(loc, baseType, where); + where = rewriter.create( + loc, baseType, isposinf, posinf, where); + Value isneginf = + rewriter.create(loc, baseType, where); + rewriter.replaceOpWithNewOp( + op, op.getType(), isneginf, neginf, where); + return success(); + } +}; +} // namespace + // Decompose aten.masked_fill.Scalar into aten.where.self op. namespace { class DecomposeAtenMaskedFillScalarOp @@ -6393,6 +6470,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -6448,6 +6526,8 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 933140d3013d..e76adb9b89dc 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -431,8 +431,11 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e04657df4d2c..8a440c16b882 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -473,6 +473,7 @@ "ElementwiseAtenWhereSelfModule_basic", "ElementwiseWhereScalarOtherStaticModule_basic", "ElementwiseWhereScalarSelfStaticModule_basic", + "ElementwiseNanToNumModule_Basic", "ElementwiseBitwiseAndStaticShapeModule_basic", "ElementwiseBitwiseNotInt64Module_basic", "ElementwiseBitwiseNotInt32Module_basic", @@ -1039,6 +1040,8 @@ "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", "ElementwiseAtenDivIntScalarModule_basic", "ElementwiseAtenIsinfOpModule_basic", + "ElementwiseAtenIsneginfOpModule_basic", + "ElementwiseAtenIsposinfOpModule_basic", "ElementwiseAtenLogicalOrOpBrodcastModule_basic", "ElementwiseAtenLogicalOrOpDiffArgs1Module_basic", "ElementwiseAtenLogicalOrOpDiffArgs2Module_basic", @@ -1090,6 +1093,8 @@ "ElementwiseGtIntTensorModule_basic", "ElementwiseGtMixed2ScalarModule_basic", "ElementwiseIsinfModule_basic", + "ElementwiseAtenIsneginfOpModule_basic", + "ElementwiseAtenIsposinfOpModule_basic", "ElementwiseIsnanModule_basic", "ElementwiseLeFloatTensorModule_basic", "ElementwiseLeIntTensorModule_basic", @@ -1146,6 +1151,7 @@ "ElementwiseUnaryModule_basic", "ElementwiseUnsqueezeBroadcastModule_basic", "ElementwiseWhereScalarModule_basic", + "ElementwiseNanToNumModule_Basic", "EmbeddingModule1DIndices_basic", "EmbeddingModuleI32Static_basic", "FlattenRank0Module_basic", @@ -1511,6 +1517,7 @@ "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", + "ElementwiseNanToNumModule_Basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseDequantizePerTensorModule_basic" } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index a4ade77364f4..640c0bbfdc28 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -341,6 +341,12 @@ def aten〇isnan〡shape(self: List[int]) -> List[int]: def aten〇isinf〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇isneginf〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇isposinf〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇ne〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, other) @@ -1062,6 +1068,9 @@ def aten〇where〇ScalarOther〡shape(condition: List[int], self: List[int], ot def aten〇where〇ScalarSelf〡shape(condition: List[int], self: float, other: List[int]) -> List[int]: return upstream_shape_functions.broadcast(condition, other) +def aten〇nan_to_num〡shape(self: List[int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇lerp〇Tensor〡shape(self: List[int], end: List[int], weight: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, upstream_shape_functions.broadcast(end, weight)) @@ -2529,6 +2538,20 @@ def aten〇isnan〡dtype(self_rank_dtype: Tuple[int, int]) -> int: def aten〇isinf〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.bool +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex128, torch.complex64})) +def aten〇isneginf〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.complex128 and self_dtype != torch.complex64 + return torch.bool + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex128, torch.complex64})) +def aten〇isposinf〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + assert self_dtype != torch.complex128 and self_dtype != torch.complex64 + return torch.bool + @check_dtype_function(_check_two_tensor_op()) def aten〇ne〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: return torch.bool @@ -3260,6 +3283,12 @@ def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], sel dtypes = [get_dtype_of_scalar(self), other_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇nan_to_num〡dtype(self_rank_dtype: Tuple[int, int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function( [Invocation(TensorOfShape(2, 3, dtype=torch.float32), TensorOfShape(2, dtype=torch.int64), TensorOfShape(3, dtype=torch.float32), reduction=0, ignore_index=0), diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 604294f7409b..93f1e741dd6e 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -567,6 +567,8 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)") emit("aten::isnan : (Tensor) -> (Tensor)") emit("aten::isinf : (Tensor) -> (Tensor)") + emit("aten::isneginf : (Tensor) -> (Tensor)") + emit("aten::isposinf : (Tensor) -> (Tensor)") emit("aten::all : (Tensor) -> (Tensor)") emit("aten::all.bool : (bool[]) -> (bool)") emit("aten::all.dim : (Tensor, int, bool) -> (Tensor)") @@ -641,6 +643,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)") emit("aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)") emit("aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)") + emit("aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)") emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)", has_folder=True) emit("aten::len.Tensor : (Tensor) -> (int)") emit("aten::cpu : (Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 23a22142c4d5..9b857839db2b 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -339,6 +339,33 @@ def ElementwiseWhereScalarSelfStaticModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseNanToNumModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.float32, True) + ]) + def forward(self, a): + return torch.ops.aten.nan_to_num(a, 0.0, 1.0, -1.0) + +@register_test_case(module_factory=lambda: ElementwiseNanToNumModule()) +def ElementwiseNanToNumModule_Basic(module, tu: TestUtils): + module.forward(torch.tensor( + [ + [float('nan'), 0.0, float('nan'), 0.0], + [float('inf'), 0.0, float('inf'), 0.0], + [float('-inf'), 0.0, float('-inf'), 0.0] + ] + )) + + +# ============================================================================== + + # Addition is an interesting special case of a binary op, because under the hood # it carries a third scalar "alpha" parameter, which needs special handling. class ElementwiseAddModule(torch.nn.Module): @@ -3463,6 +3490,58 @@ def ElementwiseAtenIsinfOpModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseAtenIsneginfOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.isneginf(x) + +@register_test_case(module_factory=lambda: ElementwiseAtenIsneginfOpModule()) +def ElementwiseAtenIsneginfOpModule_basic(module, tu:TestUtils): + test_input = torch.tensor( + [ + [1, float('-inf'), 2, float('inf'), float('nan')], + [1, float('-inf'), float('inf'), float('nan'), 3], + ] + ) + module.forward(test_input) + + +# ============================================================================== + + +class ElementwiseAtenIsposinfOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([2, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.isposinf(x) + +@register_test_case(module_factory=lambda: ElementwiseAtenIsposinfOpModule()) +def ElementwiseAtenIsposinfOpModule_basic(module, tu:TestUtils): + test_input = torch.tensor( + [ + [1, float('-inf'), 2, float('inf'), float('nan')], + [1, float('-inf'), float('inf'), float('nan'), 3], + ] + ) + module.forward(test_input) + + +# ============================================================================== + + class ElementwiseAtenLogicalNotOpPromoteModule(torch.nn.Module): def __init__(self): super().__init__() From a8538e1e3fb98c5d2c6170dd0e5d549d1d6b2632 Mon Sep 17 00:00:00 2001 From: Sungsoon Cho <1025787+godot73@users.noreply.github.com> Date: Mon, 15 Jan 2024 22:49:29 -0800 Subject: [PATCH 084/283] Decompose AtenNormalFunctionalOp into AtenRandn* and other arithmetic. (#2737) --- .../Transforms/AbstractInterpLibrary.cpp | 17 ++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 34 +++++++++++++++++-- .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 1 + .../build_tools/abstract_interp_lib_gen.py | 13 +++++++ .../torch_mlir_e2e_test/test_suite/rng.py | 21 ++++++++++++ 6 files changed, 85 insertions(+), 2 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 38a88fe16ac4..a558db372d45 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7655,6 +7655,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.randn.generator\"(%arg0: !torch.list, %arg1: !torch.any, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.normal_functional\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.arange.start_step\"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" " %0 = torch.derefine %arg0 : !torch.float to !torch.union\n" " %1 = torch.derefine %arg1 : !torch.float to !torch.union\n" @@ -11557,6 +11560,20 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.normal_functional\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.randn.generator\"(%arg0: !torch.list, %arg1: !torch.any, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.int {\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int6 = torch.constant.int 6\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 9c4776231cc8..8afccbba0346 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3669,9 +3669,38 @@ class DecomposeAtenExponentialOp : public OpRewritePattern { return success(); } }; -} // namespace -namespace { +// aten.normal_functional(mean, sigma) = randn() * sigma + mean. +class DecomposeAtenNormalFunctionalOp + : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenNormalFunctionalOp op, + PatternRewriter &rewriter) const override { + if (!op.getGenerator().getType().isa()) + return rewriter.notifyMatchFailure( + op, "The generator has to be None because only global default " + "generator is supported"); + + Location loc = op.getLoc(); + Type resultType = op.getType(); + Value std = op.getStd(); + Value mean = op.getMean(); + + Value none = rewriter.create(loc); + Value one = + rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); + Value randN = rewriter.create( + loc, resultType, op.getSelf(), /*dtype=*/none, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none, /*memory_format=*/none); + Value stdRandN = + rewriter.create(loc, resultType, randN, std); + rewriter.replaceOpWithNewOp(op, resultType, stdRandN, + mean, /*alpha=*/one); + return success(); + } +}; + template class DecomposeAtenAddCLikeOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -6591,6 +6620,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index e76adb9b89dc..da7811ad0a3f 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -494,6 +494,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 8a440c16b882..7ba4c309c14f 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1484,6 +1484,7 @@ "VarMeanUnbiasedModule_basic", "RandnLikeModule_basic", "RandnLikeDtypeModule_basic", + "NormalFunctionalModule_basic", "BernoulliFloatModule_basic", "BernoulliModule_basic", "BernoulliPModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 640c0bbfdc28..bf2f45e378a0 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -902,6 +902,9 @@ def aten〇randn〡shape(size: List[int], dtype: Optional[int] = None, layout: O def aten〇randn〇generator〡shape(size: List[int], generator: Any, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return size +def aten〇normal_functional〡shape(self: List[int], mean: float = 0., std: float = 1., generator: Any = None) -> List[int]: + return self + def aten〇arange〇start_step〡shape(start: float, end: float, step: float = 1, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return upstream_shape_functions.arange_start_step(start, end, step, dtype, layout, device, pin_memory) @@ -3822,6 +3825,16 @@ def aten〇randn〡dtype(size: List[int], dtype: Optional[int] = None, layout: O assert not is_integer_dtype(dtype) return dtype +@check_dtype_function(_check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64})) +def aten〇normal_functional〡dtype(self_rank_dtype: Tuple[int, int], mean: float = 0., std: float = 1., generator: Any = None) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype is None: + return torch.float32 + assert not is_integer_dtype(self_dtype) + return self_dtype + @check_dtype_function([Invocation(size=[1], generator=None), Invocation(size=[1], generator=None, dtype=torch.float32), ErrorInvocation(size=[1], generator=None, dtype=torch.int32), diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py index dedd2b398bd4..2b8e186ff401 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/rng.py @@ -605,3 +605,24 @@ def forward(self, x): @register_test_case(module_factory=lambda: RandnLikeDtypeModule()) def RandnLikeDtypeModule_basic(module, tu: TestUtils): module.forward(tu.rand(256, 1024).double()) +# ============================================================================== + +class NormalFunctionalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float64, True), + ]) + def forward(self, x): + a = torch.ops.aten.normal_functional(x, mean=-5.0, std=2.0) + mean = torch.mean(a) + std = torch.std(a) + return mean, std + + +@register_test_case(module_factory=lambda: NormalFunctionalModule()) +def NormalFunctionalModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2048, 4096).double()) From eed144bfbc4cf9eafebba1949aa81f615a865eea Mon Sep 17 00:00:00 2001 From: Phaneesh Barwaria Date: Tue, 16 Jan 2024 19:06:54 +0530 Subject: [PATCH 085/283] [ONNX][MLIR] add Identity op support (#2754) --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 13 +++++++++++++ .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 8 ++++++++ 2 files changed, 21 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index c0a7473e4601..d6b4e5bf1046 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -560,4 +560,17 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); + patterns.onOp( + "Identity", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value tensor; + if (binder.tensorOperand(tensor) || + binder.tensorResultType(resultType)) { + return failure(); + } + Value noneVal = rewriter.create(binder.getLoc()); + rewriter.replaceOpWithNewOp( + binder.op, resultType, tensor, /*memory_format=*/noneVal); + return success(); + }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index c85659c25aa8..5d6b86172597 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -456,3 +456,11 @@ func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4], %0 = torch.operator "onnx.Or"(%arg0, %arg1) : (!torch.vtensor<[3,4],i1>, !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> return %0 : !torch.vtensor<[3,4],i1> } + +// CHECK-LABEL: func.func @test_identity + func.func @test_identity(%arg0: !torch.vtensor<[3,4], f32>) -> !torch.vtensor<[3,4], f32> attributes {torch.onnx_meta.ir_version = 14 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %0 = torch.aten.clone %arg0, %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.none -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Identity"(%arg0) : (!torch.vtensor<[3,4], f32>) -> !torch.vtensor<[3,4], f32> + return %0 : !torch.vtensor<[3,4], f32> + } From 77a03f20690801451528ada8de7ba9ddf266f1e1 Mon Sep 17 00:00:00 2001 From: Ze Zhang Date: Thu, 18 Jan 2024 12:32:23 -0800 Subject: [PATCH 086/283] torch-to-tosa lowering support for AtenLinalgVectorNormOp (#2734) This PR add torch-to-tosa lowering support for AtenLinalgVectorNormOp e2e test: python -m e2e_testing.main --config=tosa LIT tests: cmake --build build --target tools/torch-mlir/all --------- Co-authored-by: Ze Zhang --- .../TorchToTosa/TosaLegalizeCommon.h | 6 ++ lib/Conversion/TorchToTosa/TorchToTosa.cpp | 2 + .../TorchToTosa/TosaLegalizeCommon.cpp | 71 +++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 9 +++ test/Conversion/TorchToTosa/basic.mlir | 28 ++++++++ 5 files changed, 116 insertions(+) diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h index c1b355e3c50d..16bf235de89e 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h @@ -106,6 +106,12 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, bool keep_dims); +// Lowers LinalgVectorNorm to a sequence of TOSA ops. +std::optional +convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op, + RankedTensorType output_type, Value input_value, + ElementsAttr axes_elems, bool keep_dims); + } // namespace tosa } // namespace mlir diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 6555f06e8702..919fe73b2092 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -5089,6 +5089,8 @@ class ConvertTorchToTosa : public ConvertTorchToTosaBase { mlir::tosa::convertReduceMeanOp) INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenSumDimIntListOp, mlir::tosa::convertReduceSumOp) + INSERT_NDIMS_REDUCTION_OP_PATTERN(AtenLinalgVectorNormOp, + mlir::tosa::convertLinalgVectorNormOp) #undef INSERT_NDIMS_REDUCTION_OP_PATTERN #define INSERT_ONEDIM_REDUCTION_OP_PATTERN(AtenOp, ConversionFunc) \ diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index 287a8943594a..acaa60ffc9ad 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -9,6 +9,7 @@ #include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" #include "torch-mlir/Conversion/Utils/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include #include @@ -971,5 +972,75 @@ convertReduceMeanOp(PatternRewriter &rewriter, Operation *op, return val; } +// Lowers LinalgVectorNorm to a sequence of TOSA ops. +std::optional +convertLinalgVectorNormOp(PatternRewriter &rewriter, Operation *op, + RankedTensorType output_type, Value input_value, + ElementsAttr axes_elems, bool keep_dims) { + RankedTensorType input_type = + input_value.getType().dyn_cast(); + if (!input_type) + return std::nullopt; + + Type elemType = output_type.getElementType(); + if (!elemType.isa()) { + op->emitOpError("Only floating-point datatype legalization supported for " + "AtenLinalgVectorNorm op"); + return std::nullopt; + } + + auto linalgVectorNormOp = cast(op); + // TODO: Add support for ord = {0, +inf, -inf}. + auto epsilon = 1e-5; + double ordLiteralFloat = 1.0; + int64_t ordLiteralInt = 1; + Value ordVal; + if (matchPattern(linalgVectorNormOp.getOrd(), + torch::Torch::m_TorchConstantFloat(&ordLiteralFloat))) { + ordVal = tosa::getConstTensor(rewriter, op, + {static_cast(ordLiteralFloat)}, + {}, elemType) + .value(); + } else if (matchPattern(linalgVectorNormOp.getOrd(), + torch::Torch::m_TorchConstantInt(&ordLiteralInt))) { + ordVal = tosa::getConstTensor(rewriter, op, + {static_cast(ordLiteralInt)}, + {}, elemType) + .value(); + } else { + op->emitOpError("only support FP or INT type ord parameter"); + return std::nullopt; + } + + if (fabs(ordLiteralFloat) < epsilon || + fabs(static_cast(ordLiteralInt)) < epsilon) { + op->emitOpError("unimplemented: L0 norm"); + return std::nullopt; + } + + if (std::isinf(ordLiteralFloat) || + std::isinf(static_cast(ordLiteralInt))) { + op->emitOpError("unimplemented: ord = +/- inf"); + return std::nullopt; + } + + auto absVal = CreateOpAndInfer(rewriter, op->getLoc(), + input_type, input_value) + .getResult(); + auto powVal = CreateOpAndInfer(rewriter, op->getLoc(), + input_type, absVal, ordVal) + .getResult(); + std::optional result = convertReduceSumOp( + rewriter, op, output_type, powVal, axes_elems, keep_dims); + if (!result) + return std::nullopt; + auto reciprocalVal = CreateOpAndInfer( + rewriter, op->getLoc(), ordVal.getType(), ordVal) + .getResult(); + return CreateOpAndInfer(rewriter, op->getLoc(), output_type, + result.value(), reciprocalVal) + .getResult(); +} + } // namespace tosa } // namespace mlir diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 7ba4c309c14f..0ffc3c4fd606 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1024,6 +1024,7 @@ "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "Conv2dWithPaddingModule_basic", "Convolution2DStaticModule_basic", + "CosineSimilarityStaticModule_basic", "DetachModule_basic", "DropoutEvalFloatModule_basic", "DropoutEvalIntModule_basic", @@ -1181,6 +1182,8 @@ "LeakyReluBackwardModule_basic", "LeakyReluBackwardStaticModule_basic", "LiftFreshCopyModule_basic", + "LinalgVectorNormKeepDimModule_basic", + "LinalgVectorNormModule_basic", "MaskedFillScalarDefaultModule_basic", "MaskedFillScalarIntValueModule_basic", "MaskedFillScalarIntValueStaticModule_basic", @@ -1217,6 +1220,9 @@ "NewZerosModuleInt2D_basic", "NewZerosModuleInt3D_basic", "NewZerosStaticModuleLayoutStrided_basic", + "NormalizeModule_basic", + "NormScalarOptDimKeepDimModule_basic", + "NormScalarOptDimModule_basic", "NumToTensorFloatModule_basic", "NumToTensorIntModule_basic", "NumpyTRank0Module_basic", @@ -1349,7 +1355,10 @@ "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dStaticEvenMultiple_basic", + "CosineSimilarityModule_basic", "NativeGroupNormBackwardModule_basic", + "ReduceFrobeniusNormKeepDimModule_basic", + "ReduceFrobeniusNormModule_basic", "SliceWholeTensorModule_basic", "TensorFloatModule_basic", "TensorIntModule_basic", diff --git a/test/Conversion/TorchToTosa/basic.mlir b/test/Conversion/TorchToTosa/basic.mlir index b36acc779547..c6369e6fa769 100644 --- a/test/Conversion/TorchToTosa/basic.mlir +++ b/test/Conversion/TorchToTosa/basic.mlir @@ -261,6 +261,34 @@ func.func @test_reduce_sum_dims$basic(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> ! // ----- +// CHECK-LABEL: func.func @test_linalg_vector_norm$basic( +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,151,64],f32>) -> !torch.vtensor<[3,151,1],f32> { +// CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[3,151,64],f32> -> tensor<3x151x64xf32> +// CHECK: %[[ARG1:.*]] = torch.constant.float 2.000000e+00 +// CHECK: %[[ARG2:.*]] = torch.constant.int -1 +// CHECK: %[[ARG3:.*]] = torch.constant.bool true +// CHECK: %[[ARG4:.*]] = torch.constant.none +// CHECK: %[[ARG5:.*]] = torch.prim.ListConstruct %[[ARG2]] : (!torch.int) -> !torch.list +// CHECK: %[[ARG6:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[ARG7:.*]] = tosa.abs %[[ARG0_BUILTIN]] : (tensor<3x151x64xf32>) -> tensor<3x151x64xf32> +// CHECK: %[[ARG8:.*]] = tosa.pow %[[ARG7]], %[[ARG6]] : (tensor<3x151x64xf32>, tensor) -> tensor<3x151x64xf32> +// CHECK: %[[ARG9:.*]] = tosa.reduce_sum %[[ARG8]] {axis = 2 : i32} : (tensor<3x151x64xf32>) -> tensor<3x151x1xf32> +// CHECK: %[[ARG10:.*]] = tosa.reciprocal %[[ARG6]] : (tensor) -> tensor +// CHECK: %[[ARG11:.*]] = tosa.pow %[[ARG9]], %[[ARG10]] : (tensor<3x151x1xf32>, tensor) -> tensor<3x151x1xf32> +// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[ARG11]] : tensor<3x151x1xf32> -> !torch.vtensor<[3,151,1],f32> +// CHECK: return %[[RESULT]] : !torch.vtensor<[3,151,1],f32> +func.func @test_linalg_vector_norm$basic(%arg0: !torch.vtensor<[3,151,64],f32>) -> (!torch.vtensor<[3,151,1],f32>) { + %float2.000000e00 = torch.constant.float 2.000000e+00 + %int-1 = torch.constant.int -1 + %true = torch.constant.bool true + %none = torch.constant.none + %1 = torch.prim.ListConstruct %int-1 : (!torch.int) -> !torch.list + %2 = torch.aten.linalg_vector_norm %arg0, %float2.000000e00, %1, %true, %none : !torch.vtensor<[3,151,64],f32>, !torch.float, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,151,1],f32> + return %2 : !torch.vtensor<[3,151,1],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_reduce_sum$basic( // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[1],f32> { // CHECK: %[[ARG0_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor From bd11877f6f4c741d0b709b58ba647766f4be5c99 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 18 Jan 2024 16:33:10 -0800 Subject: [PATCH 087/283] [onnx] Support lowering quantize linear to `torch` (#2751) We can map the per_tensor case to the `torch.aten.quantize_per_linear` operation. In this case we extract the `scale` and `zeropoint` values and directly invoke the quantization, then return the integer representation value. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 51 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 44 ++++++++++++++++ 2 files changed, 95 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 0833af54d43f..25f46d3f210a 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -41,6 +42,56 @@ Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter, void mlir::torch::onnx_c::populateDefaultDomainQtoZ( OnnxCustomOpConversionPattern &patterns) { + patterns.onOp("QuantizeLinear", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if (binder.tensorOperands(operands, 3) || + binder.tensorResultType(resultType)) + return failure(); + + Value operand = operands[0]; + Value scale = operands[1]; + Value zeropoint = operands[2]; + + auto scaleTy = scale.getType().dyn_cast(); + if (!scaleTy || !scaleTy.hasSizes()) return rewriter.notifyMatchFailure(binder.op, "requires known rank"); + if (!resultType.hasDtype()) + return rewriter.notifyMatchFailure( + binder.op, "requires known result dtype"); + + if (scaleTy.getSizes().size() == 0) { + Type qTy = resultType.getDtype(); + + if (qTy.isUnsignedInteger(8)) { + qTy = rewriter.getType(); + } else if (qTy.isSignedInteger(8)) { + qTy = rewriter.getType(); + } else if (qTy.isSignedInteger(32)) { + qTy = rewriter.getType(); + } else { + return rewriter.notifyMatchFailure(binder.op, "unsupported result dtype"); + } + + auto qTensorTy = rewriter.getType(resultType.getOptionalSizes(), qTy); + auto torchqTy = Torch::getScalarTypeForType(qTy); + + Value tyConst = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), static_cast(torchqTy))); + + scale = rewriter.create(binder.getLoc(), rewriter.getType(), scale); + zeropoint = rewriter.create(binder.getLoc(), rewriter.getType(), zeropoint); + + auto quantize = rewriter.create(binder.getLoc(), qTensorTy, operand, scale, zeropoint, tyConst); + rewriter.replaceOpWithNewOp(binder.op, resultType, quantize); + return success(); + } + + return failure(); + + } + ); patterns.onOp("Reciprocal", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 593a993c8968..1725b2a15911 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -4,6 +4,50 @@ // level constants. This is a pragmatic choice which lets us have a lot // of tests in this file, whereas the others tend to be more bespoke. + +// CHECK-LABEL: @test_quantizelinear_si8 +func.func @test_quantizelinear_si8(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8> + + // CHECK: %[[C12:.+]] = torch.constant.int 12 + // CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si8> -> !torch.int + // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %arg0, %[[SCALE]], %[[ZP]], %[[C12]] + // CHECK: %[[REPR:.+]] = torch.aten.int_repr %[[QUANT]] + // CHECK: return %[[REPR]] + return %0 : !torch.vtensor<[6],si8> +} + +// ----- + +// CHECK-LABEL: @test_quantizelinear_ui8 +func.func @test_quantizelinear_ui8(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],ui8> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],ui8> + // CHECK: %[[C13:.+]] = torch.constant.int 13 + // CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %arg0, %[[SCALE]], %[[ZP]], %[[C13]] + // CHECK: %[[REPR:.+]] = torch.aten.int_repr %[[QUANT]] + // CHECK: return %[[REPR]] + return %0 : !torch.vtensor<[6],ui8> +} + +// ----- + +// CHECK-LABEL: @test_quantizelinear_i32 +func.func @test_quantizelinear_i32(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si32>) -> !torch.vtensor<[6],si32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],si32>) -> !torch.vtensor<[6],si32> + // CHECK: %[[C14:.+]] = torch.constant.int 14 + // CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %arg0, %[[SCALE]], %[[ZP]], %[[C14]] + // CHECK: %[[REPR:.+]] = torch.aten.int_repr %[[QUANT]] + // CHECK: return %[[REPR]] + return %0 : !torch.vtensor<[6],si32> +} + +// ----- + // CHECK-LABEL: func.func @test_reciprocal func.func @test_reciprocal(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.reciprocal %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> From b5387c0f29e6cce9e1586b92df4a975ad539baa4 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 18 Jan 2024 16:47:21 -0800 Subject: [PATCH 088/283] [onnx] Lowering `onnx.dequantize_linear` to `torch` (#2759) We can make the per-tensor version of the operation to the dequantize operation via marking with the make quantized tensor component. This introductions the `qint*` and `quint*` tensor type that can be lowered to teh appropriate dequantization behavior during the torch-to-linalg conversion. --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 53 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 44 ++++++++++++++- 2 files changed, 96 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index aa3b5fc012d0..c18d681055aa 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1156,6 +1156,59 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, transposedInput, reshapeSizesList); return success(); }); + patterns.onOp( + "DequantizeLinear", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if (binder.tensorOperands(operands, 3) || + binder.tensorResultType(resultType)) + return failure(); + + Value operand = operands[0]; + Value scale = operands[1]; + Value zeropoint = operands[2]; + + auto operandTy = operand.getType().cast(); + + auto scaleTy = scale.getType().dyn_cast(); + if (!scaleTy || !scaleTy.hasSizes()) + return rewriter.notifyMatchFailure(binder.op, "requires known rank"); + if (!resultType.hasDtype()) + return rewriter.notifyMatchFailure(binder.op, + "requires known resulty dtype"); + + if (scaleTy.getSizes().size() == 0) { + Type qTy = operandTy.getDtype(); + + if (qTy.isUnsignedInteger(8)) { + qTy = rewriter.getType(); + } else if (qTy.isSignedInteger(8)) { + qTy = rewriter.getType(); + } else if (qTy.isSignedInteger(32)) { + qTy = rewriter.getType(); + } else { + return rewriter.notifyMatchFailure(binder.op, + "unsupported result dtype"); + } + + auto qTensorTy = rewriter.getType( + resultType.getOptionalSizes(), qTy); + scale = rewriter.create( + binder.getLoc(), rewriter.getType(), scale); + zeropoint = rewriter.create( + binder.getLoc(), rewriter.getType(), zeropoint); + + auto quantize = + rewriter.create( + binder.getLoc(), qTensorTy, operand, scale, zeropoint); + rewriter.replaceOpWithNewOp( + binder.op, resultType, quantize); + return success(); + } + + return failure(); + }); patterns.onOp("Div", 14, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index f8bc219dcb48..42a0fe743bc2 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt <%s -convert-torch-onnx-to-torch | FileCheck %s +// RUN: torch-mlir-opt <%s --split-input-file -convert-torch-onnx-to-torch | FileCheck %s // Generally, the test cases accumulated here come from running the importer // over all included backend tests that involve simple ops with no model // level constants. This is a pragmatic choice which lets us have a lot @@ -438,6 +438,48 @@ func.func @test_cos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5 return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + +// CHECK-LABEL: @test_dequantizelinear_si8 +func.func @test_dequantizelinear_si8(%arg0: !torch.vtensor<[6],si8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],f32> + // CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si8> -> !torch.int + // CHECK: %[[MAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[ZP]] + // CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[MAKE]] + // CHECK: return %[[DEQ]] + return %0 : !torch.vtensor<[6],f32> +} + +// ----- + +// CHECK-LABEL: @test_dequantizelinear_ui8 +func.func @test_dequantizelinear_ui8(%arg0: !torch.vtensor<[6],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[6],f32> + // CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[MAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[ZP]] + // CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[MAKE]] + // CHECK: return %[[DEQ]] + return %0 : !torch.vtensor<[6],f32> +} + +// ----- + +// CHECK-LABEL: @test_dequantizelinear_i32 +func.func @test_dequantizelinear_i32(%arg0: !torch.vtensor<[6],si32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si32>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],si32>, !torch.vtensor<[],f32>, !torch.vtensor<[],si32>) -> !torch.vtensor<[6],f32> + // CHECK: %[[SCALE:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[ZP:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],si32> -> !torch.int + // CHECK: %[[MAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[ZP]] + // CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[MAKE]] + // CHECK: return %[[DEQ]] + return %0 : !torch.vtensor<[6],f32> +} + +// ----- + + // CHECK-LABEL: @test_div_bcast func.func @test_div_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[3,4,5],f32> From 4de4d38b870598f766fd78653028d5bd4b27fff2 Mon Sep 17 00:00:00 2001 From: Andreas Falkenberg <149819731+afalkenberg1@users.noreply.github.com> Date: Thu, 18 Jan 2024 17:23:13 -0800 Subject: [PATCH 089/283] Initial commit of NonZero op (#2766) --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 15 +++++++++++++-- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 9 +++++++++ 2 files changed, 22 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index d6b4e5bf1046..24089e87476d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -100,8 +100,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); - }); - + }); patterns.onOp("LessOrEqual", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -149,6 +148,18 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); + patterns.onOp("NonZero", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp( "MaxPool", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 5d6b86172597..147b3f9551c5 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -450,6 +450,15 @@ func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4], // ----- +// CHECK-LABEL: func.func @test_nonzero + func.func @test_nonzero(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.nonzero %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],si64> + %0 = torch.operator "onnx.NonZero"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],si64> + return %0 : !torch.vtensor<[3,4,5],si64> + } + +// ----- + // CHECK-LABEL: func.func @test_or2d func.func @test_or2d(%arg0: !torch.vtensor<[3,4],i1>, %arg1: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_or.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],i1>, !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1> From faa4517e83d82348259165412d0744ba776360b3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ilija=20Kalini=C4=87?= Date: Fri, 19 Jan 2024 13:39:08 +0100 Subject: [PATCH 090/283] Implement lowering of torch.aten.remainder.Tensor (#2763) Closes nod-ai/SHARK-Turbine#349 --- .../TorchToLinalg/Uncategorized.cpp | 34 ++++++++-- .../Transforms/AbstractInterpLibrary.cpp | 12 ++++ .../build_tools/abstract_interp_lib_gen.py | 11 +++ .../test_suite/elementwise.py | 67 +++++++++++++++++++ 4 files changed, 118 insertions(+), 6 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index e35136e333f0..593afeb1aa84 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1195,6 +1195,26 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return result; } + if (auto remTensor = dyn_cast(op)) { + Type newResultType = converter->convertType(remTensor.getType()) + .cast() + .getElementType(); + + Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); + Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType); + Value result; + + if (newResultType.isa()) { + result = b.create(loc, self, other); + } else if (newResultType.isa()) { + result = b.create(loc, self, other); + } else { + remTensor.emitError( + "Unsupported type encountered for AtenRemainderTensorOp."); + } + + return result; + } if (auto reciprocal = dyn_cast(op)) { Type dtype = converter->convertType(reciprocal.getType()) .cast() @@ -1457,8 +1477,8 @@ class ConvertElementwiseOp : public ConversionPattern { AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, - AtenRemainderScalarOp, AtenAbsOp, AtenReciprocalOp, - AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, + AtenRemainderScalarOp, AtenRemainderTensorOp, AtenAbsOp, + AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, @@ -1471,7 +1491,8 @@ class ConvertElementwiseOp : public ConversionPattern { AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenRealOp, AtenImagOp, - AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp>(op)) + AtenDequantizeSelfOp, AtenDequantizeTensorOp, + AtenQuantizePerTensorOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); if (failed(verifyLinalgCompatibleTypes(op, rewriter))) @@ -2239,9 +2260,10 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, - AtenTrilOp, AtenRemainderScalarOp, AtenBitwiseNotOp, AtenRoundOp, - AtenFillScalarOp, AtenFillTensorOp, AtenRealOp, AtenImagOp, - AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp>(); + AtenTrilOp, AtenRemainderScalarOp, AtenRemainderTensorOp, + AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, + AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, + AtenQuantizePerTensorOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index a558db372d45..33247e6639e1 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6758,6 +6758,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.remainder.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.floor_divide.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -10725,6 +10729,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.remainder.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.baddbmm\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number, %arg4: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index bf2f45e378a0..3c4d41fd1d7f 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -383,6 +383,9 @@ def aten〇div〇Scalar〡shape(self: List[int], other: float) -> List[int]: def aten〇remainder〇Scalar〡shape(self: List[int], other: float) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇remainder〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇floor_divide〇Scalar〡shape(self: List[int], other: float) -> List[int]: return upstream_shape_functions.unary(self) @@ -3224,6 +3227,14 @@ def aten〇remainder〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: U dtypes = [self_dtype, get_dtype_of_scalar(other)] return promote_dtypes(ranks, dtypes) +@check_dtype_function(_check_two_tensor_op()) +def aten〇remainder〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + # TODO: This should be fixed by switching to FakeTensor instead of Meta tensor @check_dtype_function( _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1), (1, 1, 1), (1, 1, 1)], tensor_device="cpu", error_types={torch.bool}) + diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 9b857839db2b..a422772fc298 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -2265,6 +2265,73 @@ def ElementwiseRemainderScalarModule_Bool_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseRemainderTensorModule_Int_Float(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.remainder(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseRemainderTensorModule_Int_Float()) +def ElementwiseRemainderTensorModule_Int_Float_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, high=10).to(torch.int32), tu.rand(3, 4, high=10)) + + +# ============================================================================== + + +class ElementwiseRemainderTensorModule_Float(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.remainder(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseRemainderTensorModule_Float()) +def ElementwiseRemainderTensorModule_Float_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, high=10), tu.rand(3, 4, high=10)) + + +# ============================================================================== + +class ElementwiseRemainderTensorModule_Int(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ([-1, -1], torch.int32, True), + ]) + def forward(self, a, b): + return torch.remainder(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseRemainderTensorModule_Int()) +def ElementwiseRemainderTensorModule_Int_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, high=10, dtype=torch.int32), tu.randint(3, 4, high=10, dtype=torch.int32)) + +# ============================================================================== + + class ElementwiseDivTensorFloatModule(torch.nn.Module): def __init__(self): From 704cfdaf0893fa4078ab744a915a56367282dcaa Mon Sep 17 00:00:00 2001 From: John Wu Date: Fri, 19 Jan 2024 07:39:46 -0800 Subject: [PATCH 091/283] Add aten.pool_max3d support to torch-to-linalg (#2735) Added verification logic to the abstract_interpreter_lib_gen.py Also made some unit tests Initially, I thought we can use `linalg::pooling_ndhwc_max` to help implement this problem. However, on a 5-dimensional matrix it does the pooling on dimensions (2, 3, 4) which is not what we want. We want pooling on dimensions (3, 4, 5). To achieve this, we would need to lower our code using the `linalg` dialect. Turns out the pooling code in `linalg` looks like this. ``` func @max_pooling_ncdhw(%I: memref, %K: memref<3xindex>, %O: memref, %strides: memref<3xindex>, %dilations: memref<3xindex>) { %c0 = arith.constant 0 : index %c1 = arith.constant 1 : index %N = memref.dim %I, %c0 : memref %C = memref.dim %I, %c1 : memref %D = memref.dim %I, 2 : memref %H = memref.dim %I, 3 : memref %W = memref.dim %I, 4 : memref %kernel_d = memref.load %K[%c0] : memref<3xindex> %kernel_h = memref.load %K[%c1] : memref<3xindex> %kernel_w = memref.load %K[2] : memref<3xindex> %stride_d = memref.load %strides[%c0] : memref<3xindex> %stride_h = memref.load %strides[%c1] : memref<3xindex> %stride_w = memref.load %strides[2] : memref<3xindex> %dilation_d = memref.load %dilations[%c0] : memref<3xindex> %dilation_h = memref.load %dilations[%c1] : memref<3xindex> %dilation_w = memref.load %dilations[2] : memref<3xindex> linalg.generic { indexing_maps = [ affine_map<(n, c, d, h, w, kd, kh, kw) -> (n, c, d * %stride_d + kd * %dilation_d, h * %stride_h + kh * %dilation_h, w * %stride_w + kw * %dilation_w)>, // Map for input tensor affine_map<(n, c, d, h, w, kd, kh, kw) -> (kd, kh, kw)>, // Map for kernel tensor affine_map<(n, c, d, h, w, kd, kh, kw) -> (n, c, d, h, w)> // Map for output tensor ], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"], doc = "3D Max Pooling NCDHW with Strides, Dilations, and Kernel Size" } ins(%I, %K : memref, memref<3xindex>) outs(%O : memref) { ^bb0(%input_elem: f32, %kernel_elem: index, %output_elem: f32): %max_val = arith.maxf %input_elem, %output_elem : f32 linalg.yield %max_val : f32 } return } ``` This was implemented based on it's source code with the adjustments mentioned above: https://github.com/llvm/llvm-project/blob/4ca1b5e094280ef1af40412e3cfcb62dc3cf15bc/mlir/include/mlir/Dialect/Linalg/IR/LinalgNamedStructuredOps.yaml#L5647 Issues related to this can be found here https://github.com/nod-ai/SHARK-Turbine/issues/324 --- lib/Conversion/TorchToLinalg/Pooling.cpp | 267 +++++++++--- .../Transforms/AbstractInterpLibrary.cpp | 405 ++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 8 + .../build_tools/abstract_interp_lib_gen.py | 121 ++++++ .../torch_mlir_e2e_test/test_suite/pooling.py | 148 +++++++ test/Conversion/TorchToLinalg/pooling.mlir | 51 ++- 6 files changed, 937 insertions(+), 63 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 20c03f5ffeec..85354aad4f12 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -72,36 +72,15 @@ checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter, return success(); } -// Creates a pooling operation based on the type specified by `OpTy` and -// arguments passed. -template -static LogicalResult createPoolingOp( - Operation *op, ConversionPatternRewriter &rewriter, Value self, - bool supportNonFPInput, bool ceilMode, int64_t dimensionality, - SmallVectorImpl &kernelSizeIntValues, - SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, - SmallVectorImpl &dilationInts, Attribute initValueAttr, - SmallVectorImpl &outTensorShape, Value &paddedInput, Value &result) { - Location loc = op->getLoc(); +static Value computeOutputTensor(Operation *op, ConversionPatternRewriter &rewriter, + Value self, int64_t dimensionality, bool ceilMode, + SmallVectorImpl &strideInts, + SmallVectorImpl &paddingInts, + SmallVectorImpl &dilationInts, + SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &outTensorShape, Value initValue) { Type elementType = self.getType().cast().getElementType(); - if (!elementType.isa() && !supportNonFPInput) - return op->emitError("unimplemented: non-floating point type"); - - SmallVector lowPaddingIncludingNC = {0, 0}; - lowPaddingIncludingNC.append(paddingInts); - SmallVector highPaddingIncludingNC = lowPaddingIncludingNC; - - if (ceilMode) { - for (int64_t i = 0; i < dimensionality; ++i) { - highPaddingIncludingNC[i + 2] += strideInts[i]; - } - } - - Value initValue = - rewriter.create(loc, cast(initValueAttr)); - paddedInput = torch_to_linalg::getPaddedTensor( - op, rewriter, self, lowPaddingIncludingNC, highPaddingIncludingNC, - initValue); + Location loc = op->getLoc(); Value N = getDimOp(rewriter, loc, self, 0); Value C = getDimOp(rewriter, loc, self, 1); @@ -124,8 +103,54 @@ static LogicalResult createPoolingOp( // Create output tensor initialized with smallest floating point value. outTensorShape.insert(outTensorShape.begin(), {N, C}); - Value outTensorInitialized = - createInitTensor(rewriter, loc, outTensorShape, elementType, initValue); + return createInitTensor(rewriter, loc, outTensorShape, elementType, + initValue); +} + +static Value padInputTensor(Operation *op, ConversionPatternRewriter &rewriter, + Value self, bool ceilMode, int64_t dimensionality, + SmallVectorImpl &strideInts, + SmallVectorImpl &paddingInts, + Value initValue) { + SmallVector lowPaddingIncludingNC = {0, 0}; + lowPaddingIncludingNC.append(paddingInts); + SmallVector highPaddingIncludingNC = lowPaddingIncludingNC; + + if (ceilMode) { + for (int64_t i = 0; i < dimensionality; ++i) { + highPaddingIncludingNC[i + 2] += strideInts[i]; + } + } + + return torch_to_linalg::getPaddedTensor(op, rewriter, self, + lowPaddingIncludingNC, + highPaddingIncludingNC, initValue); +} + +// Creates a pooling operation based on the type specified by `OpTy` and +// arguments passed. +template +static LogicalResult createPoolingOp( + Operation *op, ConversionPatternRewriter &rewriter, Value self, + bool supportNonFPInput, bool ceilMode, int64_t dimensionality, + SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, + SmallVectorImpl &dilationInts, Attribute initValueAttr, + SmallVectorImpl &outTensorShape, Value &paddedInput, Value &result) { + Location loc = op->getLoc(); + Type elementType = self.getType().cast().getElementType(); + if (!elementType.isa() && !supportNonFPInput) + return op->emitError("unimplemented: non-floating point type"); + + Value initValue = + rewriter.create(loc, cast(initValueAttr)); + + paddedInput = padInputTensor(op, rewriter, self, ceilMode, dimensionality, + strideInts, paddingInts, initValue); + + auto outTensorInitialized = computeOutputTensor( + op, rewriter, self, dimensionality, ceilMode, strideInts, paddingInts, + dilationInts, kernelSizeIntValues, outTensorShape, initValue); auto stridesAttr = rewriter.getI64VectorAttr(strideInts); auto dilationAttr = rewriter.getI64VectorAttr(dilationInts); @@ -138,57 +163,174 @@ static LogicalResult createPoolingOp( ValueRange{paddedInput, windowTensor}, outTensorInitialized, stridesAttr, dilationAttr) .getResult(0); - return success(); } namespace { -class ConvertAtenMaxPool2dOp : public OpConversionPattern { +template +class ConvertAtenMaxPoolOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + +private: + template struct DimensionTraits; + + template <> struct DimensionTraits { + static const int64_t Dim = 2; + }; + + template <> struct DimensionTraits { + static const int64_t Dim = 3; + }; + + static const int64_t Dim = DimensionTraits::Dim; + + LogicalResult createPoolingMax3D(AtenMaxPool3dOp &op, + typename OpTy::Adaptor adaptor, + ConversionPatternRewriter &rewriter, + SmallVectorImpl &kernelSizeIntValues, + SmallVectorImpl &strideInts, + SmallVectorImpl &paddingInts, + SmallVectorImpl &dilationInts, + bool ceilMode) const { + SmallVector outTensorShape; + Value self = adaptor.getSelf(); + Type elementType = self.getType().cast().getElementType(); + TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( + elementType, + APFloat::getInf(elementType.cast().getFloatSemantics(), + /*Negative=*/true)); + Value initValue = + rewriter.create(op->getLoc(), smallestFPValueAttr); + + Value paddedInput = padInputTensor(op, rewriter, self, ceilMode, 3, + strideInts, paddingInts, initValue); + + auto outTensorInitialized = computeOutputTensor( + op, rewriter, self, 3, ceilMode, strideInts, paddingInts, dilationInts, + kernelSizeIntValues, outTensorShape, initValue); + + auto shape = + castIntVectorToIndexVector(rewriter, op->getLoc(), kernelSizeIntValues); + Value windowTensor = rewriter.create( + op->getLoc(), getAsOpFoldResult(shape), elementType); + + MLIRContext *context = rewriter.getContext(); + + auto mapInput = mlir::AffineMap::get( + 8, 0, + { + rewriter.getAffineDimExpr(0), // n + rewriter.getAffineDimExpr(1), // c + // dim_d * stride_d + kernal_d * dilation_d + rewriter.getAffineDimExpr(2) * + getAffineConstantExpr(strideInts[0], context) + + rewriter.getAffineDimExpr(5) * + getAffineConstantExpr(dilationInts[0], context), + // dim_h * stride_h + kernal_h * dilation_h + rewriter.getAffineDimExpr(3) * + getAffineConstantExpr(strideInts[1], context) + + rewriter.getAffineDimExpr(6) * + getAffineConstantExpr(dilationInts[1], context), + // dim_w * stride_w + kernal_w * dilation_w + rewriter.getAffineDimExpr(4) * + getAffineConstantExpr(strideInts[2], context) + + rewriter.getAffineDimExpr(7) * + getAffineConstantExpr(dilationInts[2], context), + }, + context); + auto mapKernel = + mlir::AffineMap::get(8, 0, + { + rewriter.getAffineDimExpr(5), // kd + rewriter.getAffineDimExpr(6), // kh + rewriter.getAffineDimExpr(7) // kw + }, + context); + auto mapOutput = mlir::AffineMap::get( + 8, 0, + {rewriter.getAffineDimExpr(0), rewriter.getAffineDimExpr(1), + rewriter.getAffineDimExpr(2), rewriter.getAffineDimExpr(3), + rewriter.getAffineDimExpr(4)}, + context); + auto iteratorTypes = + SmallVector(5, utils::IteratorType::parallel); + iteratorTypes.append(3, utils::IteratorType::reduction); + SmallVector indexingMaps = {mapInput, mapKernel, mapOutput}; + Value poolingOp = + rewriter + .create( + op->getLoc(), + /* result types */ outTensorInitialized.getType(), + /* operands */ ValueRange({paddedInput, windowTensor}), + /* outputs */ outTensorInitialized, + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value currentVal = args[0], accMaxValue = args[2]; + Value max_result = + b.create(loc, currentVal, accMaxValue); + ; + b.create(loc, max_result); + }) + .getResult(0); + Type newResultType = this->getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, poolingOp); + return success(); + } + public: - using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AtenMaxPool2dOp op, OpAdaptor adaptor, + matchAndRewrite(OpTy op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - const TypeConverter *typeConverter = getTypeConverter(); + const TypeConverter *typeConverter = this->getTypeConverter(); Value self = adaptor.getSelf(); int64_t selfRank = self.getType().cast().getRank(); - // TODO: Add support for 3D inputs. - if (selfRank == 3) + + if (selfRank != Dim + 2) return rewriter.notifyMatchFailure( - op, "unimplemented: only support 4D input"); + op, "unimplemented: Does not support inputs with rank"); bool ceilMode; - SmallVector kernelSizeIntValues; - SmallVector strideInts, paddingInts, dilationInts; + SmallVector kernelSizeIntValues; + SmallVector strideInts, paddingInts, dilationInts; if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts))) return rewriter.notifyMatchFailure(op, "only support constant int dilations"); - if (failed(checkAndGetPoolingParameters( - op, rewriter, typeConverter, ceilMode, kernelSizeIntValues, - strideInts, paddingInts))) + + if (failed(checkAndGetPoolingParameters(op, rewriter, typeConverter, + ceilMode, kernelSizeIntValues, + strideInts, paddingInts))) return rewriter.notifyMatchFailure(op, "invalid pooling parameters"); Type elementType = self.getType().cast().getElementType(); - TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( - elementType, - APFloat::getInf(elementType.cast().getFloatSemantics(), - /*Negative=*/true)); - SmallVector outTensorShape; - // `maxpool2d` contains the result of maxpool2d operation over the input. - Value maxPool2d, paddedInput; - if (failed(createPoolingOp( - op, rewriter, self, /*supportNonFPInput=*/false, ceilMode, - /*dimensionality=*/2, kernelSizeIntValues, strideInts, paddingInts, - dilationInts, smallestFPValueAttr, outTensorShape, paddedInput, - maxPool2d))) - return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d"); - Type newResultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, newResultType, maxPool2d); - return success(); + + if constexpr (Dim == 2) { + SmallVector outTensorShape; + // `maxpool2d` contains the result of maxpool2d operation over the input. + Value maxPool2d, paddedInput; + TypedAttr smallestFPValueAttr = rewriter.getFloatAttr( + elementType, + APFloat::getInf( + elementType.cast().getFloatSemantics(), + /*Negative=*/true)); + if (failed(createPoolingOp( + op, rewriter, self, /*supportNonFPInput=*/true, ceilMode, + /*dimensionality=*/2, kernelSizeIntValues, strideInts, + paddingInts, dilationInts, smallestFPValueAttr, outTensorShape, + paddedInput, maxPool2d))) + return rewriter.notifyMatchFailure(op, "unable to compute maxpool2d"); + Type newResultType = this->getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, maxPool2d); + return success(); + } else { + return createPoolingMax3D(op, adaptor, rewriter, + kernelSizeIntValues, strideInts, paddingInts, + dilationInts, ceilMode); + } } }; } // namespace @@ -650,7 +792,10 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp(); - patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add>(typeConverter, context); + patterns.add>(typeConverter, context); + target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 33247e6639e1..d058d0fdb97c 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7174,6 +7174,403 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.max_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.max_pool3d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__._max_pool3d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @__torch__._max_pool3d(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.list {\n" +" %int-1 = torch.constant.int -1\n" +" %int-2 = torch.constant.int -2\n" +" %int-3 = torch.constant.int -3\n" +" %int-4 = torch.constant.int -4\n" +" %int-5 = torch.constant.int -5\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %str_0 = torch.constant.str \"AssertionError: max_pool3d: dilation must be either a single int, or a tuple of three ints\"\n" +" %str_1 = torch.constant.str \"AssertionError: max_pool3d: padding must either be a single int, or a tuple of thee ints\"\n" +" %str_2 = torch.constant.str \"AssertionError: max_pool3d: stride must either be omitted, a single int, or a tuple of three ints\"\n" +" %none = torch.constant.none\n" +" %str_3 = torch.constant.str \"AssertionError: max_pool3d: kernel_size must either be a single int, or a tuple of three ints\"\n" +" %true = torch.constant.bool true\n" +" %int1 = torch.constant.int 1\n" +" %int3 = torch.constant.int 3\n" +" %int0 = torch.constant.int 0\n" +" %int2 = torch.constant.int 2\n" +" %int4 = torch.constant.int 4\n" +" %int5 = torch.constant.int 5\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %45 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %46 = torch.aten.eq.int %45, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %46 : !torch.bool\n" +" }\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_3, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %4 = torch.aten.eq.int %3, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.tuple) {\n" +" %45 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %46 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %47 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %48 = torch.prim.TupleConstruct %45, %46, %47 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %48 : !torch.tuple\n" +" } else {\n" +" %45 = torch.aten.__getitem__.t %arg1, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %46 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %47 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %48 = torch.prim.TupleConstruct %45, %46, %47 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %48 : !torch.tuple\n" +" }\n" +" %6:3 = torch.prim.TupleUnpack %5 : !torch.tuple -> !torch.int, !torch.int, !torch.int\n" +" %7 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %8 = torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %45 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %46 = torch.aten.eq.int %45, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %46 : !torch.bool\n" +" }\n" +" %10 = torch.prim.If %9 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %45 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %46 = torch.aten.eq.int %45, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %46 : !torch.bool\n" +" }\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_2, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %12 = torch.aten.eq.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %13:3 = torch.prim.If %12 -> (!torch.int, !torch.int, !torch.int) {\n" +" torch.prim.If.yield %6#0, %6#0, %6#0 : !torch.int, !torch.int, !torch.int\n" +" } else {\n" +" %45 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %46 = torch.aten.eq.int %45, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %47:3 = torch.prim.If %46 -> (!torch.int, !torch.int, !torch.int) {\n" +" %48 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %49 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %50 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %48, %49, %50 : !torch.int, !torch.int, !torch.int\n" +" } else {\n" +" %48 = torch.aten.__getitem__.t %arg2, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %49 = torch.aten.__getitem__.t %arg2, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %50 = torch.aten.__getitem__.t %arg2, %int2 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %48, %49, %50 : !torch.int, !torch.int, !torch.int\n" +" }\n" +" torch.prim.If.yield %47#0, %47#1, %47#2 : !torch.int, !torch.int, !torch.int\n" +" }\n" +" %14 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %15 = torch.aten.eq.int %14, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %16 = torch.prim.If %15 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %45 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %46 = torch.aten.eq.int %45, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %46 : !torch.bool\n" +" }\n" +" torch.prim.If %16 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %17 = torch.aten.len.t %arg3 : !torch.list -> !torch.int\n" +" %18 = torch.aten.eq.int %17, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %19 = torch.prim.If %18 -> (!torch.tuple) {\n" +" %45 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %46 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %47 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %48 = torch.prim.TupleConstruct %45, %46, %47 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %48 : !torch.tuple\n" +" } else {\n" +" %45 = torch.aten.__getitem__.t %arg3, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %46 = torch.aten.__getitem__.t %arg3, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %47 = torch.aten.__getitem__.t %arg3, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %48 = torch.prim.TupleConstruct %45, %46, %47 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %48 : !torch.tuple\n" +" }\n" +" %20:3 = torch.prim.TupleUnpack %19 : !torch.tuple -> !torch.int, !torch.int, !torch.int\n" +" %21 = torch.aten.len.t %arg4 : !torch.list -> !torch.int\n" +" %22 = torch.aten.eq.int %21, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %23 = torch.prim.If %22 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %45 = torch.aten.len.t %arg4 : !torch.list -> !torch.int\n" +" %46 = torch.aten.eq.int %45, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %46 : !torch.bool\n" +" }\n" +" torch.prim.If %23 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %24 = torch.aten.len.t %arg4 : !torch.list -> !torch.int\n" +" %25 = torch.aten.eq.int %24, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %26 = torch.prim.If %25 -> (!torch.tuple) {\n" +" %45 = torch.aten.__getitem__.t %arg4, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %46 = torch.aten.__getitem__.t %arg4, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %47 = torch.aten.__getitem__.t %arg4, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %48 = torch.prim.TupleConstruct %45, %46, %47 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %48 : !torch.tuple\n" +" } else {\n" +" %45 = torch.aten.__getitem__.t %arg4, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %46 = torch.aten.__getitem__.t %arg4, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %47 = torch.aten.__getitem__.t %arg4, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %48 = torch.prim.TupleConstruct %45, %46, %47 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" +" torch.prim.If.yield %48 : !torch.tuple\n" +" }\n" +" %27:3 = torch.prim.TupleUnpack %26 : !torch.tuple -> !torch.int, !torch.int, !torch.int\n" +" %28 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %29 = torch.aten.eq.int %28, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %30 = torch.prim.If %29 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %45 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %46 = torch.aten.eq.int %45, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %46 : !torch.bool\n" +" }\n" +" torch.prim.If %30 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %31 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %32 = torch.aten.eq.int %31, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" %33 = torch.prim.If %32 -> (!torch.int) {\n" +" %45 = torch.aten.__getitem__.t %arg0, %int-5 : !torch.list, !torch.int -> !torch.int\n" +" torch.prim.If.yield %45 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int1 : !torch.int\n" +" }\n" +" %34 = torch.aten.__getitem__.t %arg0, %int-4 : !torch.list, !torch.int -> !torch.int\n" +" %35 = torch.aten.__getitem__.t %arg0, %int-3 : !torch.list, !torch.int -> !torch.int\n" +" %36 = torch.aten.__getitem__.t %arg0, %int-2 : !torch.list, !torch.int -> !torch.int\n" +" %37 = torch.aten.__getitem__.t %arg0, %int-1 : !torch.list, !torch.int -> !torch.int\n" +" %38 = call @__torch__.torch.jit._shape_functions.pooling_output_shape(%35, %6#0, %20#0, %13#0, %27#0, %arg5) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" %39 = call @__torch__.torch.jit._shape_functions.pooling_output_shape(%36, %6#1, %20#1, %13#1, %27#1, %arg5) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" %40 = call @__torch__.torch.jit._shape_functions.pooling_output_shape(%37, %6#2, %20#2, %13#2, %27#2, %arg5) : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" %41 = call @__torch__._pool3d_shape_check(%arg0, %6#0, %6#1, %6#2, %13#0, %13#1, %13#2, %20#0, %20#1, %20#2, %27#0, %27#1, %27#2, %38, %39, %40) : (!torch.list, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.none\n" +" %42 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %43 = torch.aten.eq.int %42, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %44 = torch.prim.If %43 -> (!torch.list) {\n" +" %45 = torch.prim.ListConstruct %34, %38, %39, %40 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %45 : !torch.list\n" +" } else {\n" +" %45 = torch.prim.ListConstruct %33, %34, %38, %39, %40 : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" torch.prim.If.yield %45 : !torch.list\n" +" }\n" +" return %44 : !torch.list\n" +" }\n" +" func.func @__torch__._pool3d_shape_check(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.int, %arg5: !torch.int, %arg6: !torch.int, %arg7: !torch.int, %arg8: !torch.int, %arg9: !torch.int, %arg10: !torch.int, %arg11: !torch.int, %arg12: !torch.int, %arg13: !torch.int, %arg14: !torch.int, %arg15: !torch.int) -> !torch.none {\n" +" %str = torch.constant.str \"AssertionError: pool3d: input dimensions must be 4 or 5\"\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: \"\n" +" %false = torch.constant.bool false\n" +" %int0 = torch.constant.int 0\n" +" %int4 = torch.constant.int 4\n" +" %int5 = torch.constant.int 5\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.gt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.bool) {\n" +" %20 = torch.aten.gt.int %arg2, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %3 = torch.prim.If %2 -> (!torch.bool) {\n" +" %20 = torch.aten.gt.int %arg3, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.gt.int %arg4, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" %20 = torch.aten.gt.int %arg5, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %6 = torch.prim.If %5 -> (!torch.bool) {\n" +" %20 = torch.aten.gt.int %arg6, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.aten.gt.int %arg10, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %20 = torch.aten.gt.int %arg11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" %20 = torch.aten.gt.int %arg12, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %9 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %10 = torch.aten.eq.int %0, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" %11 = torch.prim.If %10 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %20 = torch.aten.eq.int %0, %int5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" }\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %12 = torch.aten.eq.int %0, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %12 -> () {\n" +" %20 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.ne.int %20, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %22 = torch.prim.If %21 -> (!torch.bool) {\n" +" %25 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %26 = torch.aten.ne.int %25, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %26 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %23 = torch.prim.If %22 -> (!torch.bool) {\n" +" %25 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %26 = torch.aten.ne.int %25, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %26 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %24 = torch.prim.If %23 -> (!torch.bool) {\n" +" %25 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" +" %26 = torch.aten.ne.int %25, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %26 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %24 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" } else {\n" +" %20 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %21 = torch.aten.ne.int %20, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %22 = torch.prim.If %21 -> (!torch.bool) {\n" +" %26 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %27 = torch.aten.ne.int %26, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %27 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %23 = torch.prim.If %22 -> (!torch.bool) {\n" +" %26 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %27 = torch.aten.ne.int %26, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %27 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %24 = torch.prim.If %23 -> (!torch.bool) {\n" +" %26 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" +" %27 = torch.aten.ne.int %26, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %27 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %25 = torch.prim.If %24 -> (!torch.bool) {\n" +" %26 = torch.aten.__getitem__.t %arg0, %int4 : !torch.list, !torch.int -> !torch.int\n" +" %27 = torch.aten.ne.int %26, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %27 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %25 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield\n" +" }\n" +" %13 = torch.aten.floordiv.int %arg1, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.aten.ge.int %13, %arg7 : !torch.int, !torch.int -> !torch.bool\n" +" %15 = torch.prim.If %14 -> (!torch.bool) {\n" +" %20 = torch.aten.floordiv.int %arg3, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.ge.int %20, %arg9 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %21 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %16 = torch.prim.If %15 -> (!torch.bool) {\n" +" %20 = torch.aten.floordiv.int %arg2, %int2 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.ge.int %20, %arg8 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %21 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %16 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %17 = torch.aten.ge.int %arg13, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %18 = torch.prim.If %17 -> (!torch.bool) {\n" +" %20 = torch.aten.ge.int %arg15, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" %19 = torch.prim.If %18 -> (!torch.bool) {\n" +" %20 = torch.aten.ge.int %arg14, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %20 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %19 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" return %none : !torch.none\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.max_pool2d_with_indices\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.tuple, list> {\n" " %0 = call @__torch__.torch.jit._shape_functions.max_pool2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5) : (!torch.list, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool) -> !torch.list\n" " %1 = torch.prim.TupleConstruct %0, %0 : !torch.list, !torch.list -> !torch.tuple, list>\n" @@ -8950,6 +9347,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.avg_pool3d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.bool, %arg5: !torch.bool, %arg6: !torch.optional) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.batch_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" @@ -9356,6 +9757,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.max_pool3d\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.max_pool2d_with_indices\"(%arg0: !torch.tuple, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.bool) -> !torch.tuple {\n" " %int4 = torch.constant.int 4\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 0ffc3c4fd606..341dad1e6192 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -349,6 +349,13 @@ "TransposeIntModule_basic", "TransposeIntNegDimsModule_basic", "IndexPutImpl2DNoneIndexStaticModule_basic", + "MaxPool3dCeilModeTrueModule_basic", + "MaxPool3dEmptyStrideStaticModule_basic", + "MaxPool3dLargeDatadModule_basic", + "MaxPool3dModuleRandomSimple_basic", + "MaxPool3dModule_basic", + "MaxPool3dStaticCeilModeTrueModule_basic", + "MaxPool3dStaticModule_basic", } STABLEHLO_PASS_SET = { @@ -1419,6 +1426,7 @@ "_ConvolutionDeprecated2DBenchmarkModule_basic", "_ConvolutionDeprecated2DCudnnModule_basic", "_ConvolutionDeprecated2DDeterministicModule_basic", + "MaxPool3dEmptyStrideStaticModule_basic", "AddIntModule_basic", "ArangeStartOutViewModule_basic", "AtenIntBoolOpModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 3c4d41fd1d7f..76b3d6af3ee7 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -648,9 +648,120 @@ def aten〇_unsafe_view〡shape(self: List[int], size: List[int]) -> List[int]: def aten〇resize_〡shape(self: List[int], size: List[int], memory_format: Optional[int] = None) -> List[int]: return size +def _pool3d_shape_check( + input: List[int], + kD: int, + kH: int, + kW: int, + dD: int, + dH: int, + dW: int, + padD: int, + padH: int, + padW: int, + dilationD: int, + dilationH: int, + dilationW: int, + outputDepth: int, + outputHeight: int, + outputWidth: int, +): + ndim = len(input) + + assert kD > 0 and kH > 0 and kW > 0 + assert dD > 0 and dH > 0 and dW > 0 + assert dilationD > 0 and dilationH > 0 and dilationW > 0 + assert ndim == 4 or ndim == 5, "pool3d: input dimensions must be 4 or 5" + if ndim == 4: + assert input[0] != 0 and input[1] != 0 and input[2] != 0 and input[3] != 0 + else: + assert input[0] != 0 and input[1] != 0 and input[2] != 0 and input[3] != 0 and input[4] != 0 + + assert kD // 2 >= padD and kW // 2 >= padW and kH // 2 >= padH + assert outputDepth >= 1 and outputWidth >= 1 and outputHeight >= 1 + +def _max_pool3d( + input: List[int], + kernel_size: List[int], + stride: List[int], + padding: List[int], + dilation: List[int], + ceil_mode: bool, +): + assert ( + len(kernel_size) == 1 or len(kernel_size) == 3 + ), "max_pool3d: kernel_size must either be a single int, or a tuple of three ints" + (kD, kH, kW) = (kernel_size[0], kernel_size[0], kernel_size[0]) if len(kernel_size) == 1 else (kernel_size[0], kernel_size[1], kernel_size[2]) + + assert ( + len(stride) == 0 or len(stride) == 1 or len(stride) == 3 + ), "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints" + + if len(stride) == 0: + (dD, dH, dW) = (kD, kD, kD) + elif len(stride) == 1: + (dD, dH, dW) = (stride[0], stride[0], stride[0]) + else: # len(stride) == 3 + (dD, dH, dW) = (stride[0], stride[1], stride[2]) + + assert ( + len(padding) == 1 or len(padding) == 3 + ), "max_pool3d: padding must either be a single int, or a tuple of thee ints" + (padD, padH, padW) = (padding[0], padding[0], padding[0]) if len(padding) == 1 else (padding[0], padding[1], padding[2]) + + assert ( + len(dilation) == 1 or len(dilation) == 3 + ), "max_pool3d: dilation must be either a single int, or a tuple of three ints" + (dilationD, dilationH, dilationW) = (dilation[0], dilation[0], dilation[0]) if len(dilation) == 1 else (dilation[0], dilation[1], dilation[2]) + + assert len(input) == 4 or len(input) == 5 + nbatch = input[-5] if len(input) == 5 else 1 + nInputPlane = input[-4] + inputDepth = input[-3] + inputHeight = input[-2] + inputWidth = input[-1] + + outputDepth = upstream_shape_functions.pooling_output_shape(inputDepth, kD, padD, dD, dilationD, ceil_mode) + outputHeight = upstream_shape_functions.pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode) + outputWidth = upstream_shape_functions.pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode) + + _pool3d_shape_check( + input, + kD, + kH, + kW, + dD, + dH, + dW, + padD, + padH, + padW, + dilationD, + dilationH, + dilationW, + outputDepth, + outputHeight, + outputWidth, + ) + + if len(input) == 4: + return [nInputPlane, outputDepth, outputHeight, outputWidth] + else: + return [nbatch, nInputPlane, outputDepth, outputHeight, outputWidth] + def aten〇max_pool2d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> List[int]: return upstream_shape_functions.max_pool2d(self, kernel_size, stride, padding, dilation, ceil_mode) +@check_shape_function([ + Invocation(TensorOfShape(3, 6, 10, 10, 10), [2]), # Basic using defaults + Invocation(TensorOfShape(3, 6, 10, 10, 10), [4], [2], [2], [2]), # Using single values for each parameter + Invocation(TensorOfShape(3, 6, 64, 64, 64), [4, 6, 8], [2, 4, 2], [1, 2, 4], [1, 2, 4]), # Using dimensions should be + ErrorInvocation(TensorOfShape(3, 6, 2, 2, 2), [4]), # Input is too small + ErrorInvocation(TensorOfShape(3, 6, 10, 10, 10), [4], [2], [4], [2]), # The following relationship between kernel and padding needs to apply: Kernel size >= 2 * padding size +]) +def aten〇max_pool3d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0, 0,), dilation: List[int] = (1, 1, 1,), ceil_mode: bool = False) -> List[int]: + return _max_pool3d(self, kernel_size, stride, padding, dilation, ceil_mode) + def aten〇max_pool2d_with_indices〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> Tuple[List[int], List[int]]: maxpool2d = indices = upstream_shape_functions.max_pool2d(self, kernel_size, stride, padding, dilation, ceil_mode) return maxpool2d, indices @@ -1780,6 +1891,11 @@ def aten〇avg_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: Lis self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7, 8)], kernel_size=[2, 2, 2])) +def aten〇avg_pool3d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0, 0,), ceil_mode: bool = False, count_include_pad: bool = True, divisor_override: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype( tensor_shapes=[(2, 3, 5), (3,), (3,), (3,), (3,)], training=False, momentum=0.1, eps=1e-5, cudnn_enabled=True)) def aten〇batch_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], running_mean_rank_dtype: Optional[Tuple[int, int]], running_var_rank_dtype: Optional[Tuple[int, int]], training: bool, momentum: float, eps: float, cudnn_enabled: bool) -> int: @@ -2140,6 +2256,11 @@ def aten〇max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: Lis self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7, 8)], kernel_size=[2, 2, 2])) +def aten〇max_pool3d〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0, 0,), dilation: List[int] = (1, 1, 1,), ceil_mode: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], kernel_size=[2, 2])) def aten〇max_pool2d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> Tuple[int, int]: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 1c6748538a6b..b19596be7031 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -211,6 +211,154 @@ def MaxPool2dCeilModeTrueModule_basic(module, tu: TestUtils): module.forward(tu.rand(1, 1, 20, 20, low=0.5, high=1.0)) +# ============================================================================== + +class MaxPool3dModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.mp3d = torch.nn.MaxPool3d(kernel_size=[4, 4, 4], + stride=[2, 2, 2], + padding=[1, 1, 1], + dilation=1) + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return self.mp3d(x) + + +@register_test_case(module_factory=lambda: MaxPool3dModule()) +def MaxPool3dModule_basic(module, tu: TestUtils): + module.forward(torch.arange(8*8*8).view(1, 1, 8, 8, 8).float()) + +class MaxPool3dRandomSimpleModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mp3d = torch.nn.MaxPool3d(kernel_size=[4, 4, 4], + stride=[2, 2, 2], + padding=[1, 1, 1], + dilation=1) + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return self.mp3d(x) + + +@register_test_case(module_factory=lambda: MaxPool3dRandomSimpleModule()) +def MaxPool3dModuleRandomSimple_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 20, 20, 20, low=-1)) + +class MaxPool3dLargeDataModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.mp3d = torch.nn.MaxPool3d(kernel_size=[6, 8, 8], + stride=[2, 2, 2], + padding=[3, 4, 4], + dilation=2) + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return self.mp3d(x) + + +@register_test_case(module_factory=lambda: MaxPool3dLargeDataModule()) +def MaxPool3dLargeDatadModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 20, 20, 20, low=-1)) + +class MaxPool3dEmptyStrideStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + @export + @annotate_args([ + None, + ([1, 1, 20, 20, 20], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.max_pool3d(x, kernel_size=2, stride=[]) + + +@register_test_case(module_factory=lambda: MaxPool3dEmptyStrideStaticModule()) +def MaxPool3dEmptyStrideStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 20, 20, 20, low=-1)) + + +class MaxPool3dStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mp3d = torch.nn.MaxPool3d(kernel_size=[3, 3, 3], + stride=[2, 2, 2], + padding=[1, 1, 1], + dilation=[1, 1, 1]) + @export + @annotate_args([ + None, + ([1, 64, 112, 112, 112], torch.float32, True), + ]) + def forward(self, x): + return self.mp3d(x) + + +@register_test_case(module_factory=lambda: MaxPool3dStaticModule()) +def MaxPool3dStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 64, 112, 112, 112)) + +class MaxPool3dStaticCeilModeTrueModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mp3d = torch.nn.MaxPool3d(kernel_size=[3, 3, 3], + stride=[2, 2, 2], + padding=[1, 1, 1], + dilation=[1, 1, 1], + ceil_mode=True) + + @export + @annotate_args([ + None, + ([1, 64, 112, 112, 112], torch.float32, True), + ]) + def forward(self, x): + return self.mp3d(x) + + +@register_test_case(module_factory=lambda: MaxPool3dStaticCeilModeTrueModule()) +def MaxPool3dStaticCeilModeTrueModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 64, 112, 112, 112)) + + +class MaxPool3dCeilModeTrueModule(torch.nn.Module): + def __init__(self): + super().__init__() + self.mp3d = torch.nn.MaxPool3d(kernel_size=[6, 8, 8], + stride=[2, 2, 2], + padding=[3, 4, 4], + dilation=2, + ceil_mode=True) + @export + @annotate_args([ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return self.mp3d(x) + +@register_test_case(module_factory=lambda: MaxPool3dCeilModeTrueModule()) +def MaxPool3dCeilModeTrueModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 1, 20, 20, 20, low=0.5, high=1.0)) + + # ============================================================================== diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index df19ef7645e8..70f543ad4f74 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -1,7 +1,7 @@ // RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s -// CHECK-LABEL: func @forward -func.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { +// CHECK-LABEL: func @forward_max_pool2d +func.func @forward_max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 @@ -27,3 +27,50 @@ func.func @forward(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?, %4 = torch.aten.max_pool2d %arg0, %kernel_size, %stride, %padding, %dilation, %false : !torch.vtensor<[?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> return %4 : !torch.vtensor<[?,?,?,?],f32> } + +// ----- + +// CHECK: #map = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2 * 2 + d5 * 3, d3 * 2 + d6 * 3, d4 * 2 + d7 * 3)> +// CHECK: #map1 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d5, d6, d7)> +// CHECK: #map2 = affine_map<(d0, d1, d2, d3, d4, d5, d6, d7) -> (d0, d1, d2, d3, d4)> +// CHECK-LABEL: func @forward_max_pool3d +func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?,?],f32> { + %kernel_size1 = torch.constant.int 8 + %kernel_size2 = torch.constant.int 8 + %kernel_size3 = torch.constant.int 8 + + %stride1 = torch.constant.int 2 + %stride2 = torch.constant.int 2 + %stride3 = torch.constant.int 2 + + %padding1 = torch.constant.int 4 + %padding2 = torch.constant.int 4 + %padding3 = torch.constant.int 4 + + %dilation1 = torch.constant.int 3 + %dilation2 = torch.constant.int 3 + %dilation3 = torch.constant.int 3 + + %false = torch.constant.bool false + %kernel_size = torch.prim.ListConstruct %kernel_size1, %kernel_size2, %kernel_size3 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %stride = torch.prim.ListConstruct %stride1, %stride2, %stride3 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %padding = torch.prim.ListConstruct %padding1, %padding2, %padding3 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %dilation = torch.prim.ListConstruct %dilation1, %dilation2, %dilation3 : (!torch.int, !torch.int, !torch.int) -> !torch.list + + %4 = torch.aten.max_pool3d %arg0, %kernel_size, %stride, %padding, %dilation, %false : !torch.vtensor<[?,?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[?,?,?,?,?],f32> + + // CHECK: %[[MIN_VALUE:.*]] = arith.constant 0xFF800000 : f32 + // CHECK: %[[PADDED_INPUT_TENSOR:.*]] = tensor.pad %{{.*}} low[0, 0, 4, 4, 4] high[0, 0, 4, 4, 4] { + // CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): + // CHECK-NEXT: tensor.yield %[[MIN_VALUE:.*]] : f32 + // CHECK: } : tensor to tensor + + // CHECK: %[[OUTPUT_TENSOR:.*]] = linalg.fill ins(%[[MIN_VALUE:.*]] : f32) outs(%{{.*}} : tensor) -> tensor + // CHECK: %[[MAX_3D_POOL:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PADDED_INPUT_TENSOR:.*]], %{{.*}} : tensor, tensor) outs(%[[OUTPUT_TENSOR:.*]] : tensor) { + // CHECK-NEXT: ^bb0(%[[CURRENT_VALUE:.*]]: f32, %[[KERNEL:.*]]: f32, %[[ACC_OUT:.*]]: f32): + // CHECK-NEXT: %[[MAXF:.*]] = arith.maximumf %[[CURRENT_VALUE:.*]], %[[ACC_OUT:.*]] : f32 + // CHECK-NEXT: linalg.yield %[[MAXF:.*]] : f32 + // CHECK: } -> tensor + + return %4 : !torch.vtensor<[?,?,?,?,?],f32> +} From 3b85c70748ce7177f344ff0c8a99545f849c328a Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Fri, 19 Jan 2024 21:58:29 +0530 Subject: [PATCH 092/283] [ONNX][MLIR] Add support for onnx.gather op (#2726) This commit adds support for gather op in the onnx pipeline. https://github.com/nod-ai/SHARK-Turbine/issues/242 Signed-off-by: Gaurav Shukla --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 156 +++++++++++++++++- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 37 +++++ 2 files changed, 189 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 24089e87476d..4834e7af4d5e 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -275,10 +275,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( for (uint64_t i = 1; i < operands.size(); i++) { result = rewriter.create( binder.getLoc(), resultType, result, operands[i]); - } - rewriter.replaceOp( - binder.op, result.getDefiningOp()); - return success(); + } + rewriter.replaceOp(binder.op, result.getDefiningOp()); + return success(); }); patterns.onOp("Min", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -334,6 +333,155 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); + patterns.onOp( + "Gather", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data, indices; + int64_t axis; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorOperandAtIndex(indices, 1) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(axis, "axis", 0)) + return failure(); + Location loc = binder.getLoc(); + + // 1. Get data shape and rank. + auto dataTensorType = data.getType().cast(); + if (!dataTensorType || !dataTensorType.hasSizes()) { + return rewriter.notifyMatchFailure(binder.op, + "Expect non empty input data"); + } + ArrayRef dataShape = dataTensorType.getSizes(); + unsigned dataRank = dataShape.size(); + + // 2. Get indices shape and rank. + auto indexType = indices.getType().cast(); + if (!indexType || !indexType.hasSizes()) { + return rewriter.notifyMatchFailure(binder.op, + "Expect non empty index tensor"); + } + ArrayRef indexShape = indexType.getSizes(); + unsigned indexRank = indexShape.size(); + + // 3. Compute total elements in the indices tensor, as we will collapse + // the indices tensor to a unary tensor. Also compute index shape and + // data shape tensors as they will be used for creating output types. + int64_t indexElemCount = 1; + for (int64_t dim : indexShape) { + if (dim == Torch::kUnknownSize) { + indexElemCount = Torch::kUnknownSize; + break; + } + indexElemCount *= dim; + } + + Value constOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + SmallVector indexShapeTensor; + Value indexElemCountVal = constOne; + for (unsigned i = 0; i < indexRank; ++i) { + Value indexDimVal = rewriter.create( + loc, indices, + rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + indexShapeTensor.emplace_back(indexDimVal); + indexElemCountVal = rewriter.create( + loc, indexElemCountVal, indexDimVal); + } + + SmallVector dataShapeTensor; + for (unsigned i = 0; i < dataRank; ++i) { + dataShapeTensor.emplace_back(rewriter.create( + loc, data, + rewriter.create( + loc, rewriter.getI64IntegerAttr(i)))); + } + + // 4. We can not directly perform torch.gather as the onnx.gather op + // collects the input data at different location of output compared to + // torch.gather op. The output of torch.gather and onnx.gather ops are + // indexed differently. + // check https://onnx.ai/onnx/operators/onnx__Gather.html for more + // details. So we will collapse indices tensor to a unary tensor and + // materialize to non-axis dimension of data tensor. For example, + // assuming indices is of shape (4, 5, 6), data is (8, 10, 11, 12) and + // axis=1. we will collapse indices into a (120,) unary tensor, + // materialize to non-axis dimension of data i.e. reshaping the unary + // indices tensor to (1, 120, 1, 1) and then perform the torch.gather + // operation. Now broadcast the output of gather operation to non-axis + // dimensions of data tensor. This would make the result of shape (8, + // 10, 120, 12). Post the broadcasting, expand the indices dimensions by + // reshaping (8, 10, 120, 12) to (8, 10, 4, 5, 6, 12) tensor, which is + // our expected final result. + SmallVector collapsedIndexShape(dataRank, 1); + collapsedIndexShape[axis] = indexElemCount; + Type collapsedIndexType = Torch::ValueTensorType::get( + indexType.getContext(), llvm::ArrayRef(collapsedIndexShape), + indexType.getOptionalDtype()); + + SmallVector collapsedIndexSize(dataRank, constOne); + collapsedIndexSize[axis] = indexElemCountVal; + auto collapsedIndexSizeList = + rewriter.create( + loc, + rewriter.getType( + rewriter.getType()), + collapsedIndexSize); + + auto collapsedIndices = rewriter.create( + loc, collapsedIndexType, indices, collapsedIndexSizeList); + + // 5. Compute gather result type and perform gather operation. + Type gatherResultType = Torch::ValueTensorType::get( + dataTensorType.getContext(), llvm::ArrayRef(collapsedIndexShape), + dataTensorType.getOptionalDtype()); + Value constAxis = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); + Value constFalse = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getBoolAttr(false)); + auto gatherOp = rewriter.create( + loc, gatherResultType, data, constAxis, collapsedIndices, + /*sparseGrad=*/constFalse); + + // 6. Broadcast the gather output to non-axis dimensions of data tensor. + SmallVector dataShapeVector(dataShape); + dataShapeVector[axis] = indexElemCount; + Type expandResultType = Torch::ValueTensorType::get( + dataTensorType.getContext(), llvm::ArrayRef(dataShapeVector), + dataTensorType.getOptionalDtype()); + + dataShapeTensor[axis] = indexElemCountVal; + auto expandSizeList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(data.getContext())), + dataShapeTensor); + auto expandedGather = rewriter.create( + loc, expandResultType, gatherOp, expandSizeList, + /*implicit=*/constFalse); + + // 7. Compute the result type of reshape op which expands the collapsed + // indices shapes back to the original indices shapes and reshape the + // output produced at step 6. This will produce our expected result of + // onnx.gather op. + SmallVector resultShapeTensor; + for (unsigned i = 0; i < dataRank; ++i) { + if (i == axis) { + resultShapeTensor.insert(resultShapeTensor.end(), + indexShapeTensor.begin(), + indexShapeTensor.end()); + continue; + } + resultShapeTensor.emplace_back(dataShapeTensor[i]); + } + auto resultSizeList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(data.getContext())), + resultShapeTensor); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, expandedGather, resultSizeList); + return success(); + }); patterns.onOp( "GatherElements", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 147b3f9551c5..59049e40614c 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -37,6 +37,43 @@ func.func @test_less(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[ // ----- +// CHECK-LABEL: func.func @test_gather +func.func @test_gather(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} { + // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[ARG1_SIZE0:.+]] = torch.aten.size.int %arg1, %[[INT0]] + // CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT1]], %[[ARG1_SIZE0]] + // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 + // CHECK: %[[ARG1_SIZE1:.+]] = torch.aten.size.int %arg1, %[[INT1_0]] + // CHECK: %[[MUL2:.+]] = torch.aten.mul.int %[[MUL1]], %[[ARG1_SIZE1]] + // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[ARG1_SIZE2:.+]] = torch.aten.size.int %arg1, %[[INT2]] : !torch.vtensor<[8,10,20,40],si64>, !torch.int -> !torch.int + // CHECK: %[[MUL3:.+]] = torch.aten.mul.int %[[MUL2]], %[[ARG1_SIZE2]] : !torch.int, !torch.int -> !torch.int + // CHECK-DAG: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[ARG1_SIZE3:.+]] = torch.aten.size.int %arg1, %[[INT3]] : !torch.vtensor<[8,10,20,40],si64>, !torch.int -> !torch.int + // CHECK: %[[MUL4:.+]] = torch.aten.mul.int %[[MUL3]], %[[ARG1_SIZE3]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT0_2:.+]] = torch.constant.int 0 + // CHECK: %[[ARG0_SIZE0:.+]] = torch.aten.size.int %arg0, %[[INT0_2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.int + // CHECK: %[[INT1_3:.+]] = torch.constant.int 1 + // CHECK: %[[ARG0_SIZE1:.+]] = torch.aten.size.int %arg0, %[[INT1_3]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.int + // CHECK: %[[INT2_4:.+]] = torch.constant.int 2 + // CHECK: %[[ARG0_SIZE2:.+]] = torch.aten.size.int %arg0, %[[INT2_4]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.int + // CHECK: %[[LIST1:.+]] = torch.prim.ListConstruct %[[MUL4]], %[[INT1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[VIEW1:.+]] = torch.aten.view %arg1, %[[LIST1]] : !torch.vtensor<[8,10,20,40],si64>, !torch.list -> !torch.vtensor<[64000,1,1],si64> + // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[GATHER:.+]] = torch.aten.gather %arg0, %[[INT0_1]], %[[VIEW1]], %[[FALSE]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.vtensor<[64000,1,1],si64>, !torch.bool -> !torch.vtensor<[64000,1,1],f32> + // CHECK: %[[LIST2:.+]] = torch.prim.ListConstruct %[[MUL4]], %[[ARG0_SIZE1]], %[[ARG0_SIZE2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[GATHER]], %[[LIST2]], %[[FALSE]] : !torch.vtensor<[64000,1,1],f32>, !torch.list, !torch.bool -> !torch.vtensor<[64000,4,5],f32> + // CHECK: %[[LIST3:.+]] = torch.prim.ListConstruct %[[ARG1_SIZE0]], %[[ARG1_SIZE1]], %[[ARG1_SIZE2]], %[[ARG1_SIZE3]], %[[ARG0_SIZE1]], %[[ARG0_SIZE2]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[RES:.+]] = torch.aten.view %[[EXPAND]], %[[LIST3]] : !torch.vtensor<[64000,4,5],f32>, !torch.list -> !torch.vtensor<[8,10,20,40,4,5],f32> + // CHECK: return %[[RES]] : !torch.vtensor<[8,10,20,40,4,5],f32> + %0 = torch.operator "onnx.Gather"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32> + return %0 : !torch.vtensor<[8,10,20,40,4,5],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_gather_elements func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} { // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0 From 18669b38cbc42af1eff625af86a5f7faa356d76a Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Fri, 19 Jan 2024 13:44:45 -0500 Subject: [PATCH 093/283] Create add_ops.md (#2770) --- docs/add_ops.md | 159 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 159 insertions(+) create mode 100644 docs/add_ops.md diff --git a/docs/add_ops.md b/docs/add_ops.md new file mode 100644 index 000000000000..0809283bbeae --- /dev/null +++ b/docs/add_ops.md @@ -0,0 +1,159 @@ +# How to Add Ops to Torch-Mlir + +Collected links and contacts for how to add ops to torch-mlir. + + +
+Turbine Camp: Start Here +This document was previously known as `turbine-camp.md` to Nod.ai. "Turbine Camp" is part of Nod.ai's onboarding process. Welcome to turbine camp. This document originated at Nod.ai as a part of onboardding process, where new nod-ai folks learn about the architecture of our work by adding support for 2 ops to torch-mlir. I decided to put this into torch mlir because a lot of this is about torch-mlir. + +Written & maintained by @renxida + +Guides by other folks that were used during the creation of this document: +- [Chi Liu](https://gist.github.com/AmosLewis/dd31ab37517977b1c499d06495b4adc2) +- [Sunsoon](https://docs.google.com/document/d/1H79DwW_wnVzUU81EogwY5ueXgnl-QzKet1p2lnqPar4/edit?pli=1) + +## Before you begin... + +Nod-ai maintains the pipeline below, which allows us to take a ML model from e.g. huggingface, and compile it to a variety of devices including llvm-cpu, rocm and cuda and more as an optimized `vmfb` binary. + +1. The pipeline begins with a huggingface model, or some other supported source like llama.cpp. +2. [nod-ai/SHARK-Turbine](https://github.com/nod-ai/SHARK-Turbine) takes a huggingface model and exports a `.mlir` file. +3. **[llvm/torch-mlir](https://github.com/llvm/torch-mlir)**, which you will be working on in turbine-camp, will lower torchscript, torch dialect, and torch aten ops further into a mixture `linalg` or `math` MLIR dialects (with occasionally other dialects in the mix) +4. [IREE](https://github.com/openxla/iree) converts the final `.mlir` file into a binary (typically `.vmfb`) for running on a device (llvm-cpu, rocm, vulcan, cuda, etc). + +The details of how we do it and helpful commands to help you set up each repo is in [Sungsoon's Shark Getting Started Google Doc](https://docs.google.com/document/d/1H79DwW_wnVzUU81EogwY5ueXgnl-QzKet1p2lnqPar4/edit?pli=1) + +PS: IREE is pronounced Eerie, and hence the ghost icon. + +## How to begin +1. You will start by adding support for 2 ops in torch-mlir, to get you familiar with the center of our pipeline. Begin by reading [torch-mlir's documentation on how to implement a new torch op](https://github.com/llvm/torch-mlir/blob/main/docs/Torch-ops-E2E-implementation.md), and set up `llvm/torch_mlir` using https://github.com/llvm/torch-mlir/blob/main/docs/development.md +2. Pick 1 of the yet-unimplemented from the following. You should choose something that looks easy to you. **Make sure you create an issue by clicking the little "target" icon to the right of the op, thereby marking the op as yours** + - [TorchToLinalg ops tracking issue](https://github.com/nod-ai/SHARK-Turbine/issues/347) + - [TorchOnnnxToTorch ops tracking issue](https://github.com/nod-ai/SHARK-Turbine/issues/215) +3. Implement it. For torch -> linalg, see the how to torchop section below. For Onnx ops, see how to onnx below. +5. Make a pull request and reference your issue. When the pull request is closed, also close your issue to mark the op as done + +
+ +### How to TorchToLinalg + +You will need to do 4 things: +- make sure the op exists in `torch_ods_gen.py`, and then run `build_tools/update_torch_ods.sh`, and then build. This generates `GeneratedTorchOps.td`, which is used to generate the cpp and h files where ops function signatures are defined. + - Reference [torch op registry](https://github.com/pytorch/pytorch/blob/7451dd058564b5398af79bfc1e2669d75f9ecfa2/torch/csrc/jit/passes/utils/op_registry.cpp#L21) +- make sure the op exists in `abstract_interp_lib_gen.py`, and then run `build_tools/update_abstract_interp_lib.sh`, and then build. This generates `AbstractInterpLib.cpp`, which is used to generate the cpp and h files where ops function signatures are defined. + - Reference [torch shape functions](https://github.com/pytorch/pytorch/blob/7451dd058564b5398af79bfc1e2669d75f9ecfa2/torch/jit/_shape_functions.py#L1311) +- write test cases. They live in `projects/pt1`. See the [Dec 2023 example](https://github.com/llvm/torch-mlir/pull/2640/files). +- implement the op in one of the `lib/Conversion/TorchToLinalg/*.cpp` files + +Reference Examples +- [A Dec 2023 example with the most up to date lowering](https://github.com/llvm/torch-mlir/pull/2640/files) +- [Chi's simple example of adding op lowering](https://github.com/llvm/torch-mlir/pull/1454) useful instructions and referring links for you to understand the op lowering pipeline in torch-mlir in the comments + +Resources: +- how to set up torch-mlir: [https://github.com/llvm/torch-mlir/blob/main/docs/development.md](https://github.com/llvm/torch-mlir/blob/main/docs/development.md#checkout-and-build-from-source) +- torch-mlir doc on how to debug and test: [ttps://github.com/llvm/torch-mlir/blob/main/docs/development.md#testing](https://github.com/llvm/torch-mlir/blob/main/docs/development.md#testing) +- [torch op registry](https://github.com/pytorch/pytorch/blob/7451dd058564b5398af79bfc1e2669d75f9ecfa2/torch/csrc/jit/passes/utils/op_registry.cpp#L21) +- [torch shape functions](https://github.com/pytorch/pytorch/blob/7451dd058564b5398af79bfc1e2669d75f9ecfa2/torch/jit/_shape_functions.py#L1311) + +### How to TorchOnnxToTorch +0. Generate the big folder of ONNX IR. Use https://github.com/llvm/torch-mlir/blob/main/test/python/onnx_importer/import_smoke_test.py . Alternatively, if you're trying to support a certain model, convert that model to onnx IR with + ``` + optimum-cli export onnx --model facebook/opt-125M fb-opt + python -m torch_mlir.tools.import_onnx fb-opt/model.onnx -o fb-opt-125m.onnx.mlir + ``` +2. Find an instance of the Op that you're trying to implement inside the smoke tests folder or the generated model IR, and write a test case. Later you will save it to one of the files in `torch-mlir/test/Conversion/TorchOnnxToTorch`, but for now feel free to put it anywhere. +3. Implement the op in `lib/Conversion/TorchOnnxToTorch/something.cpp`. +4. Test the conversion by running `./build/bin/torch-mlir-opt -split-input-file -verify-diagnostics -convert-torch-onnx-to-torch your_mlir_file.mlir`. For more details, see https://github.com/llvm/torch-mlir/blob/main/docs/development.md#testing . Xida usually creates a separate MLIR file to test it to his satisfaction before integrating it into one of the files at `torch-mlir/test/Conversion/TorchOnnxToTorch`. + +Helpful examples: +- [A Dec 2023 example where an ONNX op is implemented](https://github.com/llvm/torch-mlir/pull/2641/files#diff-b584b152020af6d2e5dbf62a08b2f25ed5afc2c299228383b9651d22d44b5af4R493) +- [Vivek's example of ONNX op lowering](https://github.com/llvm/torch-mlir/commit/dc9ea08db5ac295b4b3f91fc776fef6a702900b9) + +## Contacts +People who've worked on this for a while +- Vivek (@vivek97 on discord) +- Chi.Liu@amd.com + +Recent Turbine Camp Attendees, from recent to less recent +- Xida.ren@amd.com (@xida_ren on discord) +- Sungsoon.Cho@amd.com + +## Links + +- Tutorials + - [Sungsoon's Shark Getting Started Google Doc](https://docs.google.com/document/d/1H79DwW_wnVzUU81EogwY5ueXgnl-QzKet1p2lnqPar4/edit?pli=1) + - This document contains commands that would help you set up shark and run demos + - [How to implement ONNX op lowering](https://github.com/llvm/torch-mlir/blob/main/docs/importers/onnx_importer.md) +- Examples + - [A Dec 2023 example with the most up to date lowering](https://github.com/llvm/torch-mlir/pull/2640/files) + - Chi's Example Lowering + - Github issue and code detailing how to implement the lowring of an OP. + - [Chi's simple example of adding op lowering](https://github.com/llvm/torch-mlir/pull/1454) useful instructions and referring links for you to understand the op lowering pipeline in torch-mlir in the comments + - If you have questions, reach out to [Chi on Discord](https://discordapp.com/channels/973663919757492264/1104195883307892837/1180233875058868224) + - [Vivek's example of ONNX op lowering](https://github.com/llvm/torch-mlir/commit/dc9ea08db5ac295b4b3f91fc776fef6a702900b9) +- Find Ops To Lower + - [Torch MLIR + ONNX Unimplemented Ops on Sharepoint](https://amdcloud-my.sharepoint.com/:x:/r/personal/esaimana_amd_com/Documents/Torch%20MLIR%20+%20ONNX%20Unimplemented%20Ops.xlsx?d=w438f26fac8fd44eeafb89bc99e2c563b&csf=1&web=1&e=Qd4eHm) + - If you don't have access yet, request it. + - nod-ai/SHARK-Turbine ssues tracking op support + - [Model and Op Support](https://github.com/nod-ai/SHARK-Turbine/issues/119) + - [ONNX op support](https://github.com/nod-ai/SHARK-Turbine/issues/215) + + +## Chi's useful commands for debugging torch mlir + +https://gist.github.com/AmosLewis/dd31ab37517977b1c499d06495b4adc2 + +## How to write test cases and test your new op + +https://github.com/llvm/torch-mlir/blob/main/docs/development.md#testing + + + +## How to set up vs code and intellisence for [torch-mlir] +Xida: This is optional. If you're using VS code like me, you might want to set it up so you can use the jump to definition / references, auto fix, and other features. + +Feel free to contact me on discord if you have trouble figuring this out. + +You may need to write something like this into your + +```.vscode/settings.json``` + +under `torch-mlir` + +```json +{ + "files.associations": { + "*.inc": "cpp", + "ranges": "cpp", + "regex": "cpp", + "functional": "cpp", + "chrono": "cpp", + "__functional_03": "cpp", + "target": "cpp" + }, + "cmake.sourceDirectory": ["/home/xida/torch-mlir/externals/llvm-project/llvm"], + "cmake.buildDirectory": "${workspaceFolder}/build", + "cmake.generator": "Ninja", + "cmake.configureArgs": [ + "-DLLVM_ENABLE_PROJECTS=mlir", + "-DLLVM_EXTERNAL_PROJECTS=\"torch-mlir\"", + "-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR=\"/home/xida/torch-mlir\"", + "-DCMAKE_BUILD_TYPE=Release", + "-DCMAKE_C_COMPILER_LAUNCHER=ccache", + "-DCMAKE_CXX_COMPILER_LAUNCHER=ccache", + "-DLLVM_ENABLE_PROJECTS=mlir", + "-DLLVM_EXTERNAL_PROJECTS=torch-mlir", + "-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR=${workspaceFolder}", + "-DMLIR_ENABLE_BINDINGS_PYTHON=ON", + "-DLLVM_ENABLE_ASSERTIONS=ON", + "-DLLVM_TARGETS_TO_BUILD=host", + ], + "C_Cpp.default.configurationProvider": "ms-vscode.cmake-tools", + "cmake.configureEnvironment": { + "PATH": "/home/xida/miniconda/envs/torch-mlir/bin:/home/xida/miniconda/condabin:/home/xida/miniconda/bin:/home/xida/miniconda/bin:/home/xida/miniconda/condabin:/home/xida/miniconda/bin:/home/xida/miniconda/bin:/home/xida/miniconda/condabin:/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin:/usr/games:/usr/local/games:/snap/bin" + }, + "cmake.cmakePath": "/home/xida/miniconda/envs/torch-mlir/bin/cmake", // make sure this is a cmake that knows where your python is +} +``` +The important things to note are the `cmake.configureArgs`, which specify the location of your torch mlir, and the `cmake.sourceDirectory`, which indicates that CMAKE should not build from the current directory and should instead build from `externals/llvm-project/llvm` From b3a3ad4e2a5a6c2a5c649d21cb3568e238cca132 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Fri, 19 Jan 2024 15:13:32 -0800 Subject: [PATCH 094/283] Generalize install instructions to not exclude Windows. (#2771) Overly specific docs can get stale easily. It looks like https://llvm.github.io/torch-mlir/package-index/ has included Windows packages since around https://github.com/llvm/torch-mlir/pull/1521. Here's an example release: https://github.com/llvm/torch-mlir/releases/tag/snapshot-20240118.1087 ``` torch-2.3.0.dev20240109+cpu-cp311-cp311-linux_x86_64.whl torch-2.3.0.dev20240109+cpu-cp311-cp311-win_amd64.whl torch-2.3.0.dev20240109+cpu-cp38-cp38-linux_x86_64.whl torch-2.3.0.dev20240109-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl torch-2.3.0.dev20240109-cp311-none-macosx_10_9_x86_64.whl torch_mlir-20240118.1087-cp311-cp311-linux_aarch64.whl torch_mlir-20240118.1087-cp311-cp311-linux_x86_64.whl torch_mlir-20240118.1087-cp311-cp311-macosx_11_0_universal2.whl torch_mlir-20240118.1087-cp311-cp311-win_amd64.whl torch_mlir-20240118.1087-cp38-cp38-linux_x86_64.whl ``` --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index cc479d8d35eb..a10b9ac36bb5 100644 --- a/README.md +++ b/README.md @@ -52,7 +52,7 @@ Meeting links can be found [here](https://discourse.llvm.org/t/new-community-mee ## Install torch-mlir snapshot -At the time of writing, we release pre-built snapshot of torch-mlir for Python 3.11 on Linux and macOS. +At the time of writing, we release pre-built snapshots of torch-mlir for Python 3.11. If you have Python 3.11, the following commands initialize a virtual environment. ```shell From 2f4924015de915f7cd8f12500af2018fe3977fbe Mon Sep 17 00:00:00 2001 From: Dave Liddell <44620210+daveliddell@users.noreply.github.com> Date: Fri, 19 Jan 2024 17:18:16 -0700 Subject: [PATCH 095/283] [onnx] Added flatten (#2760) [https://github.com/nod-ai/SHARK-Turbine/issues/328](url) --------- Co-authored-by: Dave Liddell --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 73 +++++++++++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 113 ++++++++++++++++++ 2 files changed, 186 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index c18d681055aa..4df2e0f88eb1 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1364,6 +1364,79 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, data, dimValueList); return success(); }); + patterns.onOp( + "Flatten", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // Flatten means to partition the input tensor's dimensions + // into a "left range" spanning 0 to axis - 1 and a "right range" + // spanning axis to rank - 1. Each range is then collapsed + // into a single dimension, resulting in a 2-D tensor. + // If either range is empty, it is replaced with a single + // dimension of size 1. + // + // For example, for a 4-D input tensor of shape (a, b, c, d) + // and axis==2, flatten produces a 2-D tensor of shape + // (a*b, c*d). + // + // If instead axis==0, the left range is empty, and the result + // is (1, a*b*c*d). + + Torch::ValueTensorType resultType; + Value operand; + int64_t axis; + if (binder.tensorOperand(operand) || + binder.s64IntegerAttr(axis, "axis", 1) || + binder.tensorResultType(resultType)) + return failure(); + + // If axis is negative, count from the right instead of left + int64_t rank = + cast(operand.getType()).getSizes().size(); + if (axis < 0) + axis = rank + axis; + + Value collapsedRight; + auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation( + binder.op->getContext()); + + if (axis >= rank) { + // If the right range is empty, add a dim of size 1 to the + // right side of the shape: + // cr = torch.unsqueeze(x, x.ndim) + Value rankConst = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(rank)); + collapsedRight = rewriter.create( + binder.getLoc(), baseType, operand, rankConst); + } else { + // Otherwise, collapse the right range into a single dimension: + // cr = torch._prims.collapse(x, axis, x.ndim - 1) + Value axisConst = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(axis)); + Value rankLess1Const = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(rank - 1)); + collapsedRight = rewriter.create( + binder.getLoc(), baseType, operand, axisConst, rankLess1Const); + } + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + + if (axis <= 0) { + // If the left range is empty, add a dim of size 1 to the + // left side of the shape: + // torch.unsqueeze(cr, 0) + rewriter.replaceOpWithNewOp( + binder.op, resultType, collapsedRight, zero); + return success(); + } + + // Otherwise, collapse the left range into a single dimension: + // torch._prims.collapse(cr, 0, axis - 1) + Value axisLess1Const = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(axis - 1)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, collapsedRight, zero, axisLess1Const); + return success(); + }); patterns.onOp("Floor", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 42a0fe743bc2..c2d3c12a7b92 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1062,3 +1062,116 @@ func.func @ints_constant() -> !torch.vtensor<[2], si64> attributes {torch.onnx_m return %0 : !torch.vtensor<[2],si64> } +// CHECK-LABEL: @test_flatten_4d_axis_2 +func.func @test_flatten_4d_axis_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1 + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> + return %0 : !torch.vtensor<[6,20],f32> +} + +// CHECK-LABEL: @test_flatten_4d_axis_0 +func.func @test_flatten_4d_axis_0(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0 + // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,120],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> + return %0 : !torch.vtensor<[1,120],f32> +} + +// CHECK-LABEL: @test_flatten_4d_axis_4 +func.func @test_flatten_4d_axis_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[120,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 4 + // CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 3 + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[120,1],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 4 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[120,1],f32> + return %0 : !torch.vtensor<[120,1],f32> +} + +// CHECK-LABEL: @test_flatten_4d_axis_negative_2 +func.func @test_flatten_4d_axis_negative_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1 + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> + return %0 : !torch.vtensor<[6,20],f32> +} + +// CHECK-LABEL: @test_flatten_4d_axis_negative_1 +func.func @test_flatten_4d_axis_negative_1(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[24,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 2 + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[24,5],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[24,5],f32> + return %0 : !torch.vtensor<[24,5],f32> +} + +// CHECK-LABEL: @test_flatten_4d_axis_negative_4 +func.func @test_flatten_4d_axis_negative_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0 + // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,120],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -4 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> + return %0 : !torch.vtensor<[1,120],f32> +} + +// CHECK-LABEL: @test_flatten_2d_axis_1 +func.func @test_flatten_2d_axis_1(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0 + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> + return %0 : !torch.vtensor<[2,3],f32> +} + +// CHECK-LABEL: @test_flatten_1d_axis_0 +func.func @test_flatten_1d_axis_0(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0 + // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,2],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> + return %0 : !torch.vtensor<[1,2],f32> +} + +// CHECK-LABEL: @test_flatten_1d_axis_negative_1 +func.func @test_flatten_1d_axis_negative_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0 + // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,2],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> + return %0 : !torch.vtensor<[1,2],f32> +} + +// COM: CHECK-LABEL: @test_flatten_1d_axis_1 +func.func @test_flatten_1d_axis_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 1 + // CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 + // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0 + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[2,1],f32> + %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[2,1],f32> + return %0 : !torch.vtensor<[2,1],f32> +} From 50ac3b1912516b064a94d2414be295be79b09bc5 Mon Sep 17 00:00:00 2001 From: James Newling Date: Fri, 19 Jan 2024 19:12:29 -0800 Subject: [PATCH 096/283] g++ build fix (#2778) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Introduced in 704cfdaf0893fa4078ab744a915a56367282dcaa of @wu-s-john g++ compiler error: Pooling.cpp:177:13: error: explicit specialization in non-namespace scope ‘class Design looks good, g++ is just freaking out for no good reason. Un-nesting the template classes fixes the error. We don't have g++ CI. This hopefully happens infrequently enough that we can just fix manually. My service to those folks who really like building with g++... :) --- lib/Conversion/TorchToLinalg/Pooling.cpp | 36 ++++++++++++++---------- 1 file changed, 21 insertions(+), 15 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 85354aad4f12..14d2c71dbc92 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -72,7 +72,8 @@ checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter, return success(); } -static Value computeOutputTensor(Operation *op, ConversionPatternRewriter &rewriter, +static Value +computeOutputTensor(Operation *op, ConversionPatternRewriter &rewriter, Value self, int64_t dimensionality, bool ceilMode, SmallVectorImpl &strideInts, SmallVectorImpl &paddingInts, @@ -167,24 +168,29 @@ static LogicalResult createPoolingOp( } namespace { + +template struct DimensionTraits {}; + +template <> struct DimensionTraits { + static constexpr int64_t Dim = 2; + // unused const variable warning suppression: + static_assert(Dim == Dim); +}; + +template <> struct DimensionTraits { + static constexpr int64_t Dim = 3; + // unused const variable warning suppression: + static_assert(Dim == Dim); +}; + template class ConvertAtenMaxPoolOp : public OpConversionPattern { using OpConversionPattern::OpConversionPattern; private: - template struct DimensionTraits; - - template <> struct DimensionTraits { - static const int64_t Dim = 2; - }; - - template <> struct DimensionTraits { - static const int64_t Dim = 3; - }; - static const int64_t Dim = DimensionTraits::Dim; - LogicalResult createPoolingMax3D(AtenMaxPool3dOp &op, + LogicalResult createPoolingMax3D(AtenMaxPool3dOp &op, typename OpTy::Adaptor adaptor, ConversionPatternRewriter &rewriter, SmallVectorImpl &kernelSizeIntValues, @@ -327,9 +333,9 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern { rewriter.replaceOpWithNewOp(op, newResultType, maxPool2d); return success(); } else { - return createPoolingMax3D(op, adaptor, rewriter, - kernelSizeIntValues, strideInts, paddingInts, - dilationInts, ceilMode); + return createPoolingMax3D(op, adaptor, rewriter, kernelSizeIntValues, + strideInts, paddingInts, dilationInts, + ceilMode); } } }; From b9806cfa3823bb68faef39fbbb284eda92c0e906 Mon Sep 17 00:00:00 2001 From: Franz Haniel <77495327+frafranz@users.noreply.github.com> Date: Mon, 22 Jan 2024 18:47:13 +0100 Subject: [PATCH 097/283] [TorchToLinalg] Add lowering for torch.aten.diagonal (#2632) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 ++++ lib/Conversion/TorchToLinalg/DataMovement.cpp | 138 ++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 72 +++++++++ lib/Dialect/Torch/Utils/Utils.cpp | 2 +- .../build_tools/abstract_interp_lib_gen.py | 35 +++++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/__init__.py | 1 + .../test_suite/diagonal.py | 123 ++++++++++++++++ 8 files changed, 396 insertions(+), 1 deletion(-) create mode 100644 projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index b2c7bf32e368..7e74e698f60a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -11306,6 +11306,31 @@ def Torch_AtenAsStridedCopyOp : Torch_Op<"aten.as_strided_copy", [ }]; } +def Torch_AtenDiagonalOp : Torch_Op<"aten.diagonal", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::diagonal : (Tensor, int, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + Torch_IntType:$offset, + Torch_IntType:$dim1, + Torch_IntType:$dim2 + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenDiagonalOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenDiagonalOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_AtenDiagonalCopyOp : Torch_Op<"aten.diagonal_copy", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 49f5f0ec3321..297a0f4c2be6 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -1834,6 +1834,142 @@ class ConvertAtenViewAsRealOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertAtenDiagonalOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenDiagonalOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + int64_t offset; + if (!matchPattern(op.getOffset(), m_TorchConstantInt(&offset))) + return rewriter.notifyMatchFailure(op, "offset must be constant"); + int64_t dim1; + if (!matchPattern(op.getDim1(), m_TorchConstantInt(&dim1))) + return rewriter.notifyMatchFailure(op, "dim1 must be constant"); + int64_t dim2; + if (!matchPattern(op.getDim2(), m_TorchConstantInt(&dim2))) + return rewriter.notifyMatchFailure(op, "dim2 must be constant"); + + Value inputMatrix = adaptor.getSelf(); + RankedTensorType inputType = inputMatrix.getType().cast(); + int64_t inputRank = inputType.getRank(); + + if (inputRank < 2) + return rewriter.notifyMatchFailure( + op, "input must have at least two dimensions"); + int64_t outputRank = inputRank - 1; + + dim1 = toPositiveDim(dim1, inputRank); + if (!isValidDim(dim1, inputRank)) + return rewriter.notifyMatchFailure(op, "dim1 out of range"); + dim2 = toPositiveDim(dim2, inputRank); + if (!isValidDim(dim2, inputRank)) + return rewriter.notifyMatchFailure(op, "dim2 out of range"); + if (dim1 == dim2) + return rewriter.notifyMatchFailure( + op, "diagonal dimensions cannot be identical"); + + Type elementType = inputType.getElementType(); + RankedTensorType outputType = getTypeConverter() + ->convertType(op->getResult(0).getType()) + .cast(); + Location loc = op.getLoc(); + + Value dim1Size, dim2Size; + dim1Size = getDimOp(rewriter, loc, inputMatrix, dim1); + dim2Size = getDimOp(rewriter, loc, inputMatrix, dim2); + + // compute the length of the diagonal with possible offset + // if the offset is very large or very small, diagSize=0 and an empty tensor + // is returned + Value indexZero = rewriter.create(loc, 0); + Value indexMinusOne = rewriter.create(loc, -1); + Value indexOffset = rewriter.create(loc, offset); + Value offsetIsNegative = rewriter.create( + loc, arith::CmpIPredicate::sle, indexOffset, indexZero); + Value sizeForNegativeOffset = rewriter.create( + loc, + rewriter.create( + loc, rewriter.create(loc, dim1Size, indexOffset), + dim2Size), + indexZero); + Value sizeForPositiveOffset = rewriter.create( + loc, + rewriter.create( + loc, rewriter.create(loc, dim2Size, indexOffset), + dim1Size), + indexZero); + Value diagSize = rewriter.create( + loc, offsetIsNegative, sizeForNegativeOffset, sizeForPositiveOffset); + + // depending on its sign, the offset affects only the row or column indices + // of the diagonal + Value diagStart1 = rewriter.create( + loc, offsetIsNegative, + rewriter.create(loc, indexOffset, indexMinusOne), + indexZero); + Value diagStart2 = rewriter.create(loc, offsetIsNegative, + indexZero, indexOffset); + + SmallVector outputDims; + for (auto i = 0; i < inputRank; i++) { + if (!(i == dim1 || i == dim2)) + outputDims.push_back(getDimOp(rewriter, loc, inputMatrix, i)); + } + outputDims.push_back(diagSize); + + Value outputMatrix = rewriter.create( + loc, getAsOpFoldResult(outputDims), elementType); + + SmallVector indexingMaps = { + AffineMap::getMultiDimIdentityMap(outputRank, rewriter.getContext())}; + SmallVector iteratorTypes( + outputRank, utils::IteratorType::parallel); + + auto diagonal = + rewriter + .create( + loc, outputMatrix.getType(), ValueRange{}, outputMatrix, + indexingMaps, iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + SmallVector diagIndices; + Value indexOnDiag = + b.create(loc, outputRank - 1); + Value dim1Index = + b.create(loc, indexOnDiag, diagStart1); + Value dim2Index = + b.create(loc, indexOnDiag, diagStart2); + + // specify at which input indices the diagonal values are + // extracted + for (int indIn = 0, indOut = 0; indIn < inputRank; indIn++) { + if (indIn == dim1) + diagIndices.push_back(dim1Index); + else if (indIn == dim2) + diagIndices.push_back(dim2Index); + else { + diagIndices.push_back( + b.create(loc, indOut)); + indOut++; + } + } + Value diagElt = b.create( + loc, elementType, inputMatrix, diagIndices); + b.create(loc, diagElt); + }) + .getResult(0); + + rewriter.replaceOpWithNewOp(op, outputType, diagonal); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -1872,4 +2008,6 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index d058d0fdb97c..4b4ae748f9e9 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6238,6 +6238,74 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.diagonal\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: diagonal dimensions cannot be identical\"\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str_0 = torch.constant.str \"AssertionError: input must have at least two dimensions\"\n" +" %int2 = torch.constant.int 2\n" +" %int9223372036854775807 = torch.constant.int 9223372036854775807\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.ge.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_0, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %3 = call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg2, %2, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" %4 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %5 = call @__torch__.torch.jit._shape_functions.maybe_wrap_dim(%arg3, %4, %true) : (!torch.int, !torch.int, !torch.bool) -> !torch.int\n" +" %6 = torch.aten.ne.int %3, %5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %6 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %7 = torch.prim.ListConstruct : () -> !torch.list\n" +" %8 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %9 = torch.prim.ListConstruct %int9223372036854775807, %8 : (!torch.int, !torch.int) -> !torch.list\n" +" %10 = torch.prim.min.self_int %9 : !torch.list -> !torch.int\n" +" torch.prim.Loop %10, %true, init() {\n" +" ^bb0(%arg4: !torch.int):\n" +" %19 = torch.aten.__getitem__.t %arg0, %arg4 : !torch.list, !torch.int -> !torch.int\n" +" %20 = torch.aten.eq.int %arg4, %3 : !torch.int, !torch.int -> !torch.bool\n" +" %21 = torch.prim.If %20 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %22 = torch.aten.eq.int %arg4, %5 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %22 : !torch.bool\n" +" }\n" +" torch.prim.If %21 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %22 = torch.aten.append.t %7, %19 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %11 = torch.aten.__getitem__.t %arg0, %3 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.__getitem__.t %arg0, %5 : !torch.list, !torch.int -> !torch.int\n" +" %13 = torch.aten.sub.int %12, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %14 = torch.prim.min.int %11, %13 : !torch.int, !torch.int -> !torch.int\n" +" %15 = torch.prim.max.int %14, %int0 : !torch.int, !torch.int -> !torch.int\n" +" %16 = torch.aten.lt.int %arg1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %17 = torch.prim.If %16 -> (!torch.int) {\n" +" %19 = torch.aten.__getitem__.t %arg0, %3 : !torch.list, !torch.int -> !torch.int\n" +" %20 = torch.aten.add.int %19, %arg1 : !torch.int, !torch.int -> !torch.int\n" +" %21 = torch.aten.__getitem__.t %arg0, %5 : !torch.list, !torch.int -> !torch.int\n" +" %22 = torch.prim.min.int %20, %21 : !torch.int, !torch.int -> !torch.int\n" +" %23 = torch.prim.max.int %22, %int0 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.If.yield %23 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %15 : !torch.int\n" +" }\n" +" %18 = torch.aten.append.t %7, %17 : !torch.list, !torch.int -> !torch.list\n" +" return %7 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.tan\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -9980,6 +10048,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.diagonal\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.uniform\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index 06330f16a57e..bf371d7c4687 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -222,7 +222,7 @@ bool Torch::isViewLikeOp(Operation *op) { AtenNarrowOp, AtenNarrowTensorOp, AtenToDeviceOp, PrimsSqueezeOp, AtenMovedimIntOp, PrimsViewOfOp, AtenRealOp, AtenImagOp, PrimsSplitDimOp, AtenViewAsComplexOp, AtenViewAsRealOp, - AtenPixelShuffleOp>(op); + AtenPixelShuffleOp, AtenDiagonalOp>(op); } Value Torch::getConstantWithGivenDtypeAndValue(PatternRewriter &rewriter, diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 76b3d6af3ee7..9876864a86d8 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -59,6 +59,36 @@ def aten〇triu〡shape(self: List[int], diagonal: int = 0) -> List[int]: def aten〇tril〡shape(self: List[int], diagonal: int = 0) -> List[int]: return upstream_shape_functions.unary(self) +@check_shape_function([ + Invocation(TensorOfShape(2, 3, 4)), # Basic case. + Invocation(TensorOfShape(2, 3, 4), dim1=1, dim2=2), # Test explicit `dim1` and `dim2`. + Invocation(TensorOfShape(2, 3, 4), dim1=-1, dim2=-2, offset=1), # Positive `offset`. + Invocation(TensorOfShape(2, 3, 4), offset=-1), # Negative `offset``. + Invocation(TensorOfShape(2, 3, 4), offset=3), # Empty result due to large `offset`. + ErrorInvocation(TensorOfShape(2)), # Input one-dimensional. + ErrorInvocation(TensorOfShape(2, 3, 4), dim1=1, dim2=1), # `dim1` and `dim2` equal. + ErrorInvocation(TensorOfShape(2, 3, 4), dim1=3, dim2=1), # `dim1` out of bounds. +]) +def aten〇diagonal〡shape(self: List[int], offset: int = 0, dim1: int = 0, dim2: int = 1) -> List[int]: + assert len(self) >= 2, "input must have at least two dimensions" + dim1 = upstream_shape_functions.maybe_wrap_dim(dim1, len(self)) + dim2 = upstream_shape_functions.maybe_wrap_dim(dim2, len(self)) + assert dim1 != dim2, "diagonal dimensions cannot be identical" + + diagonal: List[int] = [] + for i, self_dim in enumerate(self): + if (i==dim1) or (i==dim2): + pass + else: + diagonal.append(self_dim) + + diag_size = max(min(self[dim1], self[dim2] - offset), 0) + if offset<0: + diag_size = max(min(self[dim1] + offset, self[dim2]), 0) + diagonal.append(diag_size) + + return diagonal + def aten〇tan〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -2493,6 +2523,11 @@ def aten〇tril〡dtype(self_rank_dtype: Tuple[int, int], diagonal: int = 0) -> self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3)], dim1=0, dim2=1)) +def aten〇diagonal〡dtype(self_rank_dtype: Tuple[int, int], offset: int = 0, dim1: int = 0, dim2: int = 1) -> int: + self_rank, self_dtype = self_rank_dtype + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇uniform〡dtype(self_rank_dtype: Tuple[int, int], from_: float = 0., to: float = 1., generator: Any = None) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 93f1e741dd6e..6099fd64e5f4 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -672,6 +672,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::alias_copy : (Tensor) -> (Tensor)") emit("aten::alias : (Tensor) -> (Tensor)", has_folder=True) emit("aten::as_strided_copy : (Tensor, int[], int[], int?) -> (Tensor)") + emit("aten::diagonal : (Tensor, int, int, int) -> (Tensor)") emit("aten::diagonal_copy : (Tensor, int, int, int) -> (Tensor)") emit("aten::expand_copy : (Tensor, int[], bool) -> (Tensor)") emit("aten::permute_copy : (Tensor, int[]) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py index f24266c78df8..10130a73fe85 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -60,3 +60,4 @@ def register_all_tests(): from . import control_flow from . import stats from . import padding + from . import diagonal diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py new file mode 100644 index 000000000000..13d49cea0737 --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py @@ -0,0 +1,123 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +import torch + +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export + +# ============================================================================== + +class DiagonalModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + + def forward(self, a): + return torch.ops.aten.diagonal(a) + + +@register_test_case(module_factory=lambda: DiagonalModule()) +def DiagonalModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 3)) + +@register_test_case(module_factory=lambda: DiagonalModule()) +def DiagonalModule_nonsquare(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + +# ============================================================================== + +class DiagonalTransposedModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diagonal(a, dim1=1, dim2=0) + +@register_test_case(module_factory=lambda: DiagonalTransposedModule()) +def DiagonalModule_transposed(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + +# ============================================================================== + +class DiagonalWithDimsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diagonal(a, dim1=0, dim2=1) + +@register_test_case(module_factory=lambda: DiagonalWithDimsModule()) +def DiagonalModule_with_dims(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class DiagonalWithNegativeDimsModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diagonal(a, dim1=-2, dim2=-1) + +@register_test_case(module_factory=lambda: DiagonalWithNegativeDimsModule()) +def DiagonalModule_with_negative_dims(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + +class DiagonalWithOffsetModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diagonal(a, offset=1) + +@register_test_case(module_factory=lambda: DiagonalWithOffsetModule()) +def DiagonalModule_with_offset(module, tu: TestUtils): + module.forward(tu.rand(4, 6)) + +# ============================================================================== + +class DiagonalWithDimsOffsetModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.diagonal(a, dim1=0, dim2=1, offset=-1) + +@register_test_case(module_factory=lambda: DiagonalWithDimsOffsetModule()) +def DiagonalModule_with_dims_and_offset(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) From 73b30604da6cd7d24865fd5aa9c855a89c69377e Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Mon, 22 Jan 2024 13:57:56 -0500 Subject: [PATCH 098/283] Do not try to legalize transposed convolution (#2721) Currently transposed convolution is not handled correctly by `TorchToTosa`. This PR allows transposed convolutions to pass through the conversion so that they can be handled by other conversion passes later in a pipeline. An example input which produces a compilation error is: ``` func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> { %true = torch.constant.bool true %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 %weight = torch.vtensor.literal(dense<0.0> : tensor<64x64x3x3xf32>) : !torch.vtensor<[64,64,3,3],f32> %bias = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32> %stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list %int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list %output = torch.aten.convolution %input, %weight, %bias, %stride, %int1x1, %int1x1, %true, %int1x1, %int1 : !torch.vtensor<[1,64,1,100],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,64,2,200],f32> return %output : !torch.vtensor<[1,64,2,200],f32> } ``` This MLIR produces an error about a cast operation with a size mismatch when passed through `torch-to-tosa`: ``` error: 'tensor.cast' op operand type 'tensor<1x64x1x50xf32>' and result type 'tensor<1x64x2x200xf32>' are cast incompatible ``` --------- Co-authored-by: Srinath Avadhanula --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 8 ++++++++ .../TorchToTosa/conv2d_transpose.mlir | 18 ++++++++++++++++++ 2 files changed, 26 insertions(+) create mode 100644 test/Conversion/TorchToTosa/conv2d_transpose.mlir diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 919fe73b2092..b49c9af8adce 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1850,6 +1850,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( AtenConvolutionOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const { + bool transposed; + if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) + return rewriter.notifyMatchFailure( + op, "Unimplemented: non-constant value for transposed not supported"); + if (transposed) + return rewriter.notifyMatchFailure( + op, "Unimplemented: transposed convolution not supported"); + auto input = adaptor.getInput(); auto weight = adaptor.getWeight(); diff --git a/test/Conversion/TorchToTosa/conv2d_transpose.mlir b/test/Conversion/TorchToTosa/conv2d_transpose.mlir new file mode 100644 index 000000000000..678034cb8405 --- /dev/null +++ b/test/Conversion/TorchToTosa/conv2d_transpose.mlir @@ -0,0 +1,18 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file -verify-diagnostics + +// The following test ensures that a tranposed convolution op is not +// lowered in the torch-to-tosa conversion pass. + +func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[1,64,2,200],f32> { + %true = torch.constant.bool true + %int1 = torch.constant.int 1 + %int2 = torch.constant.int 2 + %weight = torch.vtensor.literal(dense<0.0> : tensor<64x64x3x3xf32>) : !torch.vtensor<[64,64,3,3],f32> + %bias = torch.vtensor.literal(dense<0.0> : tensor<64xf32>) : !torch.vtensor<[64],f32> + %stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list + %int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list + // expected-error@+1 {{failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal}} + %output = torch.aten.convolution %input, %weight, %bias, %stride, %int1x1, %int1x1, %true, %int1x1, %int1 : !torch.vtensor<[1,64,1,100],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,64,2,200],f32> + return %output : !torch.vtensor<[1,64,2,200],f32> +} + From 5883ef0f21b526e399f1f6bd6fd8f548ef3a7115 Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Mon, 22 Jan 2024 19:05:55 +0000 Subject: [PATCH 099/283] Fix unused variable warnings (#2775) --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 1 - .../TorchToLinalg/TensorConstructors.cpp | 15 --------------- 2 files changed, 16 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 4df2e0f88eb1..b9b60e7748b4 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -596,7 +596,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( if (binder.tensorResultType(resultType)) return failure(); auto dtype = resultType.getDtype(); - Value scalarValue; float floatValue; if (binder.op->hasAttr("torch.onnx.value_float") && diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 9429d1e8caca..6afae47c1325 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -98,11 +98,8 @@ namespace { Location loc = op->getLoc(); Value input = adaptor.getSelf(); - MLIRContext *context = rewriter.getContext(); auto inputType = llvm::cast(input.getType()); int64_t inputRank = inputType.getRank(); - auto outputType = llvm::cast( - getTypeConverter()->convertType(op->getResult(0).getType())); unsigned numDims = inputType.getRank(); assert(numDims >= 2 && "Not enough input dimensions"); @@ -171,16 +168,6 @@ namespace { } } - // Some generic helper functions to aid in constructing basic arithmetic. - auto createAdd = [&](Value x, Value y) { - return rewriter.create(loc, x, y); - }; - - auto createAdds = [&](std::initializer_list values) { - assert(values.size() >= 2); - return std::accumulate(values.begin() + 1, values.end(), data(values)[0], - createAdd); - }; auto createSub = [&](Value x, Value y) { return rewriter.create(loc, x, y); }; @@ -247,8 +234,6 @@ namespace { tensorsRes.push_back(leftPadTile); } if (hasTopPadding) { - Value topLeftValue = rewriter.create( - loc, input, ValueRange{zero, zero, zero, zero}); Value topHcenterSlice = rewriter.create( loc, input, extractOffsetsLT, extractShapeTB, allOneStrides); for (auto i = 0; i < padInts[2]; ++i) { From cad98e81136c289694661baadd57759e817eaff7 Mon Sep 17 00:00:00 2001 From: Chi_Liu <22491986+AmosLewis@users.noreply.github.com> Date: Mon, 22 Jan 2024 12:56:39 -0800 Subject: [PATCH 100/283] [ONNX][TORCH-MLIR] Add TopK support (#2774) https://github.com/nod-ai/SHARK-Turbine/issues/331 --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 31 +++++++++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 55 +++++++++++++++++++ 2 files changed, 86 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 25f46d3f210a..8d2c50c08dcc 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1445,4 +1445,35 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( none, none, none); return success(); }); + patterns.onOp( + "Topk", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType Values_type, Indices_type; + Value X, K; + int64_t axis; + bool largest, sorted; + if (binder.tensorOperandAtIndex(X, 0) || + binder.tensorOperandAtIndex(K, 1) || + binder.s64IntegerAttr(axis, "axis", -1) || + binder.s64BoolAttr(largest, "largest", true) || + binder.s64BoolAttr(sorted, "sorted", true) || + binder.tensorResultTypeAtIndex(Values_type, 0) || + binder.tensorResultTypeAtIndex(Indices_type, 1)) + return failure(); + std::optional maybeRank = Torch::getTensorRank(X); + if (!maybeRank) + return rewriter.notifyMatchFailure(binder.op, + "Unimplemented: unranked tensor"); + unsigned rank = *maybeRank; + axis = Torch::toPositiveDim(axis, rank); + Value cstAxis = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(axis)); + Value cstLargest = + rewriter.create(binder.getLoc(), largest); + Value cstSorted = + rewriter.create(binder.getLoc(), sorted); + rewriter.replaceOpWithNewOp( + binder.op, Values_type, Indices_type, X, K, cstAxis, cstLargest, + cstSorted); + return success(); + }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 1725b2a15911..f18f28a60d2d 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -990,6 +990,8 @@ func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f3 return %0 : !torch.vtensor<[20,10,1],f32> } +// ----- + // CHECK-LABEL: func.func @test_slice_default_steps func.func @test_slice_default_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[3],si64>, %arg3: !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { //CHECK: %[[NONE:.*]] = torch.constant.none @@ -1036,6 +1038,9 @@ func.func @test_slice_default_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: %0 = torch.operator "onnx.Slice"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[20,10,5],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> return %0 : !torch.vtensor<[20,10,1],f32> } + +// ----- + // CHECK-LABEL: func.func @test_reshape_negative_dim func.func @test_reshape_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 @@ -1069,6 +1074,8 @@ func.func @test_reshape_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: return %0 : !torch.vtensor<[2,6,2],f32> } +// ----- + // CHECK-LABEL: func.func @test_reshape_negative_extended_dims func.func @test_reshape_negative_extended_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,2,3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 @@ -1109,6 +1116,8 @@ func.func @test_reshape_negative_extended_dims(%arg0: !torch.vtensor<[2,3,4],f32 return %0 : !torch.vtensor<[1,2,3,4],f32> } +// ----- + // CHECK-LABEL: func.func @test_reshape_one_dim func.func @test_reshape_one_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[24],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 @@ -1126,6 +1135,8 @@ func.func @test_reshape_one_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torc return %0 : !torch.vtensor<[24],f32> } +// ----- + // CHECK-LABEL: func.func @test_reshape_reduced_dims func.func @test_reshape_reduced_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,12],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 @@ -1151,6 +1162,8 @@ func.func @test_reshape_reduced_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: return %0 : !torch.vtensor<[2,12],f32> } +// ----- + // CHECK-LABEL: func.func @test_reshape_reordered_all_dims func.func @test_reshape_reordered_all_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[4,2,3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 @@ -1184,6 +1197,8 @@ func.func @test_reshape_reordered_all_dims(%arg0: !torch.vtensor<[2,3,4],f32>, % return %0 : !torch.vtensor<[4,2,3],f32> } +// ----- + // CHECK-LABEL: func.func @test_reshape_zero_and_negative_dim func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[2,3,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 @@ -1224,6 +1239,8 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32> return %0 : !torch.vtensor<[2,3,1,4],f32> } +// ----- + // CHECK-LABEL: func.func @test_range_float64_type func.func @test_range_float64_type(%arg0: !torch.vtensor<[],f64>, %arg1: !torch.vtensor<[],f64>, %arg2: !torch.vtensor<[],f64>) -> !torch.vtensor<[2],f64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.*]] torch.constant.none @@ -1235,6 +1252,8 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32> return %0 : !torch.vtensor<[2],f64> } +// ----- + // CHECK-LABEL: func.func @test_range_float32_type func.func @test_range_float32_type(%arg0: !torch.vtensor<[],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.*]] torch.constant.none @@ -1246,6 +1265,8 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32> return %0 : !torch.vtensor<[2],f32> } +// ----- + // CHECK-LABEL: func.func @test_range_int64_type func.func @test_range_int64_type(%arg0: !torch.vtensor<[],si64>, %arg1: !torch.vtensor<[],si64>, %arg2: !torch.vtensor<[],si64>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.*]] torch.constant.none @@ -1257,6 +1278,8 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32> return %0 : !torch.vtensor<[2],si64> } +// ----- + // CHECK-LABEL: func.func @test_range_int32_type func.func @test_range_int32_type(%arg0: !torch.vtensor<[],si32>, %arg1: !torch.vtensor<[],si32>, %arg2: !torch.vtensor<[],si32>) -> !torch.vtensor<[2],si32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.*]] torch.constant.none @@ -1268,6 +1291,8 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32> return %0 : !torch.vtensor<[2],si32> } +// ----- + // CHECK-LABEL: func.func @test_range_int16_type func.func @test_range_int16_type(%arg0: !torch.vtensor<[],si16>, %arg1: !torch.vtensor<[],si16>, %arg2: !torch.vtensor<[],si16>) -> !torch.vtensor<[2],si16> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.*]] torch.constant.none @@ -1277,4 +1302,34 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32> // CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si16> %0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si16>, !torch.vtensor<[],si16>, !torch.vtensor<[],si16>) -> !torch.vtensor<[2],si16> return %0 : !torch.vtensor<[2],si16> + } + +// ----- + +// CHECK-LABEL : func.func @test_top_k + func.func @test_top_k(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64} { + // CHECK: %[[RESULTS:.*]]:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) + // CHECK: return %[[RESULTS]]#0, %[[RESULTS]]#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> + %0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) + return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> } + +// ----- + +// CHECK-LABEL: func.func @test_top_k_smallest + func.func @test_top_k_smallest(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[RESULTS:.*]]:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64, torch.onnx.largest = 0 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) + // CHECK: return %[[RESULTS]]#0, %[[RESULTS]]#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> + %0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64, torch.onnx.largest = 0 : si64, torch.onnx.sorted = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) + return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> + } + +// ----- + +// CHECK-LABEL: func.func @test_top_k_negative_axis + func.func @test_top_k_negative_axis(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64} { + // CHECK: %[[RESULTS:.*]]:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) + // CHECK: return %[[RESULTS]]#0, %[[RESULTS]]#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> + %0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) + return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> + } From d452c4f4c0d5b8fe748eebb1b9801f44a7d13374 Mon Sep 17 00:00:00 2001 From: Dave Liddell <44620210+daveliddell@users.noreply.github.com> Date: Mon, 22 Jan 2024 14:00:05 -0700 Subject: [PATCH 101/283] Fix onnx importer to treat Constant values as static (#2780) Fixes https://github.com/llvm/torch-mlir/issues/2764 In the case of OPT, there are ConstantOfShape ops whose input shape is not static (that is, an initializer), but rather comes from a Constant op. The importer can't handle such non-static input shapes. The fix here is to create initializers for a subset of Constant ops (ones with "value" attributes), so that their outputs can be used statically. Additionally, there was no case for creating a splat of int64, so I added that as well. --------- Co-authored-by: Dave Liddell --- python/torch_mlir/extras/onnx_importer.py | 45 +++++++++++++++++++---- 1 file changed, 37 insertions(+), 8 deletions(-) diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index dbf0adc490bd..59a2682bbba9 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -276,10 +276,13 @@ def import_node(self, node: onnx.NodeProto): with InsertionPoint(self._b), Location.name(node.name): op_type = node.op_type # Handle special op types that materialize to non-op IR constructs. + # Handlers return True if the op was handled, else this function + # should process it as a general node. special_key = f"_handle_node_{op_type}" if hasattr(self, special_key): - getattr(self, special_key)(node) - return + was_handled = getattr(self, special_key)(node) + if was_handled: + return # General node import. input_values = [] @@ -333,8 +336,11 @@ def import_attributes( ) attrs[f"torch.onnx.{onnx_attr.name}"] = handler(onnx_attr, self._cc) - def import_initializer(self, initializer: onnx.TensorProto) -> Value: - with InsertionPoint(self._b), Location.name(initializer.name): + def import_initializer(self, initializer: onnx.TensorProto, extern_name: str = None) -> Value: + # If an explicitly specified name is given, use that; otherwise, pick + # up the name from the tensor proto itself + iname = extern_name if extern_name else initializer.name + with InsertionPoint(self._b), Location.name(iname): value_attr = self._cc.tensor_proto_to_attr(initializer) vtensor_type = self._cc.tensor_proto_to_type(initializer) literal_op = Operation.create( @@ -342,7 +348,7 @@ def import_initializer(self, initializer: onnx.TensorProto) -> Value: results=[vtensor_type], attributes={"value": value_attr}, ) - self._nv_map[initializer.name] = literal_op.result + self._nv_map[iname] = literal_op.result return literal_op.result def _get_immediate_tensor(self, name: str) -> np.array: @@ -366,7 +372,23 @@ def _get_immediate_tensor(self, name: str) -> np.array: f"Unhandled ONNX TensorProto immediate data: {initializer}" ) - def _handle_node_ConstantOfShape(self, node: onnx.NodeProto): + def _handle_node_Constant(self, node: onnx.NodeProto) -> bool: + # Special case only for constants specified by value attribute (for now) + value_proto = _get_attr(node, "value", False) + if not value_proto: + return False + + # Produce an initializer for the constant, so that it can be used in + # combination with other ops, such as ConstantOfShape, requiring + # a constant input + assert value_proto.type == onnx.AttributeProto.AttributeType.TENSOR + assert len(node.output) == 1 + const_name = node.output[0] + self.import_initializer(value_proto.t, const_name) + self._gi.initializer_map[const_name] = value_proto.t + return True + + def _handle_node_ConstantOfShape(self, node: onnx.NodeProto) -> bool: # This op is special: It has an input of the shape, and in full generality # could involve eager production of constants of variable size. In # practice, the DNN profile for ONNX makes this very difficult to do @@ -394,6 +416,7 @@ def _handle_node_ConstantOfShape(self, node: onnx.NodeProto): attributes={"value": value_attr}, ) self._nv_map[node.output[0]] = literal_op.result + return True class ContextCache: @@ -515,6 +538,11 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: onnx.TensorProto.DataType.FLOAT: lambda tp, shape: DenseElementsAttr.get_splat( RankedTensorType.get(shape, F32Type.get()), FloatAttr.get_f32(tp.float_data[0]) ), + onnx.TensorProto.DataType.INT64: lambda tp, shape: DenseElementsAttr.get_splat( + RankedTensorType.get(shape, IntegerType.get_signed(64)), IntegerAttr.get( + IntegerType.get_signed(64), int.from_bytes(tp.raw_data, "little", + signed=True) if tp.HasField("raw_data") else tp.int64_data[0]) + ), # TODO: All the rest from ELEM_TYPE_TO_IR_TYPE_CB } @@ -605,9 +633,10 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: } -def _get_attr(node: onnx.NodeProto, attr_name: str) -> onnx.AttributeProto: +def _get_attr(node: onnx.NodeProto, attr_name: str, is_required: bool = True) -> onnx.AttributeProto: for attr in node.attribute: if attr.name == attr_name: return attr - else: + if is_required: raise OnnxImportError(f"Required attribute {attr_name} not found in {node}") + return None From b7a032967678491bb4765a10b73e8a99c1c76b60 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Tue, 23 Jan 2024 19:23:01 +0530 Subject: [PATCH 102/283] [ONNX][MLIR] Fix padding size constraint for onnx.maxpool op (#2782) Signed-off-by: Gaurav Shukla --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 5 ++-- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 23 +++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 4834e7af4d5e..859215d287d9 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -201,9 +201,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, "kernel list size does not match the number of axes"); if (binder.s64IntegerArrayAttr(padding, "pads", {0})) return rewriter.notifyMatchFailure(binder.op, "pads bind failure"); - if (padding.size() != 1 && padding.size() != rank - 2) + if (padding.size() != 1 && padding.size() != 2 * (rank - 2)) return rewriter.notifyMatchFailure( - binder.op, "padding list size does not match the number of axes"); + binder.op, "padding list must contain (begin,end) pair for each " + "spatial axis"); if (binder.s64IntegerArrayAttr(strides, "strides", {1})) return rewriter.notifyMatchFailure(binder.op, "strides bind failure"); if (strides.size() != 1 && strides.size() != rank - 2) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 59049e40614c..ba4487152b5c 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -274,6 +274,29 @@ func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) -> // ----- +// CHECK-LABEL: func.func @test_maxpool_pad +func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,112,112],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[INT3:.*]] = torch.constant.int 3 + // CHECK: %[[INT3_0:.*]] = torch.constant.int 3 + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[INT1_1:.*]] = torch.constant.int 1 + // CHECK: %[[INT1_2:.*]] = torch.constant.int 1 + // CHECK: %[[INT1_3:.*]] = torch.constant.int 1 + // CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_1]], %[[INT1_2]], %[[INT1_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[INT2_4:.*]] = torch.constant.int 2 + // CHECK: %[[LIST3:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[EMPTY_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[OUT:.*]] = torch.aten.max_pool2d %arg0, %[[LIST]], %[[LIST3]], %[[LIST2]], %[[EMPTY_LIST]], %[[FALSE]] : !torch.vtensor<[1,64,112,112],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,64,56,56],f32> + // CHECK: return %[[OUT]] : !torch.vtensor<[1,64,56,56],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 0 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,64,112,112],f32>) -> !torch.vtensor<[1,64,56,56],f32> + return %0 : !torch.vtensor<[1,64,56,56],f32> +} + +// ----- + // CHECK-LABEL: @test_gelu_default_1 func.func @test_gelu_default_1(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[STR1:.*]] = torch.constant.str "none" From c9d8ffb414df5451ee7f193af501afbd9c03a102 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 23 Jan 2024 21:05:19 +0530 Subject: [PATCH 103/283] build: manually update PyTorch version (#2788) Set PyTorch and TorchVision version to nightly release 2024-01-22. Signed-Off By: Vivek Khandelwal --- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index cf7d2b924e62..16be42d6c147 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -03969cb2d2e773af44b71f304d8de81107b2d41e +72fcb9ad662bb941a266e3d747835382634c2be6 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index a4291b382237..1de47ff9a195 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.3.0.dev20240109 +torch==2.3.0.dev20240122 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index d081d22aca9b..fad713123493 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.18.0.dev20240109 +torchvision==0.18.0.dev20240122 From dc056e58e6e6b19a1b686ab7c04c12274864ffd7 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 23 Jan 2024 07:36:25 -0800 Subject: [PATCH 104/283] [MLIR][TORCH] Add onnx.cast cases used by OPT-1.25M (#2787) --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 148 ++++++++++-------- .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 10 ++ 2 files changed, 91 insertions(+), 67 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index b9b60e7748b4..3b47162711cb 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -10,30 +10,39 @@ #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "llvm/Support/FormatVariadic.h" using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::onnx_c; static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) { - int64_t dtypeIntTorch; // TODO: Add complete mapping. - switch (dtypeIntOnnx) { - case 1: - dtypeIntTorch = 6; // float - break; - case 10: - dtypeIntTorch = 5; // half - break; - case 11: - dtypeIntTorch = 7; // double - break; - case 16: - dtypeIntTorch = 15; // bfloat16 - break; - default: - dtypeIntTorch = -1; // No dtype - } + // Where are the ONNX and PyTorch dtype enums defined? + // ONNX: + // https://github.com/shouxieai/tensorRT_Pro/blob/main/onnx/onnx-ml.proto + // PyTorch: + // https://github.com/llvm/torch-mlir/blob/main/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h#L88 + + int64_t dtypeIntTorch = [dtypeIntOnnx]() { + switch (dtypeIntOnnx) { + case 1: + return 6; // float + case 7: + return 5; // int64 + case 9: + return 11; // bool + case 10: + return 5; // half + case 11: + return 7; // double + case 16: + return 15; // bfloat16 + default: + return -1; // No dtype + } + }(); + return dtypeIntTorch; } @@ -415,30 +424,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( } return success(); }); - patterns.onOp( - "BitwiseAnd", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value lhs, rhs; - std::string direction; - if (binder.tensorOperands(lhs, rhs) || - binder.tensorResultType(resultType)) - return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, lhs, rhs); - return success(); - }); - patterns.onOp( - "BitwiseOr", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value lhs, rhs; - std::string direction; - if (binder.tensorOperands(lhs, rhs) || - binder.tensorResultType(resultType)) - return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, lhs, rhs); - return success(); - }); + patterns.onOp("BitwiseAnd", 18, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + std::string direction; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); + patterns.onOp("BitwiseOr", 18, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + std::string direction; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); patterns.onOp("BitwiseNot", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -450,18 +459,18 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - patterns.onOp( - "BitwiseXor", 18, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value lhs, rhs; - std::string direction; - if (binder.tensorOperands(lhs, rhs) || - binder.tensorResultType(resultType)) - return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, lhs, rhs); - return success(); - }); + patterns.onOp("BitwiseXor", 18, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs; + std::string direction; + if (binder.tensorOperands(lhs, rhs) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, lhs, rhs); + return success(); + }); patterns.onOp( "Cast", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -474,9 +483,13 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( dtypeIntTorch = onnxDtypeIntToTorchDtypeInt(dtypeIntOnnx); if (dtypeIntTorch == -1) { - return rewriter.notifyMatchFailure( - binder.op, - "unimplemented support for the given dtype conversion"); + auto message = llvm::formatv("unimplemented support for the given " + "dtype conversion (onnx 'type' = {0})", + dtypeIntOnnx); + llvm::errs() << message << "\n"; + auto y = rewriter.notifyMatchFailure(binder.op, message); + + return y; } Value constDtype = rewriter.create( binder.getLoc(), rewriter.getType(), @@ -864,7 +877,8 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( unsigned rank = *maybeRank; SmallVector padding, strides, dilations, outputPadding; - SmallVector defaultPadding, defaultStrides, defaultDilations, defaultOutputPadding; + SmallVector defaultPadding, defaultStrides, defaultDilations, + defaultOutputPadding; for (unsigned i = 0; i < rank - 2; i++) { defaultPadding.push_back(0); defaultStrides.push_back(1); @@ -1018,30 +1032,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( cast(operand.getType()).getSizes().size(); Value rankVal = rewriter.create( binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - rank)); + rewriter.getIntegerAttr(rewriter.getIntegerType(64), rank)); Value zero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); - + Value axisScalar = rewriter.create( binder.getLoc(), rewriter.getType(), axisTensor); - Value isNegative = - rewriter.create(binder.getLoc(), axisScalar, zero); - isNegative = rewriter.create(binder.getLoc(), - isNegative); + Value isNegative = rewriter.create( + binder.getLoc(), axisScalar, zero); + isNegative = + rewriter.create(binder.getLoc(), isNegative); Value finalOffset = rewriter.create( binder.getLoc(), isNegative, rankVal); Value dim = rewriter.create( binder.getLoc(), axisScalar, finalOffset); - Torch::BaseTensorType resultTensorType = resultType.cast(); + Torch::BaseTensorType resultTensorType = + resultType.cast(); if (!resultTensorType.hasDtype()) { return rewriter.notifyMatchFailure( binder.op, "expected result type to have a dtype"); } // resultTensorType.print(llvm::outs()); - Value resultDType = - Torch::getDtypeIntValueForType(rewriter, loc, resultTensorType.getDtype()); + Value resultDType = Torch::getDtypeIntValueForType( + rewriter, loc, resultTensorType.getDtype()); rewriter.replaceOpWithNewOp( binder.op, resultType, operand, dim, resultDType); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index c2d3c12a7b92..bb02a29cb592 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -332,6 +332,16 @@ func.func @test_cast_FLOAT16_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f16>) -> !tor return %0 : !torch.vtensor<[3,4],f64> } +// CHECK-LABEL: @test_cast_FLOAT_to_BOOL +func.func @test_cast_FLOAT_to_BOOL(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT:.*]] = torch.constant.int 11 + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.to.dtype %arg0, %[[INT]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[3,4],i1> + %0 = torch.operator "onnx.Cast"(%arg0) {torch.onnx.to = 9 : si64} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],i1> + return %0 : !torch.vtensor<[3,4],i1> +} + // CHECK-LABEL: @test_cast_FLOAT16_to_FLOAT func.func @test_cast_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 6 From 77ae56337dbf95eb809f4a7d218a9fb3dc1f41b0 Mon Sep 17 00:00:00 2001 From: Chi_Liu <22491986+AmosLewis@users.noreply.github.com> Date: Tue, 23 Jan 2024 13:45:00 -0800 Subject: [PATCH 105/283] [ONNX][MLIR] Add support for onnx.Exp op (#2792) https://github.com/nod-ai/SHARK-Turbine/issues/312 --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 11 + .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 228 ++++++++++++++++++ 2 files changed, 239 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 3b47162711cb..54cfb3e2ab13 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1332,6 +1332,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); + patterns.onOp("Exp", 6, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp( "Expand", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { // uses ideas and code from onnx.Reshape diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index bb02a29cb592..77cdd5786ea8 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -11,6 +11,8 @@ func.func @test_abs(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5 return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_add func.func @test_add(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -19,6 +21,8 @@ func.func @test_add(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3 return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_add_bcast func.func @test_add_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -27,6 +31,8 @@ func.func @test_add_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vten return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_add_uint8 func.func @test_add_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -35,6 +41,8 @@ func.func @test_add_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vten return %0 : !torch.vtensor<[3,4,5],ui8> } +// ----- + // CHECK-LABEL: @test_and_bcast3v1d func.func @test_and_bcast3v1d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.vtensor<[5],i1>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.logical_and %arg0, %arg1 : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[5],i1> -> !torch.vtensor<[3,4,5],i1> @@ -42,6 +50,8 @@ func.func @test_and_bcast3v1d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.v return %0 : !torch.vtensor<[3,4,5],i1> } +// ----- + // CHECK-LABEL: @test_argmax_default_axis_example func.func @test_argmax_default_axis_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[1,2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 0 @@ -51,6 +61,8 @@ func.func @test_argmax_default_axis_example(%arg0: !torch.vtensor<[2,2],f32>) -> return %0 : !torch.vtensor<[1,2],si64> } +// ----- + // CHECK-LABEL: @test_argmax_negative_axis_keepdims_example func.func @test_argmax_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,1],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 1 @@ -60,6 +72,8 @@ func.func @test_argmax_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2 return %0 : !torch.vtensor<[2,1],si64> } +// ----- + // CHECK-LABEL: @test_argmax_no_keepdims_example func.func @test_argmax_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 1 @@ -69,6 +83,8 @@ func.func @test_argmax_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> return %0 : !torch.vtensor<[2],si64> } +// ----- + // CHECK-LABEL: @test_argmin_default_axis_example func.func @test_argmin_default_axis_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[1,2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 0 @@ -78,6 +94,8 @@ func.func @test_argmin_default_axis_example(%arg0: !torch.vtensor<[2,2],f32>) -> return %0 : !torch.vtensor<[1,2],si64> } +// ----- + // CHECK-LABEL: @test_argmin_negative_axis_keepdims_example func.func @test_argmin_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,1],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 1 @@ -87,6 +105,8 @@ func.func @test_argmin_negative_axis_keepdims_example(%arg0: !torch.vtensor<[2,2 return %0 : !torch.vtensor<[2,1],si64> } +// ----- + // CHECK-LABEL: @test_argmin_no_keepdims_example func.func @test_argmin_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 1 @@ -96,6 +116,8 @@ func.func @test_argmin_no_keepdims_example(%arg0: !torch.vtensor<[2,2],f32>) -> return %0 : !torch.vtensor<[2],si64> } +// ----- + // CHECK-LABEL: @test_atan func.func @test_atan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.atan %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -103,6 +125,8 @@ func.func @test_atan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_acos func.func @test_acos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.acos %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -110,6 +134,8 @@ func.func @test_acos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_bernoulli func.func @test_bernoulli(%arg0: !torch.vtensor<[10],f64>) -> !torch.vtensor<[10],f64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.*]] = torch.constant.none @@ -118,6 +144,8 @@ func.func @test_bernoulli(%arg0: !torch.vtensor<[10],f64>) -> !torch.vtensor<[10 return %0 : !torch.vtensor<[10],f64> } +// ----- + // CHECK-LABEL: @test_bernoulli_double func.func @test_bernoulli_double(%arg0: !torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f64> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.*]] = torch.constant.none @@ -129,6 +157,8 @@ func.func @test_bernoulli_double(%arg0: !torch.vtensor<[10],f32>) -> !torch.vten return %0 : !torch.vtensor<[10],f64> } +// ----- + // CHECK-LABEL: @test_bitshift_left_uint8 func.func @test_bitshift_left_uint8(%arg0: !torch.vtensor<[3],ui8>, %arg1: !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8> -> !torch.vtensor<[3],ui8> @@ -136,6 +166,8 @@ func.func @test_bitshift_left_uint8(%arg0: !torch.vtensor<[3],ui8>, %arg1: !torc return %0 : !torch.vtensor<[3],ui8> } +// ----- + // CHECK-LABEL: @test_bitshift_left_uint16 func.func @test_bitshift_left_uint16(%arg0: !torch.vtensor<[3],ui16>, %arg1: !torch.vtensor<[3],ui16>) -> !torch.vtensor<[3],ui16> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui16>, !torch.vtensor<[3],ui16> -> !torch.vtensor<[3],ui16> @@ -143,6 +175,8 @@ func.func @test_bitshift_left_uint16(%arg0: !torch.vtensor<[3],ui16>, %arg1: !to return %0 : !torch.vtensor<[3],ui16> } +// ----- + // CHECK-LABEL: @test_bitshift_left_uint32 func.func @test_bitshift_left_uint32(%arg0: !torch.vtensor<[3],ui32>, %arg1: !torch.vtensor<[3],ui32>) -> !torch.vtensor<[3],ui32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui32>, !torch.vtensor<[3],ui32> -> !torch.vtensor<[3],ui32> @@ -150,6 +184,8 @@ func.func @test_bitshift_left_uint32(%arg0: !torch.vtensor<[3],ui32>, %arg1: !to return %0 : !torch.vtensor<[3],ui32> } +// ----- + // CHECK-LABEL: @test_bitshift_left_uint64 func.func @test_bitshift_left_uint64(%arg0: !torch.vtensor<[3],ui64>, %arg1: !torch.vtensor<[3],ui64>) -> !torch.vtensor<[3],ui64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_left_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui64>, !torch.vtensor<[3],ui64> -> !torch.vtensor<[3],ui64> @@ -157,6 +193,8 @@ func.func @test_bitshift_left_uint64(%arg0: !torch.vtensor<[3],ui64>, %arg1: !to return %0 : !torch.vtensor<[3],ui64> } +// ----- + // CHECK-LABEL: @test_bitshift_right_uint8 func.func @test_bitshift_right_uint8(%arg0: !torch.vtensor<[3],ui8>, %arg1: !torch.vtensor<[3],ui8>) -> !torch.vtensor<[3],ui8> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui8>, !torch.vtensor<[3],ui8> -> !torch.vtensor<[3],ui8> @@ -164,6 +202,8 @@ func.func @test_bitshift_right_uint8(%arg0: !torch.vtensor<[3],ui8>, %arg1: !tor return %0 : !torch.vtensor<[3],ui8> } +// ----- + // CHECK-LABEL: @test_bitshift_right_uint16 func.func @test_bitshift_right_uint16(%arg0: !torch.vtensor<[3],ui16>, %arg1: !torch.vtensor<[3],ui16>) -> !torch.vtensor<[3],ui16> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui16>, !torch.vtensor<[3],ui16> -> !torch.vtensor<[3],ui16> @@ -171,6 +211,8 @@ func.func @test_bitshift_right_uint16(%arg0: !torch.vtensor<[3],ui16>, %arg1: !t return %0 : !torch.vtensor<[3],ui16> } +// ----- + // CHECK-LABEL: @test_bitshift_right_uint32 func.func @test_bitshift_right_uint32(%arg0: !torch.vtensor<[3],ui32>, %arg1: !torch.vtensor<[3],ui32>) -> !torch.vtensor<[3],ui32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui32>, !torch.vtensor<[3],ui32> -> !torch.vtensor<[3],ui32> @@ -178,6 +220,8 @@ func.func @test_bitshift_right_uint32(%arg0: !torch.vtensor<[3],ui32>, %arg1: !t return %0 : !torch.vtensor<[3],ui32> } +// ----- + // CHECK-LABEL: @test_bitshift_right_uint64 func.func @test_bitshift_right_uint64(%arg0: !torch.vtensor<[3],ui64>, %arg1: !torch.vtensor<[3],ui64>) -> !torch.vtensor<[3],ui64> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_right_shift.Tensor %arg0, %arg1 : !torch.vtensor<[3],ui64>, !torch.vtensor<[3],ui64> -> !torch.vtensor<[3],ui64> @@ -185,6 +229,8 @@ func.func @test_bitshift_right_uint64(%arg0: !torch.vtensor<[3],ui64>, %arg1: !t return %0 : !torch.vtensor<[3],ui64> } +// ----- + // CHECK-LABEL: @test_bitwise_and_i16_3d func.func @test_bitwise_and_i16_3d(%arg0: !torch.vtensor<[3,4,5],si16>, %arg1: !torch.vtensor<[3,4,5],si16>) -> !torch.vtensor<[3,4,5],si16> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si16>, !torch.vtensor<[3,4,5],si16> -> !torch.vtensor<[3,4,5],si16> @@ -192,6 +238,8 @@ func.func @test_bitwise_and_i16_3d(%arg0: !torch.vtensor<[3,4,5],si16>, %arg1: ! return %0 : !torch.vtensor<[3,4,5],si16> } +// ----- + // CHECK-LABEL: @test_bitwise_and_i32_2d func.func @test_bitwise_and_i32_2d(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si32>, !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],si32> @@ -199,6 +247,8 @@ func.func @test_bitwise_and_i32_2d(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !to return %0 : !torch.vtensor<[3,4],si32> } +// ----- + // CHECK-LABEL: @test_bitwise_and_ui8_bcast_4v3d func.func @test_bitwise_and_ui8_bcast_4v3d(%arg0: !torch.vtensor<[3,4,5,6],ui8>, %arg1: !torch.vtensor<[4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5,6],ui8>, !torch.vtensor<[4,5,6],ui8> -> !torch.vtensor<[3,4,5,6],ui8> @@ -206,6 +256,8 @@ func.func @test_bitwise_and_ui8_bcast_4v3d(%arg0: !torch.vtensor<[3,4,5,6],ui8>, return %0 : !torch.vtensor<[3,4,5,6],ui8> } +// ----- + // CHECK-LABEL: @test_bitwise_or_i16_4d func.func @test_bitwise_or_i16_4d(%arg0: !torch.vtensor<[3,4,5,6],si8>, %arg1: !torch.vtensor<[3,4,5,6],si8>) -> !torch.vtensor<[3,4,5,6],si8> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_or.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5,6],si8>, !torch.vtensor<[3,4,5,6],si8> -> !torch.vtensor<[3,4,5,6],si8> @@ -213,6 +265,8 @@ func.func @test_bitwise_or_i16_4d(%arg0: !torch.vtensor<[3,4,5,6],si8>, %arg1: ! return %0 : !torch.vtensor<[3,4,5,6],si8> } +// ----- + // CHECK-LABEL: @test_bitwise_or_i32_2d func.func @test_bitwise_or_i32_2d(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_or.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si32>, !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],si32> @@ -220,6 +274,8 @@ func.func @test_bitwise_or_i32_2d(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !tor return %0 : !torch.vtensor<[3,4],si32> } +// ----- + // CHECK-LABEL: @test_bitwise_or_ui8_bcast_4v3d func.func @test_bitwise_or_ui8_bcast_4v3d(%arg0: !torch.vtensor<[3,4,5,6],ui8>, %arg1: !torch.vtensor<[4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_or.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5,6],ui8>, !torch.vtensor<[4,5,6],ui8> -> !torch.vtensor<[3,4,5,6],ui8> @@ -227,6 +283,8 @@ func.func @test_bitwise_or_ui8_bcast_4v3d(%arg0: !torch.vtensor<[3,4,5,6],ui8>, return %0 : !torch.vtensor<[3,4,5,6],ui8> } +// ----- + // CHECK-LABEL: @test_bitwise_not_2d func.func @test_bitwise_not_2d(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],si32> @@ -234,6 +292,8 @@ func.func @test_bitwise_not_2d(%arg0: !torch.vtensor<[3,4],si32>) -> !torch.vten return %0 : !torch.vtensor<[3,4],si32> } +// ----- + // CHECK-LABEL: @test_bitwise_not_4d func.func @test_bitwise_not_4d(%arg0: !torch.vtensor<[3,4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4,5,6],ui8> -> !torch.vtensor<[3,4,5,6],ui8> @@ -241,6 +301,8 @@ func.func @test_bitwise_not_4d(%arg0: !torch.vtensor<[3,4,5,6],ui8>) -> !torch.v return %0 : !torch.vtensor<[3,4,5,6],ui8> } +// ----- + // CHECK-LABEL: @test_bitwise_xor_i16_3d func.func @test_bitwise_xor_i16_3d(%arg0: !torch.vtensor<[3,4,5],si16>, %arg1: !torch.vtensor<[3,4,5],si16>) -> !torch.vtensor<[3,4,5],si16> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_xor.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si16>, !torch.vtensor<[3,4,5],si16> -> !torch.vtensor<[3,4,5],si16> @@ -248,6 +310,8 @@ func.func @test_bitwise_xor_i16_3d(%arg0: !torch.vtensor<[3,4,5],si16>, %arg1: ! return %0 : !torch.vtensor<[3,4,5],si16> } +// ----- + // CHECK-LABEL: @test_bitwise_xor_i32_2d func.func @test_bitwise_xor_i32_2d(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !torch.vtensor<[3,4],si32>) -> !torch.vtensor<[3,4],si32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_xor.Tensor %arg0, %arg1 : !torch.vtensor<[3,4],si32>, !torch.vtensor<[3,4],si32> -> !torch.vtensor<[3,4],si32> @@ -255,6 +319,8 @@ func.func @test_bitwise_xor_i32_2d(%arg0: !torch.vtensor<[3,4],si32>, %arg1: !to return %0 : !torch.vtensor<[3,4],si32> } +// ----- + // CHECK-LABEL: @test_bitwise_xor_ui8_bcast_4v3d func.func @test_bitwise_xor_ui8_bcast_4v3d(%arg0: !torch.vtensor<[3,4,5,6],ui8>, %arg1: !torch.vtensor<[4,5,6],ui8>) -> !torch.vtensor<[3,4,5,6],ui8> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_xor.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5,6],ui8>, !torch.vtensor<[4,5,6],ui8> -> !torch.vtensor<[3,4,5,6],ui8> @@ -262,6 +328,8 @@ func.func @test_bitwise_xor_ui8_bcast_4v3d(%arg0: !torch.vtensor<[3,4,5,6],ui8>, return %0 : !torch.vtensor<[3,4,5,6],ui8> } +// ----- + // CHECK-LABEL: @test_cast_BFLOAT16_to_FLOAT func.func @test_cast_BFLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],bf16>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 6 @@ -272,6 +340,8 @@ func.func @test_cast_BFLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],bf16>) -> !to return %0 : !torch.vtensor<[3,4],f32> } +// ----- + // CHECK-LABEL: @test_cast_DOUBLE_to_FLOAT func.func @test_cast_DOUBLE_to_FLOAT(%arg0: !torch.vtensor<[3,4],f64>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 6 @@ -282,6 +352,8 @@ func.func @test_cast_DOUBLE_to_FLOAT(%arg0: !torch.vtensor<[3,4],f64>) -> !torch return %0 : !torch.vtensor<[3,4],f32> } +// ----- + // CHECK-LABEL: @test_cast_DOUBLE_to_FLOAT16 func.func @test_cast_DOUBLE_to_FLOAT16(%arg0: !torch.vtensor<[3,4],f64>) -> !torch.vtensor<[3,4],f16> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 5 @@ -292,6 +364,8 @@ func.func @test_cast_DOUBLE_to_FLOAT16(%arg0: !torch.vtensor<[3,4],f64>) -> !tor return %0 : !torch.vtensor<[3,4],f16> } +// ----- + // CHECK-LABEL: @test_cast_FLOAT_to_BFLOAT16 func.func @test_cast_FLOAT_to_BFLOAT16(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],bf16> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 15 @@ -302,6 +376,8 @@ func.func @test_cast_FLOAT_to_BFLOAT16(%arg0: !torch.vtensor<[3,4],f32>) -> !tor return %0 : !torch.vtensor<[3,4],bf16> } +// ----- + // CHECK-LABEL: @test_cast_FLOAT_to_DOUBLE func.func @test_cast_FLOAT_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 7 @@ -312,6 +388,8 @@ func.func @test_cast_FLOAT_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f32>) -> !torch return %0 : !torch.vtensor<[3,4],f64> } +// ----- + // CHECK-LABEL: @test_cast_FLOAT_to_FLOAT16 func.func @test_cast_FLOAT_to_FLOAT16(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f16> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 5 @@ -322,6 +400,8 @@ func.func @test_cast_FLOAT_to_FLOAT16(%arg0: !torch.vtensor<[3,4],f32>) -> !torc return %0 : !torch.vtensor<[3,4],f16> } +// ----- + // CHECK-LABEL: @test_cast_FLOAT16_to_DOUBLE func.func @test_cast_FLOAT16_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f16>) -> !torch.vtensor<[3,4],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 7 @@ -332,6 +412,8 @@ func.func @test_cast_FLOAT16_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f16>) -> !tor return %0 : !torch.vtensor<[3,4],f64> } +// ----- + // CHECK-LABEL: @test_cast_FLOAT_to_BOOL func.func @test_cast_FLOAT_to_BOOL(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 11 @@ -342,6 +424,8 @@ func.func @test_cast_FLOAT_to_BOOL(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.v return %0 : !torch.vtensor<[3,4],i1> } +// ----- + // CHECK-LABEL: @test_cast_FLOAT16_to_FLOAT func.func @test_cast_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 6 @@ -352,6 +436,8 @@ func.func @test_cast_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>) -> !torc return %0 : !torch.vtensor<[3,4],f32> } +// ----- + // CHECK-LABEL: @test_castlike_BFLOAT16_to_FLOAT func.func @test_castlike_BFLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],bf16>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 6 @@ -362,6 +448,8 @@ func.func @test_castlike_BFLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],bf16>, %a return %0 : !torch.vtensor<[3,4],f32> } +// ----- + // CHECK-LABEL: @test_castlike_DOUBLE_to_FLOAT func.func @test_castlike_DOUBLE_to_FLOAT(%arg0: !torch.vtensor<[3,4],f64>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 6 @@ -372,6 +460,8 @@ func.func @test_castlike_DOUBLE_to_FLOAT(%arg0: !torch.vtensor<[3,4],f64>, %arg1 return %0 : !torch.vtensor<[3,4],f32> } +// ----- + // CHECK-LABEL: @test_castlike_FLOAT_to_DOUBLE func.func @test_castlike_FLOAT_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[1],f64>) -> !torch.vtensor<[3,4],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 7 @@ -382,6 +472,8 @@ func.func @test_castlike_FLOAT_to_DOUBLE(%arg0: !torch.vtensor<[3,4],f32>, %arg1 return %0 : !torch.vtensor<[3,4],f64> } +// ----- + // CHECK-LABEL: @test_castlike_FLOAT16_to_FLOAT func.func @test_castlike_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>, %arg1: !torch.vtensor<[1],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT:.*]] = torch.constant.int 6 @@ -392,6 +484,8 @@ func.func @test_castlike_FLOAT16_to_FLOAT(%arg0: !torch.vtensor<[3,4],f16>, %arg return %0 : !torch.vtensor<[3,4],f32> } +// ----- + // CHECK-LABEL: @test_ceil_example func.func @test_ceil_example(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.ceil %arg0 : !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32> @@ -399,6 +493,8 @@ func.func @test_ceil_example(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[ return %0 : !torch.vtensor<[2],f32> } +// ----- + // CHECK-LABEL: @test_ceil func.func @test_ceil(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.ceil %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -406,6 +502,8 @@ func.func @test_ceil(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_clip_default_int8_min func.func @test_clip_default_int8_min(%arg0: !torch.vtensor<[3,4,5],si8>, %arg1: !torch.vtensor<[],si8>) -> !torch.vtensor<[3,4,5],si8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.clamp_min.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si8>, !torch.vtensor<[],si8> -> !torch.vtensor<[3,4,5],si8> @@ -413,6 +511,8 @@ func.func @test_clip_default_int8_min(%arg0: !torch.vtensor<[3,4,5],si8>, %arg1: return %0 : !torch.vtensor<[3,4,5],si8> } +// ----- + // CHECK-LABEL: @test_clip_default_min func.func @test_clip_default_min(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.clamp_min.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[3,4,5],f32> @@ -420,6 +520,8 @@ func.func @test_clip_default_min(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_clip_example func.func @test_clip_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.clamp.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[3],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[3],f32> @@ -427,6 +529,8 @@ func.func @test_clip_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtens return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: @test_clip func.func @test_clip(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.clamp.Tensor %arg0, %arg1, %arg2 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[3,4,5],f32> @@ -434,6 +538,8 @@ func.func @test_clip(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[ return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_cos_example func.func @test_cos_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.cos %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> @@ -441,6 +547,8 @@ func.func @test_cos_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3 return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: @test_cos func.func @test_cos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.cos %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -497,6 +605,8 @@ func.func @test_div_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vten return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_div_example func.func @test_div_example(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32> -> !torch.vtensor<[2],f32> @@ -504,6 +614,8 @@ func.func @test_div_example(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtenso return %0 : !torch.vtensor<[2],f32> } +// ----- + // CHECK-LABEL: @test_div func.func @test_div(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -511,6 +623,8 @@ func.func @test_div(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3 return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_div_uint8 func.func @test_div_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8> -> !torch.vtensor<[3,4,5],ui8> @@ -518,6 +632,8 @@ func.func @test_div_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vten return %0 : !torch.vtensor<[3,4,5],ui8> } +// ----- + // CHECK-LABEL: @test_equal_bcast func.func @test_equal_bcast(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor<[5],si32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[5],si32> -> !torch.vtensor<[3,4,5],i1> @@ -525,6 +641,8 @@ func.func @test_equal_bcast(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.v return %0 : !torch.vtensor<[3,4,5],i1> } +// ----- + // CHECK-LABEL: @test_erf func.func @test_erf(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.erf %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -532,6 +650,8 @@ func.func @test_erf(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5 return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_equal func.func @test_equal(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.eq.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32> -> !torch.vtensor<[3,4,5],i1> @@ -539,6 +659,7 @@ func.func @test_equal(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor return %0 : !torch.vtensor<[3,4,5],i1> } +// ----- // CHECK-LABEL: @test_floor_example func.func @test_floor_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { @@ -547,6 +668,8 @@ func.func @test_floor_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor< return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: @test_floor func.func @test_floor(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.floor %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -554,6 +677,8 @@ func.func @test_floor(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4 return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_averagepool_1d_default func.func @test_averagepool_1d_default(%arg0: !torch.vtensor<[1,3,32],f32>) -> !torch.vtensor<[1,3,31],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.avg_pool1d %arg0, %0, %2, %1, %false, %true : !torch.vtensor<[1,3,32],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool -> !torch.vtensor<[1,3,31],f32> @@ -561,6 +686,8 @@ func.func @test_averagepool_1d_default(%arg0: !torch.vtensor<[1,3,32],f32>) -> ! return %0 : !torch.vtensor<[1,3,31],f32> } +// ----- + // CHECK-LABEL: @test_averagepool_2d_ceil func.func @test_averagepool_2d_ceil(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.avg_pool2d %arg0, %0, %2, %1, %true, %false, %none : !torch.vtensor<[1,1,4,4],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,1,2,2],f32> @@ -568,6 +695,8 @@ func.func @test_averagepool_2d_ceil(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !to return %0 : !torch.vtensor<[1,1,2,2],f32> } +// ----- + // CHECK-LABEL: @test_averagepool_3d_default func.func @test_averagepool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.avg_pool3d %arg0, %0, %2, %1, %false, %false_2, %none : !torch.vtensor<[1,3,32,32,32],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,31,31,31],f32> @@ -575,6 +704,8 @@ func.func @test_averagepool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32> return %0 : !torch.vtensor<[1,3,31,31,31],f32> } +// ----- + // CHECK-LABEL: @test_conv_with_strides_no_padding func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0 @@ -596,6 +727,8 @@ func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32 return %0 : !torch.vtensor<[1,1,3,2],f32> } +// ----- + // CHECK-LABEL: @test_conv_with_strides_padding func.func @test_conv_with_strides_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,4,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C1:.*]] = torch.constant.int 1 @@ -617,6 +750,8 @@ func.func @test_conv_with_strides_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, return %0 : !torch.vtensor<[1,1,4,3],f32> } +// ----- + // CHECK-LABEL: @test_conv_with_bias_strides_padding func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,224],f32>, %arg1: !torch.vtensor<[64,3,7,7],f32>, %arg2: !torch.vtensor<[64],f32>) -> !torch.vtensor<[?,64,112,112],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C3:.*]] = torch.constant.int 3 @@ -637,6 +772,8 @@ func.func @test_conv_with_bias_strides_padding(%arg0: !torch.vtensor<[?,?,224,22 return %0 : !torch.vtensor<[?,64,112,112],f32> } +// ----- + // CHECK-LABEL: @test_convtranspose_dilations func.func @test_convtranspose_dilations(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,1,2,2],f32>) -> !torch.vtensor<[1,1,5,5],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0 @@ -659,6 +796,8 @@ func.func @test_convtranspose_dilations(%arg0: !torch.vtensor<[1,1,3,3],f32>, %a return %0 : !torch.vtensor<[1,1,5,5],f32> } +// ----- + // CHECK-LABEL: @test_convtranspose func.func @test_convtranspose(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,5,5],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0 @@ -681,6 +820,8 @@ func.func @test_convtranspose(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torc return %0 : !torch.vtensor<[1,2,5,5],f32> } +// ----- + // CHECK-LABEL: @test_convtranspose_pad func.func @test_convtranspose_pad(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,10,8],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0 @@ -703,6 +844,8 @@ func.func @test_convtranspose(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torc return %0 : !torch.vtensor<[1,2,10,8],f32> } +// ----- + // CHECK-LABEL: @test_convtranspose_pads func.func @test_convtranspose_pads(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torch.vtensor<[1,2,3,3],f32>) -> !torch.vtensor<[1,2,7,3],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C1:.*]] = torch.constant.int 1 @@ -725,6 +868,8 @@ func.func @test_convtranspose(%arg0: !torch.vtensor<[1,1,3,3],f32>, %arg1: !torc return %0 : !torch.vtensor<[1,2,7,3],f32> } +// ----- + // CHECK-LABEL: @test_batchnorm_epsilon func.func @test_batchnorm_epsilon(%arg0: !torch.vtensor<[2,3,4,5],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>, %arg4: !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[FALSE:.*]] = torch.constant.bool false @@ -735,6 +880,8 @@ func.func @test_batchnorm_epsilon(%arg0: !torch.vtensor<[2,3,4,5],f32>, %arg1: ! return %0 : !torch.vtensor<[2,3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_batchnorm_example func.func @test_batchnorm_example(%arg0: !torch.vtensor<[2,3,4,5],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>, %arg4: !torch.vtensor<[3],f32>) -> !torch.vtensor<[2,3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[FALSE:.*]] = torch.constant.bool false @@ -745,6 +892,8 @@ func.func @test_batchnorm_example(%arg0: !torch.vtensor<[2,3,4,5],f32>, %arg1: ! return %0 : !torch.vtensor<[2,3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_concat_1d_axis_0 func.func @test_concat_1d_axis_0(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.list @@ -754,6 +903,8 @@ func.func @test_concat_1d_axis_0(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.v return %0 : !torch.vtensor<[4],f32> } +// ----- + // CHECK-LABEL: @test_concat_1d_axis_negative_1 func.func @test_concat_1d_axis_negative_1(%arg0: !torch.vtensor<[2],f32>, %arg1: !torch.vtensor<[2],f32>) -> !torch.vtensor<[4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.list @@ -763,6 +914,8 @@ func.func @test_concat_1d_axis_negative_1(%arg0: !torch.vtensor<[2],f32>, %arg1: return %0 : !torch.vtensor<[4],f32> } +// ----- + // CHECK-LABEL: @test_concat_2d_axis_0 func.func @test_concat_2d_axis_0(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[4,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list @@ -772,6 +925,8 @@ func.func @test_concat_2d_axis_0(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch return %0 : !torch.vtensor<[4,2],f32> } +// ----- + // CHECK-LABEL: @test_concat_2d_axis_1 func.func @test_concat_2d_axis_1(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list @@ -781,6 +936,8 @@ func.func @test_concat_2d_axis_1(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch return %0 : !torch.vtensor<[2,4],f32> } +// ----- + // CHECK-LABEL: @test_concat_2d_axis_negative_1 func.func @test_concat_2d_axis_negative_1(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list @@ -790,6 +947,8 @@ func.func @test_concat_2d_axis_negative_1(%arg0: !torch.vtensor<[2,2],f32>, %arg return %0 : !torch.vtensor<[2,4],f32> } +// ----- + // CHECK-LABEL: @test_concat_2d_axis_negative_2 func.func @test_concat_2d_axis_negative_2(%arg0: !torch.vtensor<[2,2],f32>, %arg1: !torch.vtensor<[2,2],f32>) -> !torch.vtensor<[4,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2],f32>, !torch.vtensor<[2,2],f32>) -> !torch.list @@ -799,6 +958,8 @@ func.func @test_concat_2d_axis_negative_2(%arg0: !torch.vtensor<[2,2],f32>, %arg return %0 : !torch.vtensor<[4,2],f32> } +// ----- + // CHECK-LABEL: @test_concat_3d_axis_0 func.func @test_concat_3d_axis_0(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[4,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list @@ -808,6 +969,8 @@ func.func @test_concat_3d_axis_0(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !tor return %0 : !torch.vtensor<[4,2,2],f32> } +// ----- + // CHECK-LABEL: @test_concat_3d_axis_1 func.func @test_concat_3d_axis_1(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,4,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list @@ -817,6 +980,8 @@ func.func @test_concat_3d_axis_1(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !tor return %0 : !torch.vtensor<[2,4,2],f32> } +// ----- + // CHECK-LABEL: @test_concat_3d_axis_2 func.func @test_concat_3d_axis_2(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list @@ -826,6 +991,8 @@ func.func @test_concat_3d_axis_2(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !tor return %0 : !torch.vtensor<[2,2,4],f32> } +// ----- + // CHECK-LABEL: @test_concat_3d_axis_negative_1 func.func @test_concat_3d_axis_negative_1(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list @@ -835,6 +1002,8 @@ func.func @test_concat_3d_axis_negative_1(%arg0: !torch.vtensor<[2,2,2],f32>, %a return %0 : !torch.vtensor<[2,2,4],f32> } +// ----- + // CHECK-LABEL: @test_concat_3d_axis_negative_2 func.func @test_concat_3d_axis_negative_2(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[2,4,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list @@ -844,6 +1013,8 @@ func.func @test_concat_3d_axis_negative_2(%arg0: !torch.vtensor<[2,2,2],f32>, %a return %0 : !torch.vtensor<[2,4,2],f32> } +// ----- + // CHECK-LABEL: @test_concat_3d_axis_negative_3 func.func @test_concat_3d_axis_negative_3(%arg0: !torch.vtensor<[2,2,2],f32>, %arg1: !torch.vtensor<[2,2,2],f32>) -> !torch.vtensor<[4,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[TENSORS_LIST:.*]] = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[2,2,2],f32>, !torch.vtensor<[2,2,2],f32>) -> !torch.list @@ -853,6 +1024,17 @@ func.func @test_concat_3d_axis_negative_3(%arg0: !torch.vtensor<[2,2,2],f32>, %a return %0 : !torch.vtensor<[4,2,2],f32> } +// ----- + +// CHECK-LABEL: func.func @test_exp +func.func @test_exp(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64} { + // CHECK: torch.aten.exp %arg0 : !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Exp"(%arg0) : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + // CHECK-LABEL: @test_expand_dim2_shape2 func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { @@ -868,6 +1050,9 @@ func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !tor %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32> return %0 : !torch.vtensor<[3,4],f32> } + +// ----- + // CHECK-LABEL: @test_expand_dim2_shape3 func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 @@ -886,6 +1071,8 @@ func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !tor return %0 : !torch.vtensor<[2,3,6],f32> } +// ----- + // CHECK-LABEL: @test_expand_dim3_shape4 func.func @test_expand_dim3_shape4(%arg0: !torch.vtensor<[1,3,1],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,3,3,3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 @@ -906,6 +1093,9 @@ func.func @test_expand_dim3_shape4(%arg0: !torch.vtensor<[1,3,1],f32>, %arg1: !t %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,3,1],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,3,3,3],f32> return %0 : !torch.vtensor<[3,3,3,3],f32> } + +// ----- + // CHECK-LABEL: @test_dropout func.func @test_dropout(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.dropout %arg0, %float5.000000e-01, %false : !torch.vtensor<[3],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3],f32 @@ -913,6 +1103,8 @@ func.func @test_dropout(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f3 return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: @test_dropout_default func.func @test_dropout_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.dropout %arg0, %float5.000000e-01, %false : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,5],f32> @@ -920,6 +1112,8 @@ func.func @test_dropout_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vt return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_dropout_default_mask func.func @test_dropout_default_mask(%arg0: !torch.vtensor<[3,4,5],f32>) -> (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1>) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.dropout %arg0, %float5.000000e-01, %false : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,5],f32> @@ -928,6 +1122,8 @@ func.func @test_dropout_default_mask(%arg0: !torch.vtensor<[3,4,5],f32>) -> (!to return %0#0, %0#1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1> } +// ----- + // CHECK-LABEL: @test_dropout_default_mask_ratio func.func @test_dropout_default_mask_ratio(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[],f32>) -> (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1>) attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.dropout %arg0, %0, %false : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,5],f32> @@ -936,6 +1132,8 @@ func.func @test_dropout_default_mask_ratio(%arg0: !torch.vtensor<[3,4,5],f32>, % return %0#0, %0#1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],i1> } +// ----- + // CHECK-LABEL: @test_dropout_default_ratio func.func @test_dropout_default_ratio(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.dropout %arg0, %0, %false : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,5],f32> @@ -943,6 +1141,8 @@ func.func @test_dropout_default_ratio(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_training_dropout_zero_ratio func.func @test_training_dropout_zero_ratio(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],i1>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.dropout %arg0, %0, %2 : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3,4,5],f32> @@ -950,6 +1150,8 @@ func.func @test_training_dropout_zero_ratio(%arg0: !torch.vtensor<[3,4,5],f32>, return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_elu_default func.func @test_elu_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.elu %arg0, %float0.000000e00, %float1.000000e00, %float1.000000e00 : !torch.vtensor<[3,4,5],f32>, !torch.float, !torch.float, !torch.float -> !torch.vtensor<[3,4,5],f32> @@ -957,6 +1159,8 @@ func.func @test_elu_default(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: @test_elu_example func.func @test_elu_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.elu %arg0, %float2.000000e00, %float1.000000e00, %float1.000000e00 : !torch.vtensor<[3],f32>, !torch.float, !torch.float, !torch.float -> !torch.vtensor<[3],f32> @@ -964,6 +1168,8 @@ func.func @test_elu_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3 return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: @test_depthtospace_example func.func @test_depthtospace_example(%arg0: !torch.vtensor<[1,8,2,3],f32>) -> !torch.vtensor<[1,2,4,6],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0 @@ -998,6 +1204,8 @@ func.func @test_depthtospace_example(%arg0: !torch.vtensor<[1,8,2,3],f32>) -> !t return %0 : !torch.vtensor<[1,2,4,6],f32> } +// ----- + // CHECK-LABEL: @test_depthtospace_crd_mode_example func.func @test_depthtospace_crd_mode_example(%arg0: !torch.vtensor<[1,8,2,3],f32>) -> !torch.vtensor<[1,2,4,6],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0 @@ -1072,6 +1280,8 @@ func.func @ints_constant() -> !torch.vtensor<[2], si64> attributes {torch.onnx_m return %0 : !torch.vtensor<[2],si64> } +// ----- + // CHECK-LABEL: @test_flatten_4d_axis_2 func.func @test_flatten_4d_axis_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2 @@ -1084,6 +1294,8 @@ func.func @test_flatten_4d_axis_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torc return %0 : !torch.vtensor<[6,20],f32> } +// ----- + // CHECK-LABEL: @test_flatten_4d_axis_0 func.func @test_flatten_4d_axis_0(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0 @@ -1095,6 +1307,8 @@ func.func @test_flatten_4d_axis_0(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torc return %0 : !torch.vtensor<[1,120],f32> } +// ----- + // CHECK-LABEL: @test_flatten_4d_axis_4 func.func @test_flatten_4d_axis_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[120,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 4 @@ -1106,6 +1320,8 @@ func.func @test_flatten_4d_axis_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torc return %0 : !torch.vtensor<[120,1],f32> } +// ----- + // CHECK-LABEL: @test_flatten_4d_axis_negative_2 func.func @test_flatten_4d_axis_negative_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2 @@ -1118,6 +1334,8 @@ func.func @test_flatten_4d_axis_negative_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) return %0 : !torch.vtensor<[6,20],f32> } +// ----- + // CHECK-LABEL: @test_flatten_4d_axis_negative_1 func.func @test_flatten_4d_axis_negative_1(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[24,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 3 @@ -1130,6 +1348,8 @@ func.func @test_flatten_4d_axis_negative_1(%arg0: !torch.vtensor<[2,3,4,5],f32>) return %0 : !torch.vtensor<[24,5],f32> } +// ----- + // CHECK-LABEL: @test_flatten_4d_axis_negative_4 func.func @test_flatten_4d_axis_negative_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0 @@ -1141,6 +1361,8 @@ func.func @test_flatten_4d_axis_negative_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) return %0 : !torch.vtensor<[1,120],f32> } +// ----- + // CHECK-LABEL: @test_flatten_2d_axis_1 func.func @test_flatten_2d_axis_1(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 1 @@ -1153,6 +1375,8 @@ func.func @test_flatten_2d_axis_1(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vt return %0 : !torch.vtensor<[2,3],f32> } +// ----- + // CHECK-LABEL: @test_flatten_1d_axis_0 func.func @test_flatten_1d_axis_0(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0 @@ -1164,6 +1388,8 @@ func.func @test_flatten_1d_axis_0(%arg0: !torch.vtensor<[2],f32>) -> !torch.vten return %0 : !torch.vtensor<[1,2],f32> } +// ----- + // CHECK-LABEL: @test_flatten_1d_axis_negative_1 func.func @test_flatten_1d_axis_negative_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0 @@ -1175,6 +1401,8 @@ func.func @test_flatten_1d_axis_negative_1(%arg0: !torch.vtensor<[2],f32>) -> !t return %0 : !torch.vtensor<[1,2],f32> } +// ----- + // COM: CHECK-LABEL: @test_flatten_1d_axis_1 func.func @test_flatten_1d_axis_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 1 From ccaac857885ad8ab6532e900cd4e47bb4ee1c424 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Wed, 24 Jan 2024 00:30:03 -0500 Subject: [PATCH 106/283] implement aten.conv1d, aten.conv3d, and aten.conv_tbc (#2757) convolution with [time,batch,channel] ordering, as opposed to the default [batch, channel, time]. Currently implementing by transposing the input and output, but may need to get its own implementation in the future because this is supposed to be an op that gives a speedup. This is used by fairseq (https://github.com/facebookresearch/fairseq/issues/172). (in case you were wondering like me, this is different from transposed convolution. Transposed convolution has fractional strides). --------- Co-authored-by: Xida Ren Co-authored-by: Frederik Harwath --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 113 +++++++++++++ lib/Conversion/TorchToLinalg/Linear.cpp | 70 ++++++-- .../Transforms/AbstractInterpLibrary.cpp | 159 ++++++++++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 130 +++++++++++++- .../Transforms/LowerToBackendContract.cpp | 3 + projects/pt1/e2e_testing/xfail_sets.py | 4 + .../build_tools/abstract_interp_lib_gen.py | 62 +++++++ .../build_tools/torch_ods_gen.py | 8 + .../torch_mlir_e2e_test/test_suite/conv.py | 96 +++++++++++ 9 files changed, 626 insertions(+), 19 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 7e74e698f60a..8ed176a8eae5 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -5494,6 +5494,35 @@ def Torch_AtenCosineSimilarityOp : Torch_Op<"aten.cosine_similarity", [ }]; } +def Torch_AtenConv3dOp : Torch_Op<"aten.conv3d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv3d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_IntType:$groups + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConv3dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenConv3dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ AllowsTypeRefinement, HasValueSemantics, @@ -5523,6 +5552,35 @@ def Torch_AtenConv2dOp : Torch_Op<"aten.conv2d", [ }]; } +def Torch_AtenConv1dOp : Torch_Op<"aten.conv1d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchListOfTorchIntType:$stride, + AnyTorchListOfTorchIntType:$padding, + AnyTorchListOfTorchIntType:$dilation, + Torch_IntType:$groups + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConv1dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 7, 1); + } + void AtenConv1dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 7, 1); + } + }]; +} + def Torch_AtenConvTranspose1dOp : Torch_Op<"aten.conv_transpose1d", [ AllowsTypeRefinement, HasValueSemantics, @@ -5613,6 +5671,61 @@ def Torch_AtenConvTranspose3dInputOp : Torch_Op<"aten.conv_transpose3d.input", [ }]; } +def Torch_AtenConvTbcOp : Torch_Op<"aten.conv_tbc", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv_tbc : (Tensor, Tensor, Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$weight, + AnyTorchTensorType:$bias, + Torch_IntType:$pad + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConvTbcOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void AtenConvTbcOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + +def Torch_AtenConvTbcBackwardOp : Torch_Op<"aten.conv_tbc_backward", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::conv_tbc_backward : (Tensor, Tensor, Tensor, Tensor, int) -> (Tensor, Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$input, + AnyTorchTensorType:$weight, + AnyTorchTensorType:$bias, + Torch_IntType:$pad + ); + let results = (outs + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1, + AnyTorchTensorType:$result2 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenConvTbcBackwardOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 3); + } + void AtenConvTbcBackwardOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 3); + } + }]; +} + def Torch_AtenConvolutionOp : Torch_Op<"aten.convolution", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 7c5f2c88c033..0d64a0dd2d94 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -566,9 +566,9 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return op.emitError("unimplemented: non-floating point type"); size_t inRank = input.getType().cast().getRank(); size_t numSpacialDims = inRank - 2; - if (numSpacialDims != 2) + if (numSpacialDims < 1 || numSpacialDims > 3) return rewriter.notifyMatchFailure( - op, "unimplemented: only 2D convolution currently supported"); + op, "unimplemented: only 1d-3d convolution currently supported"); Type intType = IntegerType::get(context, 64); auto castIndexToInt = [&](Value v) { @@ -796,15 +796,50 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { weightSliceSizes.append(weightDims); Value conv; + // the code so far is able to respect all numSpacialDims + // the code below this point is numSpacialDims specific and groupSize specific + // TODO: factor out the above code into a helper function, and then separate convolution into: + // - grouped 1d-3d + // - ungrouped 1d-3d if (groupSize == 1) { - // TODO: add 1D and 3D case - conv = - rewriter - .create( - loc, outputTensor.getType(), ValueRange{paddedInput, weight}, - outputTensor, stridesAttr, dilationAttr) - .getResult(0); + // TODO: 3D case + switch (numSpacialDims) { + case 1: + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, weight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); + break; + case 2: + conv = + rewriter + .create( + loc, outputTensor.getType(), ValueRange{paddedInput, weight}, + outputTensor, stridesAttr, dilationAttr) + .getResult(0); + break; + case 3: + conv = + rewriter + .create( + loc, outputTensor.getType(), ValueRange{paddedInput, weight}, + outputTensor, stridesAttr, dilationAttr) + .getResult(0); + break; + default: + return rewriter.notifyMatchFailure( + op, "unimplemented: only 1D, 2D, and 3D convolution supported"); + }; + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, conv); + return success(); } else { + if(numSpacialDims != 2) + return rewriter.notifyMatchFailure( + op, "unimplemented: only 2D grouped convolution supported"); + // Special depthwise case auto inShape = makeShapeTorchCompatible( input.getType().cast().getShape()); @@ -824,11 +859,11 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { loc, collapsedType, weight, collapsedDims); conv = rewriter - .create( - loc, outputTensor.getType(), - ValueRange{paddedInput, collapsedWeight}, outputTensor, - stridesAttr, dilationAttr) - .getResult(0); + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, collapsedWeight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, conv); @@ -902,11 +937,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { conv = rewriter.create( loc, outputTensor.getType(), conv, expandOutputTensor.getReassociationIndices()); + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, conv); + return success(); } - - Type newResultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, newResultType, conv); - return success(); } }; } // namespace diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 4b4ae748f9e9..0e6313ea8978 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8566,6 +8566,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.conv2d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.conv3d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.conv3d(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.conv_transpose2d.input\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.list {\n" " %0 = torch.derefine %arg3 : !torch.list to !torch.optional>\n" " %1 = torch.derefine %arg4 : !torch.list to !torch.optional>\n" @@ -8574,10 +8578,67 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch.jit._shape_functions.conv_transpose2d_input(%arg0, %arg1, %arg2, %0, %1, %2, %arg6, %3) : (!torch.list, !torch.list, !torch.optional>, !torch.optional>, !torch.optional>, !torch.optional>, !torch.int, !torch.optional>) -> !torch.list\n" " return %4 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.conv_tbc\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.int) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int3 = torch.constant.int 3\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.len.t %arg2 : !torch.list -> !torch.int\n" +" %5 = torch.aten.eq.int %4, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %8 = torch.aten.eq.int %6, %7 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list\n" +" %10 = torch.prim.ListConstruct %arg3 : (!torch.int) -> !torch.list\n" +" %11 = torch.prim.ListConstruct : () -> !torch.list\n" +" %12 = torch.prim.ListConstruct : () -> !torch.list\n" +" %13 = torch.derefine %arg2 : !torch.list to !torch.optional>\n" +" %14 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %13, %9, %10, %11, %false, %12, %int1) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" +" %15:3 = torch.prim.ListUnpack %14 : !torch.list -> !torch.int, !torch.int, !torch.int\n" +" %16 = torch.prim.ListConstruct %15#2, %15#0, %15#1 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %16 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.conv1d\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.prim.ListConstruct : () -> !torch.list\n" +" %1 = call @__torch__.torch.jit._shape_functions.conv_forwards(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %false, %0, %int1) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten._convolution\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool, %arg12: !torch.bool) -> !torch.list {\n" " %0 = call @\"__torch_mlir_shape_fn.aten.convolution\"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.list, !torch.list, !torch.optional>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -10839,6 +10900,100 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %11 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" " return %11 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.conv1d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.int {\n" +" %false = torch.constant.bool false\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" %12 = torch.aten.__isnot__ %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %12 = torch.aten.__isnot__ %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %11 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %11 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.conv_tbc\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.int) -> !torch.int {\n" +" %false = torch.constant.bool false\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.aten.__not__ %3 : !torch.bool -> !torch.bool\n" +" %5 = torch.prim.If %4 -> (!torch.bool) {\n" +" %12 = torch.aten.__isnot__ %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %5 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%1#1) : (!torch.int) -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" %12 = torch.aten.__isnot__ %1#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" } else {\n" +" torch.prim.If.yield %false : !torch.bool\n" +" }\n" +" torch.prim.If %8 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %9 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %10 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %11 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%9, %10) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %11 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._convolution.deprecated\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.bool, %arg7: !torch.list, %arg8: !torch.int, %arg9: !torch.bool, %arg10: !torch.bool, %arg11: !torch.bool) -> !torch.int {\n" " %false = torch.constant.bool false\n" " %int11 = torch.constant.int 11\n" @@ -10890,6 +11045,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.conv3d\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.conv_transpose2d.input\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>, %arg3: !torch.list, %arg4: !torch.list, %arg5: !torch.list, %arg6: !torch.int, %arg7: !torch.list) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 8afccbba0346..7f2fb2b53006 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2590,6 +2590,106 @@ class DecomposeAten_ConvolutionLikeOp }; } // namespace +namespace { + + static LogicalResult createTorchTransposeOpForConvTbc(PatternRewriter &rewriter, + Location loc, Value input, + int64_t dimA, int64_t dimB, + Value &transposed) { + Type transposedType; + if (failed(getTransposedType(input.getType().cast(), + dimA, dimB, transposedType))) + return failure(); + Value cstDimA = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimA)); + Value cstDimB = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimB)); + transposed = rewriter.create( + loc, transposedType, input, cstDimA, cstDimB); + return success(); + } + + class DecomposeAtenConvTbcOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenConvTbcOp op, + PatternRewriter &rewriter) const override { + Value emptyList = rewriter.create( + op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + SmallVector()); + Value zeroList = rewriter.create( + op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + SmallVector{rewriter.create(op.getLoc(), rewriter.getI64IntegerAttr(0))}); + Value cstFalse = rewriter.create(op.getLoc(), false); + Value oneList = rewriter.create( + op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + SmallVector{rewriter.create(op.getLoc(), rewriter.getI64IntegerAttr(1))}); + Value padding = rewriter.create( + op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + SmallVector{op.getPad()}); + Value groups = rewriter.create(op.getLoc(), rewriter.getI64IntegerAttr(1)); + + // convtbc has WNC layout for input and output + // and WCF layout for weight + // whereas Convolution is going to use Conv1DNcwFcwOp for 1d + // which means we need the inputs in NCW and the weight in FCW + Value selfWnc = op.getSelf(); + Value selfNwc; + Value selfNcw; + if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfWnc, 0, 1, selfNwc))) + return rewriter.notifyMatchFailure(op, "failed to transpose input to Nwc"); + if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfNwc, 1, 2, selfNcw))) + return rewriter.notifyMatchFailure(op, "failed to transpose input to Ncw"); + + Value weightWcf = op.getWeight(); + Value weightFcw; + if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), weightWcf, 0, 2, weightFcw))) + return rewriter.notifyMatchFailure(op, "failed to transpose weight to Fcw"); + + + Value outputNcw = rewriter.create( + op.getLoc(), op->getResultTypes(), selfNcw, weightFcw, op.getBias(), /*stride*/oneList, + /*padding*/ padding, /*dilation*/ oneList, + /*transpose*/ cstFalse, /*output_padding*/ emptyList, + groups); + + // convert output from Ncw to Wnc + Value outputNwc; + Value outputWnc; + if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), outputNcw, 1, 2, outputNwc))) + return rewriter.notifyMatchFailure(op, "failed to transpose output to Nwc"); + if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), outputNwc, 0, 1, outputWnc))) + return rewriter.notifyMatchFailure(op, "failed to transpose output to Wnc"); + rewriter.replaceOp(op, outputWnc); + + return success(); + } + }; +} + + +// Decompose aten.conv1d to aten.convolution +namespace { +class DecomposeAtenConv1dOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenConv1dOp op, + PatternRewriter &rewriter) const override { + + Value emptyList = rewriter.create( + op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + SmallVector()); + Value cstFalse = rewriter.create(op.getLoc(), false); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), + op.getStride(), op.getPadding(), op.getDilation(), cstFalse, emptyList, + op.getGroups()); + + return success(); + } +}; +} // namespace + // Decompose aten.conv2d to aten.convolution namespace { class DecomposeAtenConv2dOp : public OpRewritePattern { @@ -2612,6 +2712,28 @@ class DecomposeAtenConv2dOp : public OpRewritePattern { }; } // namespace +// Decompose aten.conv3d to aten.convolution +namespace { +class DecomposeAtenConv3dOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenConv3dOp op, + PatternRewriter &rewriter) const override { + + Value emptyList = rewriter.create( + op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + SmallVector()); + Value cstFalse = rewriter.create(op.getLoc(), false); + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getInput(), op.getWeight(), op.getBias(), + op.getStride(), op.getPadding(), op.getDilation(), cstFalse, emptyList, + op.getGroups()); + + return success(); + } +}; +} // namespace + // Decompose aten.conv_transpose2d to aten.convolution namespace { class DecomposeAtenConvTranspose2dOp @@ -6531,7 +6653,6 @@ class DecomposeComplexOpsPass DecomposeAten_ConvolutionLikeOp>( patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -6643,6 +6764,13 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + // More specific conv ops + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + + GreedyRewriteConfig config; config.useTopDownTraversal = true; diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index da7811ad0a3f..34874cb59635 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -412,7 +412,10 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 341dad1e6192..7fd1ba55bce4 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -315,6 +315,9 @@ # Dynamo does not support tracing quantized tensors "ElementwiseDequantizePerTensorModule_basic", "ElementwiseQuantizePerTensorModule_basic", + + # Dynamo not supporting conv_tbc + "ConvTbcModule_basic", } TORCHDYNAMO_CRASHING_SET = { @@ -1417,6 +1420,7 @@ "PixelShuffleModuleFullDynamic_basic", "PixelShuffleModuleSpatiallyDynamic_basic", "PixelShuffleModuleSpatiallyStatic_basic", + "ConvTbcModule_basic", "_Convolution2DAllFalseModule_basic", "_Convolution2DBenchmarkModule_basic", "_Convolution2DCudnnModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 9876864a86d8..ca5d983f2c8d 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1299,12 +1299,50 @@ def aten〇view_as_real〡dtype(self_rank_dtype: Tuple[int, int]) -> int: def aten〇conv2d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), groups: int = 1) -> List[int]: return upstream_shape_functions.conv2d(input, weight, bias, stride, padding, dilation, groups) +def aten〇conv3d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1, 1,), padding: List[int] = (0, 0, 0,), dilation: List[int] = (1, 1, 1,), groups: int = 1) -> List[int]: + return upstream_shape_functions.conv3d(input, weight, bias, stride, padding, dilation, groups) + def aten〇conv_transpose2d〇input〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1, 1,), padding: List[int] = (0, 0,), output_padding: List[int] = (0, 0,), groups: int = 1, dilation: List[int] = (1, 1,)) -> List[int]: return upstream_shape_functions.conv_transpose2d_input(input, weight, bias, stride, padding, output_padding, groups, dilation) +def aten〇conv_tbc〡shape(self: List[int], weight: List[int], bias: List[int], pad: int = 0) -> List[int]: + assert len(self) == 3 # only 1d is supported by tbc + assert len(weight) == 3 + assert len(bias) == 1 + + # tbc -> bct + time = self[0] + batch = self[1] + channels = self[2] + + kernel_width = weight[0] + channels_w = weight[1] + out_channels = weight[2] + + # out_channels_b = bias[0] + + assert channels == channels_w + # the out_channels in weights and biases should also match, but this assert doesn't work because typing problems + # assert out_channels == out_channels_b + + self_bct = [batch, channels, time] + weight_bct = [out_channels, channels, kernel_width] + bias_bct = bias + + # use existing shape inf + output_size_bct = upstream_shape_functions.conv_forwards(self, weight, bias, stride=[1], padding=[pad], dilation=[], transposed=False, output_padding=[], groups=1) + + batch_out, channels_out, time_out = output_size_bct + + # bct -> tbc + return [time_out, batch_out, channels_out] + def aten〇convolution〡shape(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int) -> List[int]: return upstream_shape_functions.conv_forwards(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) +def aten〇conv1d〡shape(input: List[int], weight: List[int], bias: Optional[List[int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), dilation: List[int] = (1,), groups: int = 1) -> List[int]: + return upstream_shape_functions.conv_forwards(input, weight, bias, stride, padding, dilation, transposed=False, output_padding=[], groups=1) + def aten〇_convolution〡shape(input: List[int], weight: List[int], bias: Optional[List[int]], stride: List[int], padding: List[int], dilation: List[int], transposed: bool, output_padding: List[int], groups: int, benchmark: bool, deterministic: bool, cudnn_enabled: bool, allow_tf32: bool) -> List[int]: return aten〇convolution〡shape(input, weight, bias, stride, padding, dilation, transposed, output_padding, groups) @@ -3043,6 +3081,26 @@ def aten〇_convolution〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_d dtypes = [input_dtype, weight_dtype] return promote_dtypes(ranks, dtypes) +def aten〇conv1d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1,), padding: List[int] = (0,), dilation: List[int] = (1,), groups: int = 1) -> int: + input_rank, input_dtype = input_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert input_dtype == weight_dtype + assert not is_complex_dtype(input_dtype) and input_dtype is not torch.bool + assert not is_complex_dtype(weight_dtype) and weight_dtype is not torch.bool + ranks: List[Optional[int]] = [input_rank, weight_rank] + dtypes = [input_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + +def aten〇conv_tbc〡dtype(self_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Tuple[int, int], pad: int = 0) -> int: + self_rank, self_dtype = self_rank_dtype + weight_rank, weight_dtype = weight_rank_dtype + assert self_dtype == weight_dtype + assert not is_complex_dtype(self_dtype) and self_dtype is not torch.bool + assert not is_complex_dtype(weight_dtype) and weight_dtype is not torch.bool + ranks: List[Optional[int]] = [self_rank, weight_rank] + dtypes = [self_dtype, weight_dtype] + return promote_dtypes(ranks, dtypes) + _convolution_deprecated_kwargs = { "stride" : [1, 1], "padding" : [0, 0], "dilation" : [1, 1], "transposed" : False, "output_padding" : [0, 0], "groups" : 1, "benchmark" : False, "deterministic" : False, "cudnn_enabled" : False} @@ -3089,6 +3147,10 @@ def aten〇conv2d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: input_rank, input_dtype = input_rank_dtype return input_dtype +def aten〇conv3d〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None, stride: List[int] = (1, 1, 1,), padding: List[int] = (0, 0, 0,), dilation: List[int] = (1, 1, 1,), groups: int = 1) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + @check_dtype_function( _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1, 1, 1), (1, 1, 1, 1)]) + [Invocation(TensorOfShape(1, 1, 1, 1, dtype=torch.bool), TensorOfShape(1, 1, 1, 1, dtype=torch.float32)), diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 6099fd64e5f4..45f580ba5e13 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -406,12 +406,20 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::matmul : (Tensor, Tensor) -> (Tensor)") emit("aten::mv : (Tensor, Tensor) -> (Tensor)") emit("aten::cosine_similarity : (Tensor, Tensor, int, float) -> (Tensor)") + emit( + "aten::conv3d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" + ) emit( "aten::conv2d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" ) + emit( + "aten::conv1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int) -> (Tensor)" + ) emit("aten::conv_transpose1d : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)") emit("aten::conv_transpose2d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)") emit("aten::conv_transpose3d.input : (Tensor, Tensor, Tensor?, int[], int[], int[], int, int[]) -> (Tensor)") + emit("aten::conv_tbc : (Tensor, Tensor, Tensor, int) -> (Tensor)") + emit("aten::conv_tbc_backward : (Tensor, Tensor, Tensor, Tensor, int) -> (Tensor, Tensor, Tensor)") emit("aten::convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int) -> (Tensor)") emit("aten::_convolution : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool, bool) -> (Tensor)") emit("aten::_convolution.deprecated : (Tensor, Tensor, Tensor?, int[], int[], int[], bool, int[], int, bool, bool, bool) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index 5fc443d98605..f75e17a4f6cd 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -761,3 +761,99 @@ def forward(self, inputVec): @register_test_case(module_factory=lambda: UpSampleNearest2dSameFactor()) def UpSampleNearest2dStaticFactor_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4, 4)) +class Conv1dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ]) + def forward(self, inputVec, weight, bias): + return torch.ops.aten.conv1d(inputVec, + weight, + bias=bias, + stride=[1], + padding=[0], + dilation=[1], + groups=1) +@register_test_case(module_factory=lambda: Conv1dModule()) +def Conv1dModule_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 2, 6) + weight = torch.randn(8, 2, 3) + bias = torch.randn(8) + module.forward(inputVec, weight, bias) + +class Conv2dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ]) + def forward(self, inputVec, weight, bias): + return torch.ops.aten.conv2d(inputVec, + weight, + bias=bias, + stride=[1, 1], + padding=[0, 0], + dilation=[1, 1], + groups=1) +@register_test_case(module_factory=lambda: Conv2dModule()) +def Conv2dModule_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 2, 6, 6) + weight = torch.randn(8, 2, 3, 3) + bias = torch.randn(8) + module.forward(inputVec, weight, bias) + +class Conv3dModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ]) + def forward(self, inputVec, weight, bias): + return torch.ops.aten.conv3d(inputVec, + weight, + bias=bias, + stride=[1, 1, 1], + padding=[0, 0, 0], + dilation=[1, 1, 1], + groups=1) +@register_test_case(module_factory=lambda: Conv3dModule()) +def Conv3dModule_basic(module, tu: TestUtils): + inputVec = tu.rand(2, 2, 6, 6, 6) + weight = torch.randn(8, 2, 3, 3, 3) + bias = torch.randn(8) + module.forward(inputVec, weight, bias) + +class ConvTbcModule(torch.nn.Module): + def __init__(self): + super().__init__() + + # shapes from https://github.com/pytorch/pytorch/blob/3e8c8ce37bbfaafa8581fb48506c0a70ea54463d/test/nn/test_convolution.py#L623 + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1], torch.float32, True), + ]) + def forward(self, x, weight, bias): + return torch.conv_tbc(x, weight, bias) + +@register_test_case(module_factory=lambda: ConvTbcModule()) +def ConvTbcModule_basic(module, tu: TestUtils): + module.forward(tu.rand(9, 4, 5), tu.rand(3, 5, 6), tu.rand(6)) From 311b6b0286bfa016346bc7fd8b441bbd50216060 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Wed, 24 Jan 2024 15:55:12 +0530 Subject: [PATCH 107/283] CI: Fix Roll PyTorch CI failure at determining commit hash (#2796) Signed-Off By: Vivek Khandelwal --- .github/workflows/RollPyTorch.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/RollPyTorch.yml b/.github/workflows/RollPyTorch.yml index 5c8d74ee0941..4f2d9d8c509a 100644 --- a/.github/workflows/RollPyTorch.yml +++ b/.github/workflows/RollPyTorch.yml @@ -68,7 +68,7 @@ jobs: printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorchvision==%s\n" "${VISION_RELEASE}" > torchvision-requirements.txt # Read the commit hash from the downloaded whl file without extracting it - PT_HASH=$(unzip -p torch-"${PT_RELEASE}"*.whl torch/version.py | grep git_version | awk '{ print $3 }' | tr -d "'") + PT_HASH=$(unzip -p torch-"${PT_RELEASE}"*.whl torch/version.py | grep git_version | tail -1 | awk '{ print $3 }' | tr -d "'") echo "Found torch commit hash ${PT_HASH}" PT_HASH_CHANGED=0 From c531f5495bf2046d86bb76285f0d5d23076c71f8 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Wed, 24 Jan 2024 11:09:56 -0600 Subject: [PATCH 108/283] AtenAdaptiveMaxPool2d Conversion to Linalg (#2779) The logic here is very similar to the conversion for AdaptiveAvgPool1d #2661 with a few modifications: 1. buffVal = -inf instead of 0 2. the main linalg generic op accumulates a max, instead of a sum, to the first output tensor 3. avg pooling requires dividing the sum pool by the kernel width, which we stored as an auxilliary tensor (kSizeTensor). Here, the auxiliary tensor will be recording the indices. Strangely enough, the only signature available for this function is to return indices, and it appears that they must be computed whether the user desires them or not. See [pytorch/torch/nn/functional.py](https://github.com/pytorch/pytorch/blob/main/torch/nn/functional.py#L1174). Before writing other adaptive pooling conversions, the logic of this decomposition should be rolled into a helper function that will work for both max and avg pooling ops. Even the auxiliary tensor should likely be automated. This code was written in a slightly more tedious way than strictly necessary (often using loops to fill SmallVectors up to rank-2, which is only two in this case), in order to more easily facilitate the transition to a helper function. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 +++ lib/Conversion/TorchToLinalg/Pooling.cpp | 204 ++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 73 +++++++ .../build_tools/abstract_interp_lib_gen.py | 23 ++ .../build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/pooling.py | 81 +++++++ 6 files changed, 407 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 8ed176a8eae5..a46c79acb941 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6856,6 +6856,31 @@ def Torch_Aten_AdaptiveAvgPool3dBackwardOp : Torch_Op<"aten._adaptive_avg_pool3d }]; } +def Torch_AtenAdaptiveMaxPool2dOp : Torch_Op<"aten.adaptive_max_pool2d", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::adaptive_max_pool2d : (Tensor, int[]) -> (Tensor, Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$output_size + ); + let results = (outs + AnyTorchTensorType:$result0, + AnyTorchTensorType:$result1 + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAdaptiveMaxPool2dOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 2); + } + void AtenAdaptiveMaxPool2dOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 2); + } + }]; +} + def Torch_AtenTopkOp : Torch_Op<"aten.topk", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 14d2c71dbc92..eed79072d0f9 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -793,6 +793,208 @@ class ConvertAtenAdaptiveAvgPool1dOp }; } // namespace +// The logic for this conversion is similar to the AdaptiveAvgPool1dOp +// conversion. Before writing any more adaptive pooling conversions, the logic +// in this should be off-loaded to a helper function, since each of the adaptive +// ops are essentially the same with some minor tweaks. Instead of kSizeTensor, +// we named the additional output of the linalg generic op auxTensor. +// For max pooling, auxTensor holds the indices of max values, and for +// avg pooling, the auxTensor will be kSizeTensor, used to later divide the +// sum pool by the kernel size. +namespace { +class ConvertAtenAdaptiveMaxPool2dOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenAdaptiveMaxPool2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op->getLoc(); + const TypeConverter *typeConverter = getTypeConverter(); + + // get rank of input (same as rank of output) + int64_t rank = + adaptor.getSelf().getType().cast().getRank(); + // input operand should be NCHW (i.e. rank 4) + if (rank != 4) { + return rewriter.notifyMatchFailure(op, "only supports input type NCHW"); + } + + // input tensor and output shape + Value input = adaptor.getSelf(); + Value outputShape = op.getOutputSize(); + SmallVector outShapeVector; + getListConstructElements(outputShape, outShapeVector); + outShapeVector = + getTypeConvertedValues(rewriter, loc, typeConverter, outShapeVector); + SmallVector inputSpatialSizes; + for (unsigned i = 2; i < rank; i++) { + inputSpatialSizes.push_back(getDimOp(rewriter, loc, input, i)); + } + SmallVector outShapeIndexVector; + for (auto v : outShapeVector) { + outShapeIndexVector.push_back(castIntToIndex(rewriter, loc, v)); + } + RankedTensorType inputType = input.getType().cast(); + RankedTensorType outputType = + typeConverter->convertType(op.getResult0().getType()) + .cast(); + + // get elementType of input tensor + Type elementType = inputType.getElementType(); + + // make an iteration space of size kMax = 1 + ceildiv (hIn - 1) , hOut + Type boolType = rewriter.getI1Type(); + SmallVector kIterSizeVector; + Value constantOne = + rewriter.create(loc, rewriter.getIndexAttr(1)); + for (int i = 0; i < rank - 2; i++) { + Value hInPlusOne = rewriter.create( + loc, inputSpatialSizes[i], constantOne); + Value kMaxMinusOne = rewriter.create( + loc, hInPlusOne, outShapeIndexVector[i]); + Value kMax = + rewriter.create(loc, constantOne, kMaxMinusOne); + kIterSizeVector.push_back(kMax); + } + Value kIter = rewriter.create( + loc, getAsOpFoldResult(kIterSizeVector), boolType); + + // need to buffer input, else there will possibly be an out of bounds access + // later buffVal = 0 for avg pooling and -inf for max pooling + auto smallestFPValueAttr = rewriter.getFloatAttr( + elementType, + APFloat::getInf(elementType.cast().getFloatSemantics(), + /*Negative=*/true)); + Value buffVal = rewriter.create(loc, elementType, + smallestFPValueAttr); + SmallVector lowPadding(rank, 0); + SmallVector highPadding(2, 0); + for (int i = 0; i < rank - 2; i++) { + highPadding.push_back(1); + } + Value buffInput = torch_to_linalg::getPaddedTensor( + op, rewriter, input, lowPadding, highPadding, buffVal); + + // make a list of outputSizes + SmallVector outputSizes; + for (unsigned i = 0; i < 2; i++) { + outputSizes.push_back(getDimOp(rewriter, loc, input, i)); + } + for (unsigned i = 2; i < rank; i++) { + outputSizes.push_back(outShapeIndexVector[i - 2]); + } + + // for avg pooling the auxTensor should hold kernel widths (kSizeTensor) + // for max Pooling, it should hold the indices + RankedTensorType outputType1 = + typeConverter->convertType(op.getResult1().getType()) + .cast(); + Type indicesType = outputType1.getElementType(); + Value auxTensor = rewriter.create( + loc, getAsOpFoldResult(outputSizes), indicesType); + + // initialize an output tensor + Value initOutput = + createInitTensor(rewriter, loc, outputSizes, elementType, buffVal); + + // setup indexing maps and iterator types for linalg generic op (outputShape + // (rank),kIter (rank -2)) for kIter (d0,d1,d2,d3,d4,d5) -> (d4,d5) for + // output (d0,d1,d2,d3,d4,d5) -> (d0,d1,d2,d3) for auxTensor + // (d0,d1,d2,d3,d4,d5) -> (d0,d1,d2,d3) (or (d2,d3) for avg pooling) + SmallVector kIterExprs, outputExprs, auxTensorExprs; + // batch + channel + output spatial dims + for (unsigned i = 0; i < rank; i++) { + outputExprs.push_back(rewriter.getAffineDimExpr(i)); + auxTensorExprs.push_back(rewriter.getAffineDimExpr(i)); + } + // kIter covers last rank-2 indices + for (unsigned i = rank; i < 2 * rank - 2; i++) { + kIterExprs.push_back(rewriter.getAffineDimExpr(i)); + } + SmallVector indexingMaps = + AffineMap::inferFromExprList({kIterExprs, outputExprs, auxTensorExprs}); + SmallVector iteratorTypes( + rank, utils::IteratorType::parallel); + for (unsigned i = 0; i < rank - 2; i++) { + iteratorTypes.push_back(utils::IteratorType::reduction); + } + Value indexOne = rewriter.create(loc, 1); + auto maxPool = rewriter.create( + loc, /*resultTensorTypes=*/ + TypeRange({initOutput.getType(), auxTensor.getType()}), + /*inputs=*/ValueRange({kIter}), + /*outputs=*/ValueRange({initOutput, auxTensor}), + /*indexingMaps=*/indexingMaps, + /*iteratorTypes=*/iteratorTypes, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value res = args[1]; + Value maxIndex = args[2]; + SmallVector ind; + for (unsigned i = 0; i < 2 * rank - 2; i++) { + ind.push_back(b.create(loc, i)); + } + // compute start and end indices + // st = s1( s0(ind2 * Hin) // Hout ) + SmallVector starts; + SmallVector ends; + for (unsigned i = 2; i < rank; i++) { + Value s0 = + b.create(loc, ind[i], inputSpatialSizes[i - 2]); + Value s1 = b.create( + loc, s0, outShapeIndexVector[i - 2]); + starts.push_back(s1); + // en = e4( 1 + e3( e2( e1( e0(ind2 + 1) * hIn ) - 1 ) // hOut ) ) + Value e0 = b.create(loc, ind[i], indexOne); + Value e1 = + b.create(loc, e0, inputSpatialSizes[i - 2]); + Value e2 = b.create(loc, e1, indexOne); + Value e3 = b.create( + loc, e2, outShapeIndexVector[i - 2]); + Value e4 = b.create(loc, indexOne, e3); + ends.push_back(e4); + } + SmallVector inputElementIndices; + inputElementIndices.push_back(ind[0]); + inputElementIndices.push_back(ind[1]); + for (unsigned i = 2; i < rank; i++) { + inputElementIndices.push_back( + b.create(loc, starts[i - 2], ind[rank - 2 + i])); + } + Value inElt = b.create(loc, elementType, buffInput, + inputElementIndices); + // check if we extracted at windex < end index + for (unsigned i = 0; i < rank - 2; i++) { + Value cond = + b.create(loc, arith::CmpIPredicate(6), + inputElementIndices[i + 2], ends[i]); + inElt = b.create(loc, cond, inElt, buffVal); + } + Value cond1 = b.create(loc, arith::CmpFPredicate::OGT, + inElt, res); + // index location is (ih * input_width + iw) + Value indexOut0 = b.create(loc, inputElementIndices[2], + inputSpatialSizes[1]); + Value indexOut1 = + b.create(loc, indexOut0, inputElementIndices[3]); + Value indexOut1Int = castIndexToInt64(b, loc, indexOut1); + Value indexOut2 = + b.create(loc, cond1, indexOut1Int, maxIndex); + Value out2 = b.create(loc, cond1, inElt, res); + b.create(loc, ValueRange({out2, indexOut2})); + }); + + Value maxValues = rewriter.create( + loc, outputType, maxPool.getResultTensors()[0]); + Value outputIndices = rewriter.create( + loc, outputType1, maxPool.getResultTensors()[1]); + rewriter.replaceOp(op, {maxValues, outputIndices}); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -813,4 +1015,6 @@ void mlir::torch::torch_to_linalg::populatePoolingPatternsAndLegality( typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 0e6313ea8978..590bea8d7176 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7964,6 +7964,73 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.adaptive_avg_pool2d(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.adaptive_max_pool2d\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.tuple, list> {\n" +" %0 = call @__torch__.adaptive_max_pool2d(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.tuple, list>\n" +" return %0 : !torch.tuple, list>\n" +" }\n" +" func.func @__torch__.adaptive_max_pool2d(%arg0: !torch.list, %arg1: !torch.list) -> !torch.tuple, list> {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %int2 = torch.constant.int 2\n" +" %int3 = torch.constant.int 3\n" +" %int4 = torch.constant.int 4\n" +" %int0 = torch.constant.int 0\n" +" %0 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %3 = torch.aten.eq.int %2, %int3 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %11 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %12 = torch.aten.eq.int %11, %int4 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %12 : !torch.bool\n" +" }\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %5, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %11 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.ne.int %11, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %12 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %6 = torch.prim.ListConstruct : () -> !torch.list\n" +" %7 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %8 = torch.aten.sub.int %7, %int2 : !torch.int, !torch.int -> !torch.int\n" +" torch.prim.Loop %8, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %11 = torch.aten.__getitem__.t %arg0, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.append.t %6, %11 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %9 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" torch.prim.Loop %9, %true, init() {\n" +" ^bb0(%arg2: !torch.int):\n" +" %11 = torch.aten.__getitem__.t %arg1, %arg2 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.append.t %6, %11 : !torch.list, !torch.int -> !torch.list\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %10 = torch.prim.TupleConstruct %6, %6 : !torch.list, !torch.list -> !torch.tuple, list>\n" +" return %10 : !torch.tuple, list>\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.flatten.using_ints\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.flatten(%arg0, %arg1, %arg2) : (!torch.list, !torch.int, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -9896,6 +9963,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" " return %1 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.adaptive_max_pool2d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.tuple {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.prim.TupleConstruct %0#1, %int4 : !torch.int, !torch.int -> !torch.tuple\n" +" return %1 : !torch.tuple\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.mish\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index ca5d983f2c8d..28e87cc60990 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -902,6 +902,24 @@ def aten〇avg_pool2d〡shape(self: List[int], kernel_size: List[int], stride: L def aten〇adaptive_avg_pool2d〡shape(self: List[int], output_size: List[int]) -> List[int]: return upstream_shape_functions.adaptive_avg_pool2d(self, output_size) +def adaptive_max_pool2d(self: List[int], out: List[int]): + assert len(out) == 2 + assert len(self) == 3 or len(self) == 4 + + for i in range(len(self)): + assert self[i] != 0 + + shape: List[int] = [] + for i in range(len(self) - 2): + shape.append(self[i]) + for j in range(len(out)): + shape.append(out[j]) + + return shape, shape + +def aten〇adaptive_max_pool2d〡shape(self: List[int], output_size: List[int]) -> Tuple[List[int], List[int]]: + return adaptive_max_pool2d(self, output_size) + def aten〇flatten〇using_ints〡shape(self: List[int], start_dim: int = 0, end_dim: int = -1) -> List[int]: return upstream_shape_functions.flatten(self, start_dim, end_dim) @@ -2334,6 +2352,11 @@ def aten〇max_pool2d_with_indices〡dtype(self_rank_dtype: Tuple[int, int], ker self_rank, self_dtype = self_rank_dtype return self_dtype, torch.int64 +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3, 5, 7)], output_size=[2, 2])) +def aten〇adaptive_max_pool2d〡dtype(self_rank_dtype: Tuple[int, int], output_size: List[int]) -> Tuple[int, int]: + self_rank, self_dtype = self_rank_dtype + return self_dtype, torch.int64 + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇mish〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 45f580ba5e13..ae4c608c6de7 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -502,6 +502,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)") emit("aten::_adaptive_avg_pool3d : (Tensor, int[]) -> (Tensor)") emit("aten::_adaptive_avg_pool3d_backward : (Tensor, Tensor) -> (Tensor)") + emit("aten::adaptive_max_pool2d : (Tensor, int[]) -> (Tensor, Tensor)") emit("aten::topk : (Tensor, int, int, bool, bool) -> (Tensor, Tensor)") emit("aten::transpose.int : (Tensor, int, int) -> (Tensor)") emit("aten::pixel_shuffle : (Tensor, int) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index b19596be7031..1d3481196e5f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -1061,3 +1061,84 @@ def forward(self, x): def AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic( module, tu: TestUtils): module.forward(tu.rand(1, 512, 7)) + +class AdaptiveMaxPool2dDynamic(torch.nn.Module): + + def __init__(self): + super().__init__() + self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=False) + + @export + @annotate_args([ + None, + ([-1,-1,-1,-1], torch.float32, True) + ]) + def forward(self,x): + return self.amp2d(x) + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool2dDynamic()) +def AdaptiveMaxPool2dDynamic_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10, 16)) + +class AdaptiveMaxPool2dDynamicWithIndices(torch.nn.Module): + + def __init__(self): + super().__init__() + self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=True) + + @export + @annotate_args([ + None, + ([-1,-1,-1,-1], torch.float32, True) + ]) + def forward(self,x): + return self.amp2d(x) + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool2dDynamicWithIndices()) +def AdaptiveMaxPool2dDynamicWithIndices_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10, 16)) + + +class AdaptiveMaxPool2dStatic(torch.nn.Module): + + def __init__(self): + super().__init__() + self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=False) + + @export + @annotate_args([ + None, + ([1, 512, 10, 9], torch.float32, True) + ]) + def forward(self,x): + return self.amp2d(x) + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool2dStatic()) +def AdaptiveMaxPool2dStatic_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10, 9)) + +class AdaptiveMaxPool2dStaticWithIndices(torch.nn.Module): + + def __init__(self): + super().__init__() + self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=True) + + @export + @annotate_args([ + None, + ([1, 512, 10, 16], torch.float32, True) + ]) + def forward(self,x): + return self.amp2d(x) + +@register_test_case( + module_factory=lambda: AdaptiveMaxPool2dStaticWithIndices()) +def AdaptiveMaxPool2dStaticWithIndices_basic( + module, tu: TestUtils): + module.forward(tu.rand(1, 512, 10, 16)) \ No newline at end of file From ac8975ea1276e1f53f1bb3eaedc26f32252c91a1 Mon Sep 17 00:00:00 2001 From: Phaneesh Barwaria Date: Wed, 24 Jan 2024 22:56:21 +0530 Subject: [PATCH 109/283] [MLIR] [ONNX] lowering for onnx tile op and sign op (#2725) --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 53 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 29 ++++++++++ 2 files changed, 82 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 8d2c50c08dcc..b8468773549c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1445,6 +1445,47 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( none, none, none); return success(); }); + patterns.onOp( + "Tile", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + Value repeatDims; + if (binder.tensorOperands(operand, repeatDims) || + binder.tensorResultType(resultType)) + return failure(); + + // convert repeatDims tensor to list of ints + auto repeatDimsSizes = + dyn_cast(repeatDims.getType()).getSizes(); + SmallVector dimList; + SmallVector selectSizes; + selectSizes.push_back(1); + Torch::BaseTensorType shapeType = + repeatDims.getType().cast(); + Type selectResultType = shapeType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype()); + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + for (int i = 0; i < repeatDimsSizes[0]; i++) { + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, repeatDims, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + dimList.push_back(dim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + + rewriter.replaceOpWithNewOp(binder.op, resultType, + operand, dimValueList); + return success(); + }); patterns.onOp( "Topk", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType Values_type, Indices_type; @@ -1476,4 +1517,16 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( cstSorted); return success(); }); + patterns.onOp("Sign", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index f18f28a60d2d..22b529d12c90 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1333,3 +1333,32 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32> %0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> } + +// ----- + +// CHECK-LABEL: func.func @test_tile +func.func @test_tile(%arg0: !torch.vtensor<[2, 3, 4],f32>, %arg1: !torch.vtensor<[3], si64>) -> !torch.vtensor<[2,12,4],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[ELE_0]], %[[ELE_1]], %[[ELE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %7 = torch.aten.tile %arg0, %[[DIM_LIST]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,12,4],f32> + %0 = torch.operator "onnx.Tile"(%arg0, %arg1) : (!torch.vtensor<[2, 3, 4],f32>, !torch.vtensor<[3], si64>) -> !torch.vtensor<[2, 12, 4],f32> + return %0 : !torch.vtensor<[2, 12, 4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_sign +func.func @test_sign(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.sign %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Sign"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} From 12f123eff8ba7e2b70b80d5f78099fa75e4df6b7 Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Thu, 25 Jan 2024 00:33:37 +0530 Subject: [PATCH 110/283] [ONNX][MLIR] Add support for pad op in the onnx pipeline (#2738) This commit adds mapping from `onnx.pad` op to `torch.pad` op. Currently it does not support `axes` parameter of `onnx.pad` op. Signed-off-by: Gaurav Shukla --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 94 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 26 +++++ 2 files changed, 120 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 859215d287d9..a3a053e76b40 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -708,6 +708,100 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, operand, constAlpha); return success(); }); + patterns.onOp( + "Pad", 19, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value data, pads, constantValue, axes; + std::string mode; + + // TODO: The `axes` parameter is not supported yet. + if (!binder.tensorOperandAtIndex(axes, 3)) { + return rewriter.notifyMatchFailure( + binder.op, "The axes parameter is not supported yet"); + } + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorOperandAtIndex(pads, 1) || + binder.tensorOperandAtIndex(constantValue, 2) || + binder.tensorResultType(resultType) || + binder.customOpNameStringAttr(mode, "mode", "constant")) + return failure(); + Location loc = binder.getLoc(); + + // Get pads shape and rank. The pads tensor is expected to be 1-D + // tensor. + auto padsTensorType = pads.getType().cast(); + if (!padsTensorType || !padsTensorType.hasSizes()) { + return rewriter.notifyMatchFailure(binder.op, + "Expect non empty pad tensor"); + } + ArrayRef padsShape = padsTensorType.getSizes(); + int64_t padsRank = padsShape.size(); + if (padsRank != 1) { + return rewriter.notifyMatchFailure(binder.op, + "Expect 1-D pad tensor"); + } + + // Extract all the values of 1-D pad tensor and create a list of all + // these values as torch.pad op expects pad list. + int64_t padsSize = padsShape[0]; + Value constZero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + SmallVector padsTensorValue; + SmallVector emptyShape; + Type padsElemType = + Torch::ValueTensorType::get(padsTensorType.getContext(), emptyShape, + padsTensorType.getOptionalDtype()); + for (uint32_t i = 0; i < padsSize; ++i) { + Value index = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + padsTensorValue.emplace_back(rewriter.create( + loc, padsElemType, pads, constZero, index)); + } + + // The torch.pad op expects a different arrangement of padding pairs for + // each dimension as compared to the onnx.pad op. So, rearranging pad + // tensor to satisfy torch.pad op semantics. + SmallVector padsRearrange; + for (uint32_t i = 0; i < padsSize / 2; i++) { + padsRearrange.emplace_back(padsTensorValue[(padsSize / 2) - 1 - i]); + padsRearrange.emplace_back(padsTensorValue[padsSize - 1 - i]); + } + + Value padsSizeList = + rewriter + .create( + loc, + Torch::ListType::get(rewriter.getType()), + padsRearrange) + .getResult(0); + Value modeVal = rewriter.create( + loc, rewriter.getStringAttr(mode)); + + // The constant value is a 0-d tensor, which needs to be converted to a + // float scalar as torch.pad op expects a float scalar. + auto constValueType = + constantValue.getType().cast(); + if (!constValueType) { + return rewriter.notifyMatchFailure(binder.op, + "Expect non-none constant value"); + } + auto resultTensorType = Torch::ValueTensorType::get( + constValueType.getContext(), emptyShape, rewriter.getF64Type()); + Value none = rewriter.create(loc); + Value cstFalse = rewriter.create(loc, false); + Value constFloatValue = rewriter.create( + loc, resultTensorType, constantValue, + Torch::getDtypeIntValueForType(rewriter, loc, + resultTensorType.getOptionalDtype()), + /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); + Value constScalar = rewriter.create( + loc, rewriter.getType(), constFloatValue); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, padsSizeList, modeVal, constScalar); + return success(); + }); patterns.onOp("Pow", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index ba4487152b5c..3e83f6016d9d 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -348,6 +348,32 @@ func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch. // ----- +// CHECK-LABEL: func.func @test_pad +func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>, %arg2: !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} { + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[SELECT_0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[SELECT_1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[INT2:.+]] = torch.constant.int 2 + // CHECK: %[[SELECT_2:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[INT3:.+]] = torch.constant.int 3 + // CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[LIST:.+]] = torch.prim.tolist(%[[SELECT_1]], %[[SELECT_3]], %[[SELECT_0]], %[[SELECT_2]]) : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.list + // CHECK: %[[STR:.+]] = torch.constant.str "constant" + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[INT7:.+]] = torch.constant.int 7 + // CHECK: %[[CONVERT:.+]] = torch.aten.to.dtype %arg2, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f64> + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[CONVERT]] : !torch.vtensor<[],f64> -> !torch.float + // CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[STR]], %[[ITEM]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32> + // CHECK: return %[[PAD]] : !torch.vtensor<[5,4],f32> + %0 = torch.operator "onnx.Pad"(%arg0, %arg1, %arg2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>, !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32> + return %0 : !torch.vtensor<[5,4],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_pow func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> From 894805dd5e21c3255a6352606a76145740085b92 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 25 Jan 2024 00:38:20 +0530 Subject: [PATCH 111/283] [MLIR][TORCH] Support for `onnx.LayerNormalization` (#2789) Signed-Off By: Vivek Khandelwal --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 109 +++++++++++------- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 13 +++ 2 files changed, 79 insertions(+), 43 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index a3a053e76b40..7e3025da3e9b 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -649,49 +649,72 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } return failure(); }); - patterns.onOp("LayerNormalization", 17, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType Y_type; - Torch::ValueTensorType Mean_type; - Torch::ValueTensorType InvStdDev_type; - Value X; - Value Scale; - Value B; - int64_t axis; - float epsilon; - int64_t stash_type; - if (binder.tensorOperandAtIndex(X, 0) || - binder.tensorOperandAtIndex(Scale, 1) || - binder.tensorOperandAtIndex(B, 2) || - binder.tensorResultTypeAtIndex(Y_type, 0) || - binder.tensorResultTypeAtIndex(Mean_type, 1) || - binder.tensorResultTypeAtIndex(InvStdDev_type, 2) || - binder.s64IntegerAttr(axis, "axis", -1) || - binder.f32FloatAttr(epsilon, "epsilon", 0.00001) || - binder.s64IntegerAttr(stash_type, "stash_type", 1)) - return failure(); - Value constEpsilon = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getF64FloatAttr(epsilon)); - unsigned rank = 1; - if(std::optional maybeRank = Torch::getTensorRank(X)) - rank = *maybeRank; - SmallVector normalized; - axis = Torch::toPositiveDim(axis, rank); - auto X_type = X.getType().cast(); - ArrayRef X_shape = X_type.getSizes(); - for (int64_t n = axis; n < rank ; n++) { - normalized.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(X_shape[n]))); - } - Value normalized_shape = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - normalized); - rewriter.replaceOpWithNewOp( - binder.op, Y_type, Mean_type, InvStdDev_type, X, normalized_shape, Scale, B, constEpsilon); - return success(); - }); + patterns.onOp( + "LayerNormalization", 17, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType yType, meanType, invStdDevType; + Value x, scale, b; + int64_t axis, stashType; + float epsilon; + if (binder.tensorOperandAtIndex(x, 0) || + binder.tensorOperandAtIndex(scale, 1) || + binder.tensorOperandAtIndex(b, 2) || + binder.tensorResultTypeAtIndex(yType, 0) || + binder.s64IntegerAttr(axis, "axis", -1) || + binder.f32FloatAttr(epsilon, "epsilon", 0.00001) || + binder.s64IntegerAttr(stashType, "stash_type", 1)) + return failure(); + Value constEpsilon = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(epsilon)); + unsigned rank = 1; + if (std::optional maybeRank = Torch::getTensorRank(x)) + rank = *maybeRank; + SmallVector normalized; + axis = Torch::toPositiveDim(axis, rank); + auto xType = x.getType().cast(); + if (!xType.hasSizes()) { + return rewriter.notifyMatchFailure( + binder.op, "Expected input (X) to have sizes"); + } + ArrayRef xShape = xType.getSizes(); + for (int64_t n = axis; n < rank; n++) { + normalized.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(xShape[n]))); + } + Value normalized_shape = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + normalized); + + int64_t numResults = binder.op->getNumResults(); + if (numResults == 1) { + SmallVector reducedShape(rank, 1); + for (int64_t i = 0; i < axis; i++) + reducedShape[i] = xShape[i]; + auto reducedType = xType.getWithSizesAndDtype( + reducedShape, xType.getOptionalDtype()); + Value y = rewriter + .create( + binder.getLoc(), yType, /*meanType=*/reducedType, + /*invStdDevType=*/reducedType, x, normalized_shape, + scale, b, constEpsilon) + .getResult0(); + rewriter.replaceOp(binder.op, y); + return success(); + } + if (numResults == 3) { + if (binder.tensorResultTypeAtIndex(meanType, 1) || + binder.tensorResultTypeAtIndex(invStdDevType, 2)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, yType, meanType, invStdDevType, x, normalized_shape, + scale, b, constEpsilon); + return success(); + } + return rewriter.notifyMatchFailure( + binder.op, "Unimplemented: expected either 1 or 3 results"); + }); patterns.onOp("LeakyRelu", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 3e83f6016d9d..00d1355c4699 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -172,6 +172,19 @@ func.func @test_layer_norm(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtens // ----- +// CHECK-LABEL : func.func @test_layer_norm_single_result +func.func @test_layer_norm_single_result(%arg0: !torch.vtensor<[1,4,768],f32>, %arg1: !torch.vtensor<[768],f32>, %arg2: !torch.vtensor<[768],f32>) -> (!torch.vtensor<[1,4,768], f32>) + attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %float9.999990e-06 = torch.constant.float 9.9999997473787516E-6 + // CHECK: %int768 = torch.constant.int 768 + // CHECK: %0 = torch.prim.ListConstruct %int768 : (!torch.int) -> !torch.list + // CHECK: %result0, %result1, %result2 = torch.aten.native_layer_norm %arg0, %0, %arg1, %arg2 + %0 = torch.operator "onnx.LayerNormalization"(%arg0, %arg1, %arg2) {torch.onnx.axis = -1 : si64, torch.onnx.epsilon = 9.99999974E-6 : f32} : (!torch.vtensor<[1,4,768],f32>, !torch.vtensor<[768],f32>, !torch.vtensor<[768],f32>) -> !torch.vtensor<[1,4,768],f32> + return %0 : !torch.vtensor<[1,4,768],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_leaky_relu func.func @test_leaky_relu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 16 : si64} { // CHECK-DAG: %[[F2:.+]] = torch.constant.float 2 From 60bf6c25af8a34f8d9356636d90a18c24467c6ec Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 24 Jan 2024 12:28:48 -0800 Subject: [PATCH 112/283] [onnx] Lower `onnx.QLinearMatMul` lowering to `torch` operators (#2776) We can plumb the linear matmul into pytorch using its quantized types with side channel information. To handle the final int8 operation we dequantize and requantize. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 134 +++++++++++++++++- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 65 ++++++++- 2 files changed, 195 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index b8468773549c..2ead942ded1b 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -55,7 +55,9 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value zeropoint = operands[2]; auto scaleTy = scale.getType().dyn_cast(); - if (!scaleTy || !scaleTy.hasSizes()) return rewriter.notifyMatchFailure(binder.op, "requires known rank"); + if (!scaleTy || !scaleTy.hasSizes()) + return rewriter.notifyMatchFailure(binder.op, + "requires known rank"); if (!resultType.hasDtype()) return rewriter.notifyMatchFailure( binder.op, "requires known result dtype"); @@ -89,9 +91,135 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } return failure(); + }); + patterns.onOp( + "QLinearMatMul", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if (binder.tensorOperands(operands, 8) || + binder.tensorResultType(resultType)) + return failure(); + Value a = operands[0]; + Value aScale = operands[1]; + Value aZp = operands[2]; + Value b = operands[3]; + Value bScale = operands[4]; + Value bZp = operands[5]; + Value cScale = operands[6]; + Value cZp = operands[7]; + + auto check = [](Value v) { + auto vTy = v.getType().cast(); + for (auto dim : vTy.getSizes()) + if (dim != 1) + return false; + return true; + }; + if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) || + !check(cScale) || !check(cScale)) + return rewriter.notifyMatchFailure( + binder.op, "not supported for non per-tensor quantization"); + + Value emptyList = rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + ValueRange{}); + auto extract = [&rewriter, &binder, &emptyList](Value v) { + auto vTy = v.getType().cast(); + if (!vTy.getSizes().empty()) { + vTy = rewriter.getType( + ArrayRef({}), vTy.getOptionalDtype()); + v = rewriter.create(binder.getLoc(), vTy, v, + emptyList); + } + + Type extractTy = rewriter.getType(); + if (isa(vTy.getDtype())) + extractTy = rewriter.getType(); + + return rewriter.create(binder.getLoc(), extractTy, + v); + }; + + aZp = extract(aZp); + bZp = extract(bZp); + cZp = extract(cZp); + aScale = extract(aScale); + bScale = extract(bScale); + cScale = extract(cScale); + + auto getQTy = + [&rewriter](Torch::ValueTensorType ty) -> Torch::ValueTensorType { + auto dt = ty.getDtype(); + Type newDt; + if (dt.isUnsignedInteger(8)) { + newDt = rewriter.getType(); + } else if (dt.isSignedInteger(8)) { + newDt = rewriter.getType(); + } else if (dt.isSignedInteger(32)) { + newDt = rewriter.getType(); + } else { + return nullptr; + } - } - ); + return rewriter.getType(ty.getOptionalSizes(), + newDt); + }; + + auto make = [&rewriter, &binder, &getQTy](Value v, Value scale, + Value zp) -> Value { + auto ty = v.getType().cast(); + auto newTy = getQTy(ty); + return rewriter.create( + binder.getLoc(), newTy, v, scale, zp); + }; + + a = make(a, aScale, aZp); + b = make(b, bScale, bZp); + + auto cTy = rewriter.getType( + resultType.getOptionalSizes(), + rewriter.getIntegerType(32, /*issigned=*/true)); + + Value c; + if (cTy.getSizes().size() == 2) { + c = rewriter.create(binder.getLoc(), cTy, a, b); + } else { + c = rewriter.create(binder.getLoc(), cTy, a, b); + } + + cTy = rewriter.getType( + resultType.getOptionalSizes(), + rewriter.getType()); + + Value mmScale = rewriter.create( + binder.getLoc(), rewriter.getType(), aScale, + bScale); + Value mmZp = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + c = rewriter.create( + binder.getLoc(), cTy, c, mmScale, mmZp); + cTy = rewriter.getType( + resultType.getOptionalSizes(), rewriter.getF32Type()); + + c = rewriter.create(binder.getLoc(), cTy, + c); + cTy = getQTy(resultType); + Value dtyVal = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(64), + static_cast( + Torch::getScalarTypeForType(cTy.getDtype())))); + c = rewriter.create( + binder.getLoc(), cTy, c, cScale, cZp, dtyVal); + rewriter.replaceOpWithNewOp(binder.op, resultType, + c); + return success(); + }); patterns.onOp("Reciprocal", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 22b529d12c90..b2128175d9d7 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -4,7 +4,6 @@ // level constants. This is a pragmatic choice which lets us have a lot // of tests in this file, whereas the others tend to be more bespoke. - // CHECK-LABEL: @test_quantizelinear_si8 func.func @test_quantizelinear_si8(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { %0 = torch.operator "onnx.QuantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],f32>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],si8> @@ -48,6 +47,70 @@ func.func @test_quantizelinear_i32(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch // ----- +// CHECK-LABEL: @test_qlinearmatmul_2D +func.func @test_qlinearmatmul_2D(%arg0: !torch.vtensor<[2,4],ui8>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],ui8>, %arg3: !torch.vtensor<[4,3],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[1],f32>, %arg7: !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,3],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.QLinearMatMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[2,4],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[4,3],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,3],ui8> + // CHECK-DAG: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK-DAG: %[[RESH0:.+]] = torch.aten.reshape %arg2, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list -> !torch.vtensor<[],ui8> + // CHECK-DAG: %[[RESH1:.+]] = torch.aten.reshape %arg5, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list -> !torch.vtensor<[],ui8> + // CHECK-DAG: %[[RESH2:.+]] = torch.aten.reshape %arg7, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list -> !torch.vtensor<[],ui8> + // CHECK-DAG: %[[RESH3:.+]] = torch.aten.reshape %arg1, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[RESH4:.+]] = torch.aten.reshape %arg4, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[RESH5:.+]] = torch.aten.reshape %arg6, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[AZP:.+]] = torch.aten.item %[[RESH0]] : !torch.vtensor<[],ui8> -> !torch.int + // CHECK-DAG: %[[BZP:.+]] = torch.aten.item %[[RESH1]] : !torch.vtensor<[],ui8> -> !torch.int + // CHECK-DAG: %[[CZP:.+]] = torch.aten.item %[[RESH2]] : !torch.vtensor<[],ui8> -> !torch.int + // CHECK-DAG: %[[ASCALE:.+]] = torch.aten.item %[[RESH3]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK-DAG: %[[BSCALE:.+]] = torch.aten.item %[[RESH4]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK-DAG: %[[CCSCALE:.+]] = torch.aten.item %[[RESH5]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK-DAG: %[[LHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[ASCALE]], %[[AZP]] : !torch.vtensor<[2,4],ui8>, !torch.float, !torch.int -> !torch.vtensor<[2,4],!torch.quint8> + // CHECK-DAG: %[[RHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[BSCALE]], %[[BZP]] : !torch.vtensor<[4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[4,3],!torch.quint8> + // CHECK: %[[MM:.+]] = torch.aten.mm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,4],!torch.quint8>, !torch.vtensor<[4,3],!torch.quint8> -> !torch.vtensor<[2,3],si32> + // CHECK: %[[CSCALE:.+]] = torch.aten.mul.float %[[ASCALE]], %[[BSCALE]] : !torch.float, !torch.float -> !torch.float + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK: %[[QC:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[MM]], %[[CSCALE]], %[[ZERO]] : !torch.vtensor<[2,3],si32>, !torch.float, !torch.int -> !torch.vtensor<[2,3],!torch.qint32> + // CHECK: %[[FC:.+]] = torch.aten.dequantize.self %[[QC]] : !torch.vtensor<[2,3],!torch.qint32> -> !torch.vtensor<[2,3],f32> + // CHECK: %[[DTY:.+]] = torch.constant.int 13 + // CHECK: %[[QO:.+]] = torch.aten.quantize_per_tensor %[[FC]], %[[CCSCALE]], %[[CZP]], %[[DTY]] : !torch.vtensor<[2,3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[2,3],!torch.quint8> + // CHECK: %[[OUT:.+]] = torch.aten.int_repr %[[QO]] : !torch.vtensor<[2,3],!torch.quint8> -> !torch.vtensor<[2,3],ui8> + // CHECK: return %[[OUT]] + return %0 : !torch.vtensor<[2,3],ui8> +} + +// ----- + +// CHECK-LABEL: @test_qlinearmatmul_3D +func.func @test_qlinearmatmul_3D(%arg0: !torch.vtensor<[2,2,4],ui8>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],ui8>, %arg3: !torch.vtensor<[2,4,3],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[1],f32>, %arg7: !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,2,3],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.QLinearMatMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[2,2,4],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[2,4,3],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,2,3],ui8> + // CHECK-DAG: %[[EMPTY:.+]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK-DAG: %[[RESH0:.+]] = torch.aten.reshape %arg2, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list -> !torch.vtensor<[],ui8> + // CHECK-DAG: %[[RESH1:.+]] = torch.aten.reshape %arg5, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list -> !torch.vtensor<[],ui8> + // CHECK-DAG: %[[RESH2:.+]] = torch.aten.reshape %arg7, %[[EMPTY]] : !torch.vtensor<[1],ui8>, !torch.list -> !torch.vtensor<[],ui8> + // CHECK-DAG: %[[RESH3:.+]] = torch.aten.reshape %arg1, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[RESH4:.+]] = torch.aten.reshape %arg4, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[RESH5:.+]] = torch.aten.reshape %arg6, %[[EMPTY]] : !torch.vtensor<[1],f32>, !torch.list -> !torch.vtensor<[],f32> + // CHECK-DAG: %[[AZP:.+]] = torch.aten.item %[[RESH0]] : !torch.vtensor<[],ui8> -> !torch.int + // CHECK-DAG: %[[BZP:.+]] = torch.aten.item %[[RESH1]] : !torch.vtensor<[],ui8> -> !torch.int + // CHECK-DAG: %[[CZP:.+]] = torch.aten.item %[[RESH2]] : !torch.vtensor<[],ui8> -> !torch.int + // CHECK-DAG: %[[ASCALE:.+]] = torch.aten.item %[[RESH3]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK-DAG: %[[BSCALE:.+]] = torch.aten.item %[[RESH4]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK-DAG: %[[CCSCALE:.+]] = torch.aten.item %[[RESH5]] : !torch.vtensor<[],f32> -> !torch.float + // CHECK-DAG: %[[LHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[ASCALE]], %[[AZP]] : !torch.vtensor<[2,2,4],ui8>, !torch.float, !torch.int -> !torch.vtensor<[2,2,4],!torch.quint8> + // CHECK-DAG: %[[RHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[BSCALE]], %[[BZP]] : !torch.vtensor<[2,4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[2,4,3],!torch.quint8> + // CHECK: %[[MM:.+]] = torch.aten.bmm %[[LHS]], %[[RHS]] : !torch.vtensor<[2,2,4],!torch.quint8>, !torch.vtensor<[2,4,3],!torch.quint8> -> !torch.vtensor<[2,2,3],si32> + // CHECK: %[[CSCALE:.+]] = torch.aten.mul.float %[[ASCALE]], %[[BSCALE]] : !torch.float, !torch.float -> !torch.float + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK: %[[QC:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[MM]], %[[CSCALE]], %[[ZERO]] : !torch.vtensor<[2,2,3],si32>, !torch.float, !torch.int -> !torch.vtensor<[2,2,3],!torch.qint32> + // CHECK: %[[FC:.+]] = torch.aten.dequantize.self %[[QC]] : !torch.vtensor<[2,2,3],!torch.qint32> -> !torch.vtensor<[2,2,3],f32> + // CHECK: %[[DTY:.+]] = torch.constant.int 13 + // CHECK: %[[QO:.+]] = torch.aten.quantize_per_tensor %[[FC]], %[[CCSCALE]], %[[CZP]], %[[DTY]] : !torch.vtensor<[2,2,3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[2,2,3],!torch.quint8> + // CHECK: %[[OUT:.+]] = torch.aten.int_repr %[[QO]] : !torch.vtensor<[2,2,3],!torch.quint8> -> !torch.vtensor<[2,2,3],ui8> + // CHECK: return %[[OUT]] + return %0 : !torch.vtensor<[2,2,3],ui8> +} + +// ----- + // CHECK-LABEL: func.func @test_reciprocal func.func @test_reciprocal(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.reciprocal %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> From f6f890520b67482c32bc4c175816f7a811039f50 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 24 Jan 2024 14:02:50 -0800 Subject: [PATCH 113/283] [torch][quant] Quantized `torch.mm` for linalg with end-to-end test (#2750) This includes custom op matching for decomposed operations and fusing dequantization into dense operations. As a validation we compare to the dequant+mm torch implementation. --- externals/llvm-project | 2 +- .../Dialect/Torch/Transforms/Passes.h | 4 + .../Dialect/Torch/Transforms/Passes.td | 28 +++ lib/Conversion/TorchToLinalg/Linear.cpp | 45 +++- lib/Dialect/Torch/Transforms/CMakeLists.txt | 2 + .../Torch/Transforms/FuseQuantizedOps.cpp | 214 ++++++++++++++++++ .../Torch/Transforms/MatchQuantizedOps.cpp | 109 +++++++++ .../TorchConversion/Transforms/Passes.cpp | 8 +- .../base_lazy_backend/shape_inference.cpp | 38 ++++ projects/pt1/e2e_testing/xfail_sets.py | 4 +- .../torch_mlir_e2e_test/test_suite/matmul.py | 27 +++ test/Dialect/Torch/fuse-quantized-ops.mlir | 62 +++++ .../Torch/match-quantized-customs-ops.mlir | 42 ++++ 13 files changed, 577 insertions(+), 8 deletions(-) create mode 100644 lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp create mode 100644 lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp create mode 100644 test/Dialect/Torch/fuse-quantized-ops.mlir create mode 100644 test/Dialect/Torch/match-quantized-customs-ops.mlir diff --git a/externals/llvm-project b/externals/llvm-project index 0cb024b357af..eae82ac259ee 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 0cb024b357aff294b1ba0f9d3de8f48ab684962b +Subproject commit eae82ac259ee5a58bc4070a414bc53239e18bad0 diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index 84efddcc93d4..fd7468847e5f 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -106,6 +106,10 @@ createDecomposeComplexOpsPass(ArrayRef legalOps); std::unique_ptr> createRecomposeComplexOpsPass(); +std::unique_ptr> createFuseQuantizedOpsPass(); +std::unique_ptr> +createMatchQuantizedCustomOpsPass(); + std::unique_ptr> createReifyShapeCalculationsPass(StringRef extraLibrary); diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td index 8967855c2e52..7b52d786610e 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.td @@ -258,6 +258,34 @@ def RecomposeComplexOps : Pass<"torch-recompose-complex-ops", "func::FuncOp"> { }]; } +def FuseQuantizedOps : Pass<"torch-fuse-quantized-ops", "func::FuncOp"> { + let summary = "QDQ: Fuse recognized QDQ op sequences."; + let constructor = "mlir::torch::Torch::createFuseQuantizedOpsPass()"; + let description = [{ + Torch models often represents quantized operations as the sequence: + Dequantize + DenseOp + Quantize + This allows the existing dense operations to be used without specifically + representing quantized types. It is more computationally efficient to + perform the dense operation in the quantized domain, so we fuse the + quantization / dequantization behavior together and represent as purely + quantized operations. + }]; +} + +def MatchQuantizedCustomOps : Pass<"torch-match-quantized-custom-ops", "func::FuncOp"> { + let summary = "Match quantized operations that occur in different namespace."; + let constructor = "mlir::torch::Torch::createMatchQuantizedCustomOpsPass()"; + let description = [{ + Torch quantization utilities generated custom op versions of known aten + quantziation operations. We can match these specially named operations and + rewrite to the corresponding aten quantized operations. + + We handle this post import to maintain a simplified import process. + }]; +} + def ReifyShapeCalculations : Pass<"torch-reify-shape-calculations", "ModuleOp"> { let summary = "Reify shape calculations."; let constructor = [{ diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 0d64a0dd2d94..d818b99c0c4a 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -29,6 +29,13 @@ using namespace mlir::torch; using namespace mlir::torch::Torch; namespace { + +static void getZeroPoint(Value value, Value &zeropoint) { + if (auto make = value.getDefiningOp()) { + zeropoint = make.getZeroPoint(); + } +} + class ConvertAtenMmOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -64,11 +71,27 @@ class ConvertAtenMmOp : public OpConversionPattern { op.getSelf().getType().cast(); ValueTensorType rhsTorchType = op.getMat2().getType().cast(); + + Value lhsZeroPoint, rhsZeroPoint; + getZeroPoint(op.getSelf(), lhsZeroPoint); + getZeroPoint(op.getMat2(), rhsZeroPoint); + + if (static_cast(lhsZeroPoint) != static_cast(lhsZeroPoint)) { + return rewriter.notifyMatchFailure( + op, "unsupported: aten.mm with mixed quantization"); + } + if (lhsTorchType.getDtype() != rhsTorchType.getDtype()) { return rewriter.notifyMatchFailure( op, "unsupported: aten.mm with different input element types"); } + bool isUnsigned = torch_to_linalg::isUnsignedTorchType(lhsTorchType); + if (lhsZeroPoint && isUnsigned) { + return rewriter.notifyMatchFailure( + op, "unsupported: unsigned quantized matmul not supported"); + } + Value lhsDim0 = rewriter.create(loc, lhs, 0); Value rhsDim1 = rewriter.create(loc, rhs, 1); @@ -89,8 +112,26 @@ class ConvertAtenMmOp : public OpConversionPattern { rewriter, loc, ValueRange{lhsDim0, rhsDim1}, elementType); Value matmul; - auto intType = dyn_cast(lhsTorchType.getDtype()); - if (intType && intType.isUnsigned()) { + if (lhsZeroPoint && !isUnsigned) { + lhsZeroPoint = typeConverter->materializeTargetConversion( + rewriter, loc, + getTypeConverter()->convertType(lhsZeroPoint.getType()), + lhsZeroPoint); + rhsZeroPoint = typeConverter->materializeTargetConversion( + rewriter, loc, + getTypeConverter()->convertType(rhsZeroPoint.getType()), + rhsZeroPoint); + lhsZeroPoint = rewriter.create( + loc, rewriter.getI32Type(), lhsZeroPoint); + rhsZeroPoint = rewriter.create( + loc, rewriter.getI32Type(), rhsZeroPoint); + matmul = + rewriter + .create( + loc, zeroFill.getType(), + ValueRange{lhs, rhs, lhsZeroPoint, rhsZeroPoint}, zeroFill) + .getResult(0); + } else if (isUnsigned) { matmul = rewriter .create( loc, zeroFill.getType(), ValueRange{lhs, rhs}, zeroFill) diff --git a/lib/Dialect/Torch/Transforms/CMakeLists.txt b/lib/Dialect/Torch/Transforms/CMakeLists.txt index 0f7621ff0da4..4def554d9f49 100644 --- a/lib/Dialect/Torch/Transforms/CMakeLists.txt +++ b/lib/Dialect/Torch/Transforms/CMakeLists.txt @@ -3,10 +3,12 @@ add_mlir_library(TorchMLIRTorchPasses DecomposeComplexOps.cpp DropAbstractInterpCalculations.cpp EraseModuleInitializer.cpp + FuseQuantizedOps.cpp Passes.cpp GlobalizeObjectGraph.cpp InlineGlobalSlots.cpp LowerToBackendContract.cpp + MatchQuantizedOps.cpp MaximizeValueSemantics.cpp PrepareForGlobalizeObjectGraph.cpp RecomposeComplexOps.cpp diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp new file mode 100644 index 000000000000..85dadb755112 --- /dev/null +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -0,0 +1,214 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { + +template +class QuantizeOperands : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SrcOp op, + PatternRewriter &rewriter) const override { + llvm::SmallVector operands(op->getOperands()); + + bool dequanted = false; + for (auto &operand : operands) { + if (auto dequant = operand.getDefiningOp()) { + operand = dequant.getOperand(); + dequanted = true; + } + if (auto dequant = operand.getDefiningOp()) { + operand = dequant.getOperand(); + dequanted = true; + } + } + + if (!dequanted) { + return rewriter.notifyMatchFailure(op, "no dequantizations found"); + } + + rewriter.replaceOpWithNewOp(op, op.getType(), operands); + return success(); + } +}; + +template class QuantizeBias : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SrcOp op, + PatternRewriter &rewriter) const override { + llvm::SmallVector operands(op->getOperands()); + if (operands.size() < 3) + return failure(); + + Value bias = operands[2]; + if (bias.getDefiningOp()) + return failure(); + + Value lhsScale; + if (auto qLhs = + operands[0].getDefiningOp()) + lhsScale = qLhs.getScale(); + + Value rhsScale; + if (auto qRhs = + operands[1].getDefiningOp()) + rhsScale = qRhs.getScale(); + + if (!rhsScale || !lhsScale) + return failure(); + + auto biasTy = bias.getType().cast(); + auto biasETy = biasTy.getOptionalDtype(); + if (!biasETy || !isa(biasETy)) + return failure(); + + Value biasScale = rewriter.create( + op.getLoc(), lhsScale.getType(), lhsScale, rhsScale); + + Value zero = rewriter.create( + op.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + + auto qi32Ty = rewriter.getType(); + auto newBiasTy = + rewriter.getType(biasTy.getOptionalSizes(), qi32Ty); + Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), qi32Ty); + bias = rewriter.create( + op.getLoc(), newBiasTy, bias, biasScale, zero, dtype); + + operands[2] = bias; + rewriter.replaceOpWithNewOp(op, op.getType(), operands); + return success(); + } +}; + +template +class QuantizeAccumulator : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SrcOp op, + PatternRewriter &rewriter) const override { + auto lhs = op.getOperand(0); + auto rhs = op.getOperand(1); + + auto resultTy = dyn_cast_or_null(op.getType()); + if (!resultTy || !resultTy.hasDtype()) + return failure(); + + Type resultETy = resultTy.getDtype(); + if (!resultETy.isa()) + return failure(); + + Value lhsScale; + if (auto defining = + lhs.template getDefiningOp()) { + lhsScale = defining.getScale(); + } + + Value rhsScale; + if (auto defining = + rhs.template getDefiningOp()) { + rhsScale = defining.getScale(); + } + + if (!lhsScale || !rhsScale) + return failure(); + + // Quantize the bias input to the expected result: + Value zero = rewriter.create( + op.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + + auto qi32Ty = rewriter.getType(); + Value biasScale = rewriter.create( + op.getLoc(), lhsScale.getType(), lhsScale, rhsScale); + + // Update the quantied type: + llvm::SmallVector operands(op.getOperands()); + + auto newResultTy = + rewriter.getType(resultTy.getOptionalSizes(), qi32Ty); + auto conv = rewriter.create(op.getLoc(), newResultTy, operands); + + // Attach the quantize information to the resulting quint32: + auto intReprTy = rewriter.getType( + resultTy.getOptionalSizes(), + rewriter.getIntegerType(32, IntegerType::Signed)); + auto intRepr = rewriter.create(op.getLoc(), intReprTy, conv); + + auto quantTy = + rewriter.getType(resultTy.getOptionalSizes(), qi32Ty); + auto quant = rewriter.create( + op.getLoc(), quantTy, intRepr, biasScale, zero); + auto dequant = + rewriter.create(op.getLoc(), resultTy, quant); + rewriter.replaceOp(op, dequant); + + return success(); + } +}; + +template class RemoveUnused : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SrcOp op, + PatternRewriter &rewriter) const override { + auto result = op.getResult(); + if (result.use_empty()) { + op.erase(); + return success(); + } + return failure(); + } +}; + +class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns + .insert, + RemoveUnused, + RemoveUnused, + QuantizeOperands, QuantizeOperands, + QuantizeAccumulator, + QuantizeAccumulator, QuantizeBias>( + context); + + GreedyRewriteConfig config; + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +std::unique_ptr> +mlir::torch::Torch::createFuseQuantizedOpsPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp new file mode 100644 index 000000000000..147f16c08eb3 --- /dev/null +++ b/lib/Dialect/Torch/Transforms/MatchQuantizedOps.cpp @@ -0,0 +1,109 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "PassDetail.h" + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" + +using namespace mlir; +using namespace mlir::torch; +using namespace mlir::torch::Torch; + +namespace { + +Type getQuantizedType(MLIRContext *context, Type t) { + if (t.isSignlessInteger(8)) + return Torch::QUInt8Type::get(context); + if (t.isInteger(8) || t.isSignedInteger(8)) + return Torch::QInt8Type::get(context); + if (t.isInteger(32)) + return Torch::QInt32Type::get(context); + return {}; +} + +class MatchQuantizeOperator : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OperatorOp op, + PatternRewriter &rewriter) const override { + if (op.getName() == "torch.quantized_decomposed.quantize_per_tensor") { + auto resultTy = cast(op.getType(0)); + auto qeTy = getQuantizedType(rewriter.getContext(), resultTy.getDtype()); + if (!qeTy) + qeTy = resultTy.getDtype(); + + auto qTy = + rewriter.getType(resultTy.getOptionalSizes(), qeTy); + Value quant = rewriter.create( + op.getLoc(), qTy, + /*self=*/op.getOperand(0), /*scale=*/op.getOperand(1), + /*zero_point=*/op.getOperand(2), /*dtype=*/op.getOperand(5)); + + if (qTy != resultTy) { + quant = rewriter.create(op.getLoc(), resultTy, quant); + } + + rewriter.replaceOpWithNewOp( + op, resultTy, quant, op.getOperand(3), op.getOperand(4)); + return success(); + } + + if (op.getName() == "torch.quantized_decomposed.dequantize_per_tensor") { + auto clamp = rewriter.create( + op.getLoc(), op.getOperand(0).getType(), op.getOperand(0), + op.getOperand(3), op.getOperand(4)); + + auto clampTy = clamp.getType().cast(); + if (!clampTy.hasDtype()) + return rewriter.notifyMatchFailure(op, + "dequantization has unknown dtype"); + + Type dtype = clampTy.getDtype(); + Type qetype = getQuantizedType(op.getContext(), dtype); + if (!qetype) + return rewriter.notifyMatchFailure(op, + "dequantization has unknown qtype"); + + Type qTy = Torch::ValueTensorType::get( + op.getContext(), clampTy.getOptionalSizes(), qetype); + auto quant = rewriter.create( + op.getLoc(), qTy, clamp, op.getOperand(1), op.getOperand(2)); + rewriter.replaceOpWithNewOp( + op, op.getResultTypes(), quant); + return success(); + } + + return failure(); + } +}; + +class MatchQuantizedCustomOpsPass + : public MatchQuantizedCustomOpsBase { +public: + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(context); + patterns.insert(context); + + GreedyRewriteConfig config; + if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), + config))) + return signalPassFailure(); + } +}; + +} // namespace + +std::unique_ptr> +mlir::torch::Torch::createMatchQuantizedCustomOpsPass() { + return std::make_unique(); +} diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 09e99057e0b6..91d468a6941f 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -15,12 +15,13 @@ #include "mlir/Dialect/Tosa/Transforms/Passes.h" #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" +#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" +#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" -#include "torch-mlir/Conversion/TorchToArith/TorchToArith.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" -#include "torch-mlir/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.h" +#include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #ifdef TORCH_MLIR_ENABLE_STABLEHLO #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #endif @@ -64,6 +65,9 @@ void mlir::torch::registerTorchConversionPasses() { void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( OpPassManager &pm) { + // We want to fuse quantized operations together before lowering to linalg. + pm.addNestedPass(Torch::createFuseQuantizedOpsPass()); + // Lower to linalg + guards which is the input to codegen backends. // We do this first as it tends to involve pattern-matching against constants, // (e.g. dimensions which must be constant in a ranked programming model) diff --git a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp index 15080f9764cc..ff43359ebe80 100644 --- a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp +++ b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp @@ -39,6 +39,38 @@ std::vector compute_shape_div(const at::Tensor& self, return {Shape(self.scalar_type(), self.sizes().vec())}; } +std::vector compute_shape__make_per_tensor_quantized_tensor( + const at::Tensor &self, double scale, int64_t zero_point) { + if (self.scalar_type() == at::kChar) + return {Shape(at::kQInt8, self.sizes().vec())}; + if (self.scalar_type() == at::kByte) + return {Shape(at::kQUInt8, self.sizes().vec())}; + if (self.scalar_type() == at::kInt) + return {Shape(at::kQInt32, self.sizes().vec())}; + assert(false); +} + +std::vector compute_shape_int_repr(const at::Tensor &self) { + if (self.scalar_type() == at::kQInt8) + return {Shape(at::kChar, self.sizes().vec())}; + if (self.scalar_type() == at::kQUInt8) + return {Shape(at::kByte, self.sizes().vec())}; + if (self.scalar_type() == at::kQInt32) + return {Shape(at::kInt, self.sizes().vec())}; + assert(false); +} + +std::vector +compute_shape_dequantize(const at::Tensor &self) { + return {Shape(at::kFloat, self.sizes().vec())}; +} + +std::vector +compute_shape_quantize_per_tensor(const at::Tensor &self, double scale, + int64_t zero_point, at::ScalarType dtype) { + return {Shape(dtype, self.sizes().vec())}; +} + std::vector compute_shape_isinf(const at::Tensor& self) { return {Shape(at::kBool, self.sizes().vec())}; } @@ -102,6 +134,12 @@ std::vector compute_shape_var( return {Shape(self.scalar_type(), {})}; } +std::vector compute_shape_nan_to_num( + const at::Tensor & self, c10::optional nan, + c10::optional posinf, c10::optional neginf) { + return {Shape(self.scalar_type(), self.sizes().vec())}; +} + std::vector compute_shape_hardtanh( const at::Tensor& self, const at::Scalar& min_val, const at::Scalar& max_val) { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 7fd1ba55bce4..f0261b16f6af 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -315,6 +315,7 @@ # Dynamo does not support tracing quantized tensors "ElementwiseDequantizePerTensorModule_basic", "ElementwiseQuantizePerTensorModule_basic", + "AtenMmQuint8_basic", # Dynamo not supporting conv_tbc "ConvTbcModule_basic", @@ -1539,7 +1540,4 @@ "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", - "ElementwiseNanToNumModule_Basic", - "ElementwiseQuantizePerTensorModule_basic", - "ElementwiseDequantizePerTensorModule_basic" } diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index ae7ea72031a5..72a4097bc302 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -262,3 +262,30 @@ def forward(self, a, b): @register_test_case(module_factory=lambda: AtenMmIntTypes()) def AtenMmIntTypes_basic(module, tu: TestUtils): module.forward(tu.randint(16, 4, high=100), tu.randint(4, 16, high=100)) + + +# ============================================================================== + +class AtenMmQuint8(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.int8, True), + ([4, 3], torch.int8, True), + ]) + def forward(self, x, y): + qx = torch._make_per_tensor_quantized_tensor(x, 0.1, 8) + qx = torch.dequantize(qx) + qy = torch._make_per_tensor_quantized_tensor(y, 0.1, 8) + qy = torch.dequantize(qy) + qz = torch.mm(qx, qy) + return qz + +@register_test_case(module_factory=lambda: AtenMmQuint8()) +def AtenMmQuint8_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8), + tu.randint(4, 3, low=-128, high=127).to(torch.int8)) diff --git a/test/Dialect/Torch/fuse-quantized-ops.mlir b/test/Dialect/Torch/fuse-quantized-ops.mlir new file mode 100644 index 000000000000..c62a0d13d9cf --- /dev/null +++ b/test/Dialect/Torch/fuse-quantized-ops.mlir @@ -0,0 +1,62 @@ +// RUN: torch-mlir-opt %s --split-input-file --torch-fuse-quantized-ops | FileCheck %s + +// CHECK-LABEL: @mm +func.func @mm(%arg0: !torch.vtensor<[4, 4],si8>, %arg1: !torch.vtensor<[4, 4],si8>) -> !torch.vtensor<[4, 4],f32> { + %scale = torch.constant.float 0.5 + %false = torch.constant.bool false + %zero = torch.constant.int 0 + %one = torch.constant.int 1 + %zp = torch.constant.int -128 + %6 = torch.aten._make_per_tensor_quantized_tensor %arg0, %scale, %one : !torch.vtensor<[4, 4],si8>, !torch.float, !torch.int -> !torch.vtensor<[4, 4],!torch.qint8> + %7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[4, 4],!torch.qint8> -> !torch.vtensor<[4, 4],f32> + %12 = torch.aten._make_per_tensor_quantized_tensor %arg1, %scale, %zero : !torch.vtensor<[4, 4],si8>, !torch.float, !torch.int -> !torch.vtensor<[4, 4],!torch.qint8> + %13 = torch.aten.dequantize.tensor %12 : !torch.vtensor<[4, 4],!torch.qint8> -> !torch.vtensor<[4, 4],f32> + %16 = torch.aten.mm %7, %13 : !torch.vtensor<[4, 4],f32>, !torch.vtensor<[4, 4],f32> -> !torch.vtensor<[4, 4],f32> + + // CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[QUARTER:.+]] = torch.constant.float 2.500000e-01 + // CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF:.+]], %[[ONE]] : !torch.vtensor<[4,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[4,4],!torch.qint8> + // CHECK-DAG: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF:.+]], %[[ZERO]] : !torch.vtensor<[4,4],si8>, !torch.float, !torch.int -> !torch.vtensor<[4,4],!torch.qint8> + // CHECK-DAG: %[[MM:.+]] = torch.aten.mm %[[QLHS]], %[[QRHS]] : !torch.vtensor<[4,4],!torch.qint8>, !torch.vtensor<[4,4],!torch.qint8> -> !torch.vtensor<[4,4],!torch.qint32> + // CHECK-DAG: %[[INT:.+]] = torch.aten.int_repr %[[MM]] : !torch.vtensor<[4,4],!torch.qint32> -> !torch.vtensor<[4,4],si32> + // CHECK-DAG: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[INT]], %[[QUARTER]], %[[ZERO]] : !torch.vtensor<[4,4],si32>, !torch.float, !torch.int -> !torch.vtensor<[4,4],!torch.qint32> + // CHECK: %[[OUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[4,4],!torch.qint32> -> !torch.vtensor<[4,4],f32> + return %16 : !torch.vtensor<[4, 4],f32> +} + +// ----- + +// CHECK-LABEL: @convolution +func.func @convolution(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> { + %scale = torch.constant.float 0.5 + %false = torch.constant.bool false + %zero = torch.constant.int 0 + %one = torch.constant.int 1 + %zp = torch.constant.int -128 + %6 = torch.aten._make_per_tensor_quantized_tensor %arg0, %scale, %one : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> + %7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],f32> + %12 = torch.aten._make_per_tensor_quantized_tensor %arg1, %scale, %zero : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8> + %13 = torch.aten.dequantize.tensor %12 : !torch.vtensor<[3,3,2,2],!torch.qint8> -> !torch.vtensor<[3,3,2,2],f32> + %14 = torch.prim.ListConstruct %one, %one : (!torch.int, !torch.int) -> !torch.list + %15 = torch.prim.ListConstruct %zero, %zero : (!torch.int, !torch.int) -> !torch.list + %16 = torch.aten.convolution %7, %13, %arg2, %14, %15, %14, %false, %15, %one : !torch.vtensor<[1,3,8,8],f32>, !torch.vtensor<[3,3,2,2],f32>, !torch.vtensor<[3],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],f32> + + // CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01 + // CHECK-DAG: %[[DTYPE:.+]] = torch.constant.int 14 + // CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> + // CHECK-DAG: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8> + // CHECK-DAG: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[QBIAS:.+]] = torch.aten.quantize_per_tensor %arg2, %[[SCALEO]], %[[ZERO]], %[[DTYPE]] : !torch.vtensor<[3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[3],!torch.qint32> + // CHECK-DAG: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[QBIAS]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.vtensor<[3],!torch.qint32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32> + // CHECK-DAG: %[[INT:.+]] = torch.aten.int_repr %[[CONV]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],si32> + // CHECK-DAG: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[INT]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32> + // CHECK: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32> + return %16 : !torch.vtensor<[1,3,7,7],f32> +} diff --git a/test/Dialect/Torch/match-quantized-customs-ops.mlir b/test/Dialect/Torch/match-quantized-customs-ops.mlir new file mode 100644 index 000000000000..c1a0e9ebf20d --- /dev/null +++ b/test/Dialect/Torch/match-quantized-customs-ops.mlir @@ -0,0 +1,42 @@ +// RUN: torch-mlir-opt --split-input-file --torch-match-quantized-custom-ops %s | FileCheck %s + +// CHECK-LABEL: func.func @quantize_per_tensor +func.func @quantize_per_tensor(%arg0: !torch.vtensor<[1,3,8,8],f32>) -> !torch.vtensor<[1,3,8,8],si8> { + %float = torch.constant.float 0.5 + %zp = torch.constant.int 17 + %min = torch.constant.int -128 + %max = torch.constant.int 127 + %dtype = torch.constant.int 1 + + // CHECK-DAG: %[[SCALE:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[ZP:.+]] = torch.constant.int 17 + // CHECK-DAG: %[[MIN:.+]] = torch.constant.int -128 + // CHECK-DAG: %[[MAX:.+]] = torch.constant.int 127 + // CHECK-DAG: %[[DTYPE:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %arg0, %[[SCALE]], %[[ZP]], %[[DTYPE]] : !torch.vtensor<[1,3,8,8],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> + // CHECK-DAG: %[[REPR:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],si8> + // CHECK: torch.aten.clamp %[[REPR]], %[[MIN]], %[[MAX]] + %0 = torch.operator "torch.quantized_decomposed.quantize_per_tensor"(%arg0, %float, %zp, %min, %max, %dtype) : (!torch.vtensor<[1,3,8,8],f32>, !torch.float, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[1,3,8,8],si8> + return %0 : !torch.vtensor<[1,3,8,8],si8> +} + +// ----- + +// CHECK-LABEL: func.func @dequantize_per_tensor +func.func @dequantize_per_tensor(%arg0: !torch.vtensor<[1,3,8,8],si8>) -> !torch.vtensor<[1,3,8,8],f32> { + %float = torch.constant.float 0.5 + %zp = torch.constant.int 17 + %min = torch.constant.int -128 + %max = torch.constant.int 127 + %dtype = torch.constant.int 1 + + // CHECK-DAG: %[[SCALE:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[ZP:.+]] = torch.constant.int 17 + // CHECK-DAG: %[[MIN:.+]] = torch.constant.int -128 + // CHECK-DAG: %[[MAX:.+]] = torch.constant.int 127 + // CHECK-DAG: %[[CLAMP:.+]] = torch.aten.clamp %arg0, %[[MIN]], %[[MAX]] : !torch.vtensor<[1,3,8,8],si8>, !torch.int, !torch.int -> !torch.vtensor<[1,3,8,8],si8> + // CHECK-DAG: %[[QINT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CLAMP]], %[[SCALE]], %[[ZP]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> + // CHECK: %[[DEQUANT:.+]] = torch.aten.dequantize.tensor %[[QINT]] : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],f32> + %13 = torch.operator "torch.quantized_decomposed.dequantize_per_tensor"(%arg0, %float, %zp, %min, %max, %dtype) : (!torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.vtensor<[1,3,8,8],f32> + return %13 : !torch.vtensor<[1,3,8,8],f32> +} From e581b33f9694204e12213f26b5de7b4b5126d6ea Mon Sep 17 00:00:00 2001 From: lonely eagle <75576166+linuxlonelyeagle@users.noreply.github.com> Date: Thu, 25 Jan 2024 10:44:08 +0800 Subject: [PATCH 114/283] [Stablehlo]fix CumsumInputDtypeInt32Module_basic on stablehlo backend. (#2797) Code used for testing.For the location of CumsumInputDtypeInt32Module in the repo you can see [here](https://github.com/llvm/torch-mlir/blob/311b6b0286bfa016346bc7fd8b441bbd50216060/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py#L4148). ```python import torch import torch_mlir class CumsumInputDtypeInt32Module(torch.nn.Module): def __init__(self): super().__init__() def forward(self, val): return torch.ops.aten.cumsum(val, 1) module = torch_mlir.compile(CumsumInputDtypeInt32Module(), [torch.randn(2, 7, 4).to(torch.int32)], output_type="stablehlo") print(module.operation.get_asm()) ``` After fixing the bugs. ``` module attributes {torch.debug_module_name = "CumsumInputDtypeInt32Module"} { func.func @forward(%arg0: tensor<2x7x4xi32>) -> tensor<2x7x4xi64> { %0 = stablehlo.constant dense<0> : tensor %1 = stablehlo.convert %arg0 : (tensor<2x7x4xi32>) -> tensor<2x7x4xi64> %2 = "stablehlo.reduce_window"(%1, %0) ({ ^bb0(%arg1: tensor, %arg2: tensor): %3 = stablehlo.add %arg1, %arg2 : tensor stablehlo.return %3 : tensor }) {padding = dense<[[0, 0], [6, 0], [0, 0]]> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 7, 1]> : tensor<3xi64>, window_strides = dense<1> : tensor<3xi64>} : (tensor<2x7x4xi64>, tensor) -> tensor<2x7x4xi64> return %2 : tensor<2x7x4xi64> } } ``` --- lib/Conversion/TorchToStablehlo/Pooling.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 7c28a2fd3004..e90f231c74f5 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -569,11 +569,13 @@ LogicalResult ConvertAtenOp::matchAndRewrite( ConversionPatternRewriter &rewriter) const { Value input = adaptor.getSelf(); auto inputTy = input.getType().cast(); + auto outTy = + getTypeConverter()->convertType(op.getType()).cast(); + input = hlo::promoteType(rewriter, op.getLoc(), input, outTy); + inputTy = input.getType().cast(); auto inputElemTy = inputTy.getElementType(); auto inputRank = inputTy.getRank(); auto inputShape = inputTy.getShape(); - auto outTy = - getTypeConverter()->convertType(op.getType()).cast(); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) { From e824fbc65cd9481a733af7dcf11ffe7e9eaa9bf0 Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Thu, 25 Jan 2024 10:04:04 -0800 Subject: [PATCH 115/283] [torch-mlir][torch] add encoding field to torch type (#2799) This adds an encoding field to the torch type, using the interfaces for printing, parsing, and verification. Note that although this change prepares adding sparsity to the torch type (as illustrated by the round trip and invalid tests), nothing in this change depends on the actual contents of the encoding field! --- .../torch-mlir/Dialect/Torch/IR/TorchTypes.h | 2 +- .../torch-mlir/Dialect/Torch/IR/TorchTypes.td | 23 ++++++- lib/Dialect/Torch/IR/TorchTypes.cpp | 68 ++++++++++++++----- test/Dialect/Torch/invalid.mlir | 13 ++++ test/Dialect/Torch/ops.mlir | 6 ++ 5 files changed, 93 insertions(+), 19 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h index de77a1a8f8a3..c8d1c5051f28 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.h @@ -32,7 +32,7 @@ class ValueTensorType; /// Common getter function signature that covers all tensor types. /// Used for sharing code between NonValueTensorType and ValueTensorType. using GetTensorTypeFn = llvm::function_ref>, Type)>; + MLIRContext *, std::optional>, Type, Attribute)>; /// The representation of an unknown dimension size in an ArrayRef. constexpr static int64_t kUnknownSize = -1; diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td index c3b5c1582c02..898c768ae1c2 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTypes.td @@ -63,12 +63,13 @@ class AnyTorchTensorType ``` tensor-type ::= (`!torch.tensor` | `!torch.vtensor`) tensor-modifiers? - tensor-modifiers ::= `<` sizes-spec `,` dtype-spec `>` + tensor-modifiers ::= `<` sizes-spec `,` dtype-spec (',' sparsity)? `>` sizes-spec ::= `*` | `[` size-list `]` size-list ::= /*empty*/ | size-list-nonempty size-list-nonempty = size (`,` size)* size ::= `?` | decimal-literal dtype-spec ::= `unk` | type + sparsity ::= attribute-value ``` Represents a multi-dimensional array to model Torch's `torch.Tensor` type. @@ -133,6 +134,12 @@ class AnyTorchTensorType |-------------------|--------------------| ``` + The `sparsity` attribute directly mirrors the additional tensor `encoding` + defined by upstream MLIR on the RankedTensorType. Unlike the upstream + attribute, however, this attribute is exclusively used to denote a + straightforward tensor (with an empty attribute) or a sparse tensor + (with a SparseTensorEncodingAttr). + TODO: Support the full set of Torch dtypes. TODO: Use si1? @@ -149,8 +156,20 @@ class AnyTorchTensorType }]; let parameters = (ins OptionalArrayRefTorchParameter<"int64_t", "sizes of dimensions">:$optionalSizes, - "::mlir::Type":$optionalDtype + "::mlir::Type":$optionalDtype, + "Attribute":$optionalSparsity ); + let builders = [ + // Provide builder where optionalSparsity is empty by default. + TypeBuilder<(ins + "::std::optional>":$optionalSizes, + "::mlir::Type":$optionalDtype, + CArg<"Attribute", "{}">:$optionalSparsity + ), [{ + return $_get(context, optionalSizes, optionalDtype, optionalSparsity); + }]> + ]; + let skipDefaultBuilders = 1; let genVerifyDecl = 1; let hasCustomAssemblyFormat = 1; string extraBaseClassDeclaration = [{ diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index 33ef459081c4..b5b63954fe42 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -8,10 +8,11 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/IR/DialectImplementation.h" -#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/STLExtras.h" using namespace mlir; @@ -239,7 +240,7 @@ ValueTensorType BaseTensorType::getWithValueSemantics() const { static LogicalResult verifyTensorType(function_ref emitError, std::optional> optionalSizes, - Type optionalDtype) { + Type optionalDtype, Attribute optionalSparsity) { if (optionalDtype && !isValidTorchDtype(optionalDtype)) { emitError() << "invalid dtype " << optionalDtype << " for !torch.tensor type"; @@ -253,6 +254,24 @@ verifyTensorType(function_ref emitError, } } } + // Verify sparsity encoding against a known type and shape using the encoding + // verification interface. Any implementation emits a diagnostic on failure. + // Also verify sparsity encoding is truly a sparse encoding attrbute. + if (optionalSparsity) { + if (optionalDtype && optionalSizes.has_value()) { + if (auto venc = llvm::dyn_cast_or_null( + optionalSparsity)) { + if (failed(venc.verifyEncoding(optionalSizes.value(), optionalDtype, + emitError))) { + return failure(); + } + } + } + if (!optionalSparsity.isa()) { + emitError() << "invalid sparsity encoding attribute"; + return failure(); + } + } return success(); } @@ -262,7 +281,8 @@ Type parseTensorType(MLIRContext *context, AsmParser &parser, if (parser.parseOptionalLess()) return getTensorType(context, /*optionalSizes=*/std::nullopt, - /*optionalDtype=*/Type()); + /*optionalDtype=*/Type(), + /*optionalSparsity=*/Attribute()); bool hasSizes; SmallVector sizes; if (succeeded(parser.parseOptionalStar())) { @@ -307,6 +327,12 @@ Type parseTensorType(MLIRContext *context, AsmParser &parser, if (parser.parseType(optionalDtype)) return Type(); } + Attribute optionalSparsity; + if (succeeded(parser.parseOptionalComma())) { + // Explicit encoding. + if (parser.parseAttribute(optionalSparsity)) + return Type(); + } if (parser.parseGreater()) return Type(); std::optional> optionalSizes; @@ -314,15 +340,15 @@ Type parseTensorType(MLIRContext *context, AsmParser &parser, optionalSizes.emplace(sizes); if (failed(verifyTensorType([&]() { return parser.emitError(startLoc); }, - optionalSizes, optionalDtype))) + optionalSizes, optionalDtype, optionalSparsity))) return Type(); - return getTensorType(context, optionalSizes, optionalDtype); + return getTensorType(context, optionalSizes, optionalDtype, optionalSparsity); } static void printTensorType(AsmPrinter &printer, std::optional> optionalSizes, - Type optionalDtype) { + Type optionalDtype, Attribute optionalSparsity) { if (!optionalSizes && !optionalDtype) return; printer << "<"; @@ -345,6 +371,10 @@ static void printTensorType(AsmPrinter &printer, printer.printType(optionalDtype); else printer << "unk"; + if (optionalSparsity) { + printer << ","; + printer.printAttribute(optionalSparsity); + } printer << ">"; } @@ -367,8 +397,9 @@ NonValueTensorType::getWithLeastStaticInformation(MLIRContext *context) { LogicalResult NonValueTensorType::verify(function_ref emitError, std::optional> optionalSizes, - Type optionalDtype) { - return verifyTensorType(emitError, optionalSizes, optionalDtype); + Type optionalDtype, Attribute optionalSparsity) { + return verifyTensorType(emitError, optionalSizes, optionalDtype, + optionalSparsity); } Type NonValueTensorType::parse(AsmParser &parser) { @@ -376,13 +407,15 @@ Type NonValueTensorType::parse(AsmParser &parser) { return parseTensorType( context, parser, [](MLIRContext *context, std::optional> optionalSizes, - Type optionalType) { - return NonValueTensorType::get(context, optionalSizes, optionalType); + Type optionalType, Attribute optionalSparsity) { + return NonValueTensorType::get(context, optionalSizes, optionalType, + optionalSparsity); }); } void NonValueTensorType::print(AsmPrinter &printer) const { - printTensorType(printer, getOptionalSizes(), getOptionalDtype()); + printTensorType(printer, getOptionalSizes(), getOptionalDtype(), + getOptionalSparsity()); } //===----------------------------------------------------------------------===// @@ -440,8 +473,9 @@ TensorType ValueTensorType::toBuiltinTensor() const { LogicalResult ValueTensorType::verify(function_ref emitError, std::optional> optionalSizes, - Type optionalDtype) { - return verifyTensorType(emitError, optionalSizes, optionalDtype); + Type optionalDtype, Attribute optionalSparsity) { + return verifyTensorType(emitError, optionalSizes, optionalDtype, + optionalSparsity); } Type ValueTensorType::parse(AsmParser &parser) { @@ -449,13 +483,15 @@ Type ValueTensorType::parse(AsmParser &parser) { return parseTensorType( context, parser, [](MLIRContext *context, std::optional> optionalSizes, - Type optionalType) { - return ValueTensorType::get(context, optionalSizes, optionalType); + Type optionalType, Attribute optionalSparsity) { + return ValueTensorType::get(context, optionalSizes, optionalType, + optionalSparsity); }); } void ValueTensorType::print(AsmPrinter &printer) const { - printTensorType(printer, getOptionalSizes(), getOptionalDtype()); + printTensorType(printer, getOptionalSizes(), getOptionalDtype(), + getOptionalSparsity()); } Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) { diff --git a/test/Dialect/Torch/invalid.mlir b/test/Dialect/Torch/invalid.mlir index 067c1a9b67f4..63aa1e3755a9 100644 --- a/test/Dialect/Torch/invalid.mlir +++ b/test/Dialect/Torch/invalid.mlir @@ -362,3 +362,16 @@ func.func @torch.permute$invalid_index_in_permutation (%arg0: !torch.vtensor<[1, return %3 : !torch.vtensor<[1,2,3],f32> } +// ----- + +#SV = #sparse_tensor.encoding<{ map = (d0) -> (d0 : compressed) }> + +// expected-error @+1 {{dimension-rank mismatch between encoding and tensor shape: 1 != 2}} +func.func @foo(%arg0: !torch.vtensor<[64,64],f32,#SV>) -> !torch.vtensor<[64,64],f32,#SV> { + return %arg0 : !torch.vtensor<[64,64],f32,#SV> +} + +// ----- + +// expected-error @+1 {{invalid sparsity encoding attribute}} +func.func private @tensor.sparse() -> !torch.vtensor<[64,64],f32,12345> diff --git a/test/Dialect/Torch/ops.mlir b/test/Dialect/Torch/ops.mlir index d8d0fc33098b..623217fd22dc 100644 --- a/test/Dialect/Torch/ops.mlir +++ b/test/Dialect/Torch/ops.mlir @@ -1,5 +1,7 @@ // RUN: torch-mlir-opt %s | torch-mlir-opt | FileCheck %s +// CHECK: #[[$ENCODING:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> + // CHECK-LABEL: func.func @torch.operator( func.func @torch.operator(%arg0: !torch.tensor, %arg1: !torch.tensor) -> !torch.tensor { // CHECK: torch.operator "ns.unqual.overload"(%arg0, %arg1) : (!torch.tensor, !torch.tensor) -> !torch.tensor @@ -28,6 +30,10 @@ func.func private @tensor.some_sizes_known() -> !torch.tensor<[?,2,?,4],unk> // CHECK: @tensor.fully_determined() -> !torch.vtensor<[1,2,3,4],f32> func.func private @tensor.fully_determined() -> !torch.vtensor<[1,2,3,4],f32> +// CHECK: @tensor.sparse() -> !torch.vtensor<[64,64],f32,#[[$ENCODING]]> +#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +func.func private @tensor.sparse() -> !torch.vtensor<[64,64],f32,#CSR> + // CHECK: @tuple.empty() -> !torch.tuple<> func.func private @tuple.empty() -> !torch.tuple<> // CHECK: @tuple.one_element() -> !torch.tuple From dc9c624a29d8b5ca42f0e5d6474837cf12ea410c Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Thu, 25 Jan 2024 12:54:40 -0800 Subject: [PATCH 116/283] [torch-mlir][sparse] provide a bazel build (#2805) --- utils/bazel/torch-mlir-overlay/BUILD.bazel | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index 138dcbefb6ac..493383cf9161 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -133,6 +133,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:SparseTensorDialect", "@llvm-project//mlir:TransformUtils", ], ) From fe836ceebf48e369bc6f2904c64704559dbb9689 Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Thu, 25 Jan 2024 14:24:13 -0800 Subject: [PATCH 117/283] [torch-mlir][test] cleanup trailing whitespace in mlir files (#2806) --- test/Dialect/Torch/GlobalizeObjectGraph/submodules.mlir | 2 +- test/Dialect/Torch/canonicalize.mlir | 2 +- test/Dialect/Torch/match-quantized-customs-ops.mlir | 2 +- test/Dialect/Torch/simplify-shape-calculations.mlir | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/test/Dialect/Torch/GlobalizeObjectGraph/submodules.mlir b/test/Dialect/Torch/GlobalizeObjectGraph/submodules.mlir index eacd36493791..7fc261850def 100644 --- a/test/Dialect/Torch/GlobalizeObjectGraph/submodules.mlir +++ b/test/Dialect/Torch/GlobalizeObjectGraph/submodules.mlir @@ -1,6 +1,6 @@ // RUN: torch-mlir-opt -torch-globalize-object-graph -split-input-file %s | FileCheck %s -// Check that linkage names consist of the dotted path from the root. +// Check that linkage names consist of the dotted path from the root. // CHECK-LABEL: torch.global_slot.module_initializer { // CHECK: %[[FLOAT:.*]] = torch.constant.float 4.200000e+01 diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index abb990cccc8c..9172d4642759 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2129,7 +2129,7 @@ func.func @torch.aten.floor$canonicalize(%arg0: !torch.vtensor<[?,?],si64>) -> ! } // CHECK-LABEL: func.func @torch.aten.numel$canonicalize -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4],f32> +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4],f32> // CHECK-NEXT: %int12 = torch.constant.int 12 // CHECK-NEXT: return %int12 : !torch.int func.func @torch.aten.numel$canonicalize(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.int { diff --git a/test/Dialect/Torch/match-quantized-customs-ops.mlir b/test/Dialect/Torch/match-quantized-customs-ops.mlir index c1a0e9ebf20d..4196e688157f 100644 --- a/test/Dialect/Torch/match-quantized-customs-ops.mlir +++ b/test/Dialect/Torch/match-quantized-customs-ops.mlir @@ -7,7 +7,7 @@ func.func @quantize_per_tensor(%arg0: !torch.vtensor<[1,3,8,8],f32>) -> !torch.v %min = torch.constant.int -128 %max = torch.constant.int 127 %dtype = torch.constant.int 1 - + // CHECK-DAG: %[[SCALE:.+]] = torch.constant.float 5.000000e-01 // CHECK-DAG: %[[ZP:.+]] = torch.constant.int 17 // CHECK-DAG: %[[MIN:.+]] = torch.constant.int -128 diff --git a/test/Dialect/Torch/simplify-shape-calculations.mlir b/test/Dialect/Torch/simplify-shape-calculations.mlir index 0d3b1f661bde..b7e7cf17ba0e 100644 --- a/test/Dialect/Torch/simplify-shape-calculations.mlir +++ b/test/Dialect/Torch/simplify-shape-calculations.mlir @@ -316,7 +316,7 @@ func.func @abstractly_interpret_list_ops$mutation_in_child_region(%arg0: !torch. // CHECK: } else { // CHECK: torch.prim.If.yield %[[ARG1]] : !torch.list // CHECK: } - // .... and this one don't have the same object identity, but should! + // .... and this one don't have the same object identity, but should! // CHECK: %[[VAL_8:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[VAL_9:.*]] = torch.prim.If %[[ARG2]] -> (!torch.list) { // CHECK: torch.prim.If.yield %[[VAL_8]] : !torch.list From 0aed231e216ea8cd84aa64887d94a57a2c2c5f36 Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Thu, 25 Jan 2024 14:24:28 -0800 Subject: [PATCH 118/283] [torch-mlir][conversion-test] cleanup trailing whitespace in mlir files (#2807) --- .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 6 ++--- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 14 +++++------ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 24 +++++++++---------- .../unsupported_fb_opt_ops.mlir | 12 +++++----- test/Conversion/TorchToLinalg/basic.mlir | 6 ++--- test/Conversion/TorchToLinalg/pooling.mlir | 8 +++---- test/Conversion/TorchToLinalg/view.mlir | 4 ++-- test/Conversion/TorchToStablehlo/basic.mlir | 2 +- test/Conversion/TorchToStablehlo/linear.mlir | 14 +++++------ test/Conversion/TorchToStablehlo/pooling.mlir | 2 +- .../TorchToTosa/conv2d_transpose.mlir | 2 +- 11 files changed, 47 insertions(+), 47 deletions(-) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 77cdd5786ea8..493cdc98312f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1036,7 +1036,7 @@ func.func @test_exp(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f3 // ----- // CHECK-LABEL: @test_expand_dim2_shape2 -func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !torch.vtensor<[2],si32>) +func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 @@ -1089,10 +1089,10 @@ func.func @test_expand_dim3_shape4(%arg0: !torch.vtensor<[1,3,1],f32>, %arg1: !t // CHECK: torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int // CHECK: torch.prim.ListConstruct %1, %3, %5, %7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %9 = torch.aten.broadcast_to %arg0, %8 : !torch.vtensor<[1,3,1],f32>, !torch.list -> !torch.vtensor<[3,3,3,3],f32> + // CHECK: %9 = torch.aten.broadcast_to %arg0, %8 : !torch.vtensor<[1,3,1],f32>, !torch.list -> !torch.vtensor<[3,3,3,3],f32> %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,3,1],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,3,3,3],f32> return %0 : !torch.vtensor<[3,3,3,3],f32> -} +} // ----- diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 00d1355c4699..6a420300cdc9 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -160,8 +160,8 @@ func.func @test_gemm_alpha_beta(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch. // ----- // CHECK-LABEL : func.func @test_layer_norm -func.func @test_layer_norm(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[3,4],f32>, %arg2: !torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4], f32>, !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) - attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_layer_norm(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[3,4],f32>, %arg2: !torch.vtensor<[3,4],f32>) -> (!torch.vtensor<[3,4], f32>, !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) + attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %int3 = torch.constant.int 3 // CHECK: %int4 = torch.constant.int 4 // CHECK: %0 = torch.prim.ListConstruct %int3, %int4 : (!torch.int, !torch.int) -> !torch.list @@ -173,8 +173,8 @@ func.func @test_layer_norm(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtens // ----- // CHECK-LABEL : func.func @test_layer_norm_single_result -func.func @test_layer_norm_single_result(%arg0: !torch.vtensor<[1,4,768],f32>, %arg1: !torch.vtensor<[768],f32>, %arg2: !torch.vtensor<[768],f32>) -> (!torch.vtensor<[1,4,768], f32>) - attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @test_layer_norm_single_result(%arg0: !torch.vtensor<[1,4,768],f32>, %arg1: !torch.vtensor<[768],f32>, %arg2: !torch.vtensor<[768],f32>) -> (!torch.vtensor<[1,4,768], f32>) + attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %float9.999990e-06 = torch.constant.float 9.9999997473787516E-6 // CHECK: %int768 = torch.constant.int 768 // CHECK: %0 = torch.prim.ListConstruct %int768 : (!torch.int) -> !torch.list @@ -224,7 +224,7 @@ func.func @test_matmul_4d(%arg0: !torch.vtensor<[1,2,3,4],f32>, %arg1: !torch.vt // CHECK-LABEL: func.func @test_mul func.func @test_mul(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + // CHECK: torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> %0 = torch.operator "onnx.Mul"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } @@ -399,7 +399,7 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], // CHECK-LABEL: @test_hardsigmoid_example func.func @test_hardsigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[ALPHA_FLOAT:.*]] = torch.constant.float 5.000000e-01 - // CHECK: %[[BETA_FLOAT:.*]] = torch.constant.float 0.60000002384185791 + // CHECK: %[[BETA_FLOAT:.*]] = torch.constant.float 0.60000002384185791 // CHECK: %[[ALPHA_MULTI_X_PLUS_BETA:.*]] = torch.aten.add.Scalar %arg0, %[[BETA_FLOAT:.*]], %[[ALPHA_FLOAT:.*]] : !torch.vtensor<[3],f32>, !torch.float, !torch.float -> !torch.vtensor<[3],f32> // CHECK: %[[INT_1:.*]] = torch.constant.int 1 // CHECK: %[[TENSOR_DIMENSION_LIST_FOR_ONE:.*]] = torch.prim.ListConstruct : () -> !torch.list @@ -414,7 +414,7 @@ func.func @test_hardsigmoid_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vt // CHECK: %[[ZERO_TENSOR:.*]] = torch.aten.full %[[TENSOR_DIMENSION_LIST_FOR_ZERO:.*]], %[[INT_0:.*]], %[[INT_TYPE_FOR_TENSOR_ZERO:.*]], %[[NONE_FOR_ZERO:.*]], %none_0, %none_0 : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> // CHECK: %[[RESULT:.*]] = torch.aten.maximum %[[ZERO_TENSOR:.*]], %[[MIN_EXPRESSION:.*]] : !torch.vtensor<[],f32>, !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> // CHECK: return %[[RESULT:.*]] : !torch.vtensor<[3],f32> - + %0 = torch.operator "onnx.HardSigmoid"(%arg0) {torch.onnx.alpha = 5.000000e-01 : f32, torch.onnx.beta = 6.000000e-01 : f32} : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index b2128175d9d7..918141065cc6 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -931,7 +931,7 @@ func.func @test_split_2d_uneven_split_opset18(%arg0: !torch.vtensor<[2,8],f32>) } // ----- - + // CHECK-LABEL: func.func @test_tan func.func @test_tan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[TAN:.+]] = torch.aten.tan %arg0 @@ -973,7 +973,7 @@ func.func @test_transpose_all_permutations_4(%arg0: !torch.vtensor<[2,3,4],f32>) // CHECK-LABEL: func.func @test_slice func.func @test_slice(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[2],si64>, %arg2: !torch.vtensor<[2],si64>, %arg3: !torch.vtensor<[2],si64>, %arg4: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,10,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { //CHECK: %[[INDEX_TO_GRAB:.*]] = torch.constant.int 0 - + //CHECK: %[[CONST_0:.*]] = torch.constant.int 0 //CHECK: %[[ZERO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_0:.*]] : !torch.int -> !torch.vtensor<[1],si64> //CHECK: %[[STARTS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> @@ -985,7 +985,7 @@ func.func @test_slice(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtenso //CHECK: %[[STEPS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg4, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> //CHECK: %[[STEPS_ELEMENT_0:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int //CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %arg0, %[[AXES_ELEMENT_0:.*]], %[[STARTS_ELEMENT_0:.*]], %[[ENDS_ELEMENT_0:.*]], %[[STEPS_ELEMENT_0:.*]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,10,5],f32> - + //CHECK: %[[CONST_1:.*]] = torch.constant.int 1 //CHECK: %[[ONE_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_1:.*]] : !torch.int -> !torch.vtensor<[1],si64> //CHECK: %[[STARTS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[2],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> @@ -1013,7 +1013,7 @@ func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f3 //CHECK: %[[DEFAULT_SIZE_INPUT:.*]] = torch.prim.ListConstruct %[[DEFAULT_SIZE_AMOUNT:.*]] : (!torch.int) -> !torch.list //CHECK: %[[DEFAULT_SIZES:.*]] = torch.aten.ones %[[DEFAULT_SIZE_INPUT:.*]], %[[NONE_2:.*]], %[[NONE_2:.*]], %[[NONE_2:.*]], %[[NONE_2:.*]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64> //CHECK: %[[INDEX_TO_GRAB:.*]] = torch.constant.int 0 - + //CHECK: %[[CONST_0:.*]] = torch.constant.int 0 //CHECK: %[[ZERO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_0:.*]] : !torch.int -> !torch.vtensor<[1],si64> //CHECK: %[[STARTS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> @@ -1025,7 +1025,7 @@ func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f3 //CHECK: %[[STEPS_INDEX_VEC_0:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> //CHECK: %[[STEPS_ELEMENT_0:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int //CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %arg0, %[[AXES_ELEMENT_0:.*]], %[[STARTS_ELEMENT_0:.*]], %[[ENDS_ELEMENT_0:.*]], %[[STEPS_ELEMENT_0:.*]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32> - + //CHECK: %[[CONST_1:.*]] = torch.constant.int 1 //CHECK: %[[ONE_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_1:.*]] : !torch.int -> !torch.vtensor<[1],si64> //CHECK: %[[STARTS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> @@ -1037,7 +1037,7 @@ func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f3 //CHECK: %[[STEPS_INDEX_VEC_1:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> //CHECK: %[[STEPS_ELEMENT_1:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int //CHECK: %[[TWO_INDEX_VEC:.*]] = torch.aten.slice.Tensor %[[SLICE_0:.*]], %[[AXES_ELEMENT_1:.*]], %[[STARTS_ELEMENT_1:.*]], %[[ENDS_ELEMENT_1:.*]], %[[STEPS_ELEMENT_1:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32> - + //CHECK: %[[CONST_2:.*]] = torch.constant.int 2 //CHECK: %[[TWO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_2:.*]] : !torch.int -> !torch.vtensor<[1],si64> //CHECK: %[[STARTS_INDEX_VEC_2:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> @@ -1062,7 +1062,7 @@ func.func @test_slice_default_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: //CHECK: %[[DEFAULT_SIZE_INPUT:.*]] = torch.prim.ListConstruct %[[DEFAULT_SIZE_AMOUNT:.*]] : (!torch.int) -> !torch.list //CHECK: %[[DEFAULT_SIZES:.*]] = torch.aten.ones %[[DEFAULT_SIZE_INPUT:.*]], %[[NONE:.*]], %[[NONE:.*]], %[[NONE:.*]], %[[NONE:.*]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64> //CHECK: %[[INDEX_TO_GRAB:.*]] = torch.constant.int 0 - + //CHECK: %[[CONST_0:.*]] = torch.constant.int 0 //CHECK: %[[ZERO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_0:.*]] : !torch.int -> !torch.vtensor<[1],si64> //CHECK: %[[STARTS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> @@ -1074,7 +1074,7 @@ func.func @test_slice_default_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: //CHECK: %[[STEPS_INDEX_VEC_0:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> //CHECK: %[[STEPS_ELEMENT_0:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int //CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %arg0, %[[AXES_ELEMENT_0:.*]], %[[STARTS_ELEMENT_0:.*]], %[[ENDS_ELEMENT_0:.*]], %[[STEPS_ELEMENT_0:.*]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32> - + //CHECK: %[[CONST_1:.*]] = torch.constant.int 1 //CHECK: %[[ONE_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_1:.*]] : !torch.int -> !torch.vtensor<[1],si64> //CHECK: %[[STARTS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> @@ -1086,7 +1086,7 @@ func.func @test_slice_default_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: //CHECK: %[[STEPS_INDEX_VEC_1:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> //CHECK: %[[STEPS_ELEMENT_1:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int //CHECK: %[[TWO_INDEX_VEC:.*]] = torch.aten.slice.Tensor %[[SLICE_0:.*]], %[[AXES_ELEMENT_1:.*]], %[[STARTS_ELEMENT_1:.*]], %[[ENDS_ELEMENT_1:.*]], %[[STEPS_ELEMENT_1:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32> - + //CHECK: %[[CONST_1:.*]] = torch.constant.int 2 //CHECK: %[[TWO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_1:.*]] : !torch.int -> !torch.vtensor<[1],si64> //CHECK: %[[STARTS_INDEX_VEC_2:.*]] = torch.aten.index_select %arg1, %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> @@ -1339,7 +1339,7 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32> // CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si64> %0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64>) -> !torch.vtensor<[2],si64> return %0 : !torch.vtensor<[2],si64> - } + } // ----- @@ -1352,7 +1352,7 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32> // CHECK: torch.aten.arange.start_step %0, %1, %2, %none, %none, %none, %none : !torch.int, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2],si32> %0 = torch.operator "onnx.Range"(%arg0, %arg1, %arg2) : (!torch.vtensor<[],si32>, !torch.vtensor<[],si32>, !torch.vtensor<[],si32>) -> !torch.vtensor<[2],si32> return %0 : !torch.vtensor<[2],si32> - } + } // ----- @@ -1375,7 +1375,7 @@ func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32> // CHECK: return %[[RESULTS]]#0, %[[RESULTS]]#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> %0:2 = torch.operator "onnx.TopK"(%arg0, %arg1) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[1],si64>) -> (!torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64>) return %0#0, %0#1 : !torch.vtensor<[3,3],f32>, !torch.vtensor<[3,3],si64> - } + } // ----- diff --git a/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir b/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir index 6659935ffa6f..0a8bbfe1a8e3 100644 --- a/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir +++ b/test/Conversion/TorchOnnxToTorch/unsupported_fb_opt_ops.mlir @@ -11,8 +11,8 @@ func.func @cast_operation(%arg0: !torch.vtensor<[?,?,?,?],si64>) -> !torch.vtens } // ----- -func.func @div_operation(%arg0: !torch.vtensor<[1,64,768],f32>, - %arg1: !torch.vtensor<[1,64,1],f32>) +func.func @div_operation(%arg0: !torch.vtensor<[1,64,768],f32>, + %arg1: !torch.vtensor<[1,64,1],f32>) -> !torch.vtensor<[1,64,768],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %209 = torch.operator "onnx.Div"(%arg0, %arg1) : (!torch.vtensor<[1,64,768],f32>, !torch.vtensor<[1,64,1],f32>) -> !torch.vtensor<[1,64,768],f32> return %209 : !torch.vtensor<[1,64,768],f32> @@ -22,8 +22,8 @@ func.func @div_operation(%arg0: !torch.vtensor<[1,64,768],f32>, // Fixed. // this is the onnx opset 1 version of Equal, only int types. // this used to fail to legalize because the "since" value is set unecessarily high (19) -func.func @equal_operation(%arg0: !torch.vtensor<[4],si64>, - %arg1: !torch.vtensor<[4],si64>) +func.func @equal_operation(%arg0: !torch.vtensor<[4],si64>, + %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %205 = torch.operator "onnx.Equal"(%arg0, %arg1) : (!torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],i1> return %205 : !torch.vtensor<[4],i1> @@ -40,8 +40,8 @@ func.func @reduce_mean_operation(%arg0: !torch.vtensor<[1,64,768],f32>) // ----- // Fixed. -func.func @cumsum_operation(%arg0: !torch.vtensor<[2,3],f64>, - %arg1: !torch.vtensor<[],si32>) +func.func @cumsum_operation(%arg0: !torch.vtensor<[2,3],f64>, + %arg1: !torch.vtensor<[],si32>) -> !torch.vtensor<[2,3],f64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %212 = torch.operator "onnx.CumSum"(%arg0, %arg1) : (!torch.vtensor<[2,3],f64>, !torch.vtensor<[],si32>) -> !torch.vtensor<[2,3],f64> return %212 : !torch.vtensor<[2,3],f64> diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index 486b8b641dfd..cfb252cd104a 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -45,7 +45,7 @@ func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch // CHECK-LABEL: func.func @torch.aten.mm$basic_strict( // CHECK-NOT: assert -func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> +func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,2],f32> attributes {torch.assume_strict_symbolic_shapes} { %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,2],f32> @@ -56,7 +56,7 @@ func.func @torch.aten.mm$basic_strict(%arg0: !torch.vtensor<[?,?],f32>, %arg1: ! // CHECK-LABEL: func.func @torch.aten.mm$basic_unsigned( // CHECK: linalg.matmul_unsigned -func.func @torch.aten.mm$basic_unsigned(%arg0: !torch.vtensor<[?,?],ui32>, %arg1: !torch.vtensor<[?,?],ui32>) -> !torch.vtensor<[?,2],ui32> +func.func @torch.aten.mm$basic_unsigned(%arg0: !torch.vtensor<[?,?],ui32>, %arg1: !torch.vtensor<[?,?],ui32>) -> !torch.vtensor<[?,2],ui32> attributes {torch.assume_strict_symbolic_shapes} { %0 = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[?,?],ui32>, !torch.vtensor<[?,?],ui32> -> !torch.vtensor<[?,2],ui32> @@ -324,7 +324,7 @@ func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torc // ----- // CHECK-LABEL: func.func @torch.aten.cat( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %int0 = torch.constant.int 0 // CHECK: %[[VAL_0:.*]] = torch.prim.ListConstruct %[[ARG_0]], %[[ARG_1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index 70f543ad4f74..8ed75f648f5e 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -50,13 +50,13 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch. %dilation1 = torch.constant.int 3 %dilation2 = torch.constant.int 3 %dilation3 = torch.constant.int 3 - + %false = torch.constant.bool false %kernel_size = torch.prim.ListConstruct %kernel_size1, %kernel_size2, %kernel_size3 : (!torch.int, !torch.int, !torch.int) -> !torch.list %stride = torch.prim.ListConstruct %stride1, %stride2, %stride3 : (!torch.int, !torch.int, !torch.int) -> !torch.list %padding = torch.prim.ListConstruct %padding1, %padding2, %padding3 : (!torch.int, !torch.int, !torch.int) -> !torch.list %dilation = torch.prim.ListConstruct %dilation1, %dilation2, %dilation3 : (!torch.int, !torch.int, !torch.int) -> !torch.list - + %4 = torch.aten.max_pool3d %arg0, %kernel_size, %stride, %padding, %dilation, %false : !torch.vtensor<[?,?,?,?,?],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[?,?,?,?,?],f32> // CHECK: %[[MIN_VALUE:.*]] = arith.constant 0xFF800000 : f32 @@ -64,13 +64,13 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch. // CHECK-NEXT: ^bb0(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index): // CHECK-NEXT: tensor.yield %[[MIN_VALUE:.*]] : f32 // CHECK: } : tensor to tensor - + // CHECK: %[[OUTPUT_TENSOR:.*]] = linalg.fill ins(%[[MIN_VALUE:.*]] : f32) outs(%{{.*}} : tensor) -> tensor // CHECK: %[[MAX_3D_POOL:.*]] = linalg.generic {indexing_maps = [#map, #map1, #map2], iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel", "reduction", "reduction", "reduction"]} ins(%[[PADDED_INPUT_TENSOR:.*]], %{{.*}} : tensor, tensor) outs(%[[OUTPUT_TENSOR:.*]] : tensor) { // CHECK-NEXT: ^bb0(%[[CURRENT_VALUE:.*]]: f32, %[[KERNEL:.*]]: f32, %[[ACC_OUT:.*]]: f32): // CHECK-NEXT: %[[MAXF:.*]] = arith.maximumf %[[CURRENT_VALUE:.*]], %[[ACC_OUT:.*]] : f32 // CHECK-NEXT: linalg.yield %[[MAXF:.*]] : f32 // CHECK: } -> tensor - + return %4 : !torch.vtensor<[?,?,?,?,?],f32> } diff --git a/test/Conversion/TorchToLinalg/view.mlir b/test/Conversion/TorchToLinalg/view.mlir index 83424a17d843..4f9c1f867ee4 100644 --- a/test/Conversion/TorchToLinalg/view.mlir +++ b/test/Conversion/TorchToLinalg/view.mlir @@ -148,8 +148,8 @@ func.func @torch.aten.view$expandInferredDim(%arg0: !torch.vtensor<[2,6],f32>) - // CHECK: return %[[BUILTIN_TENSOR_CAST]] : !torch.vtensor<[2,3,5,?,6],f32> // [10,3,?,2,3] -> [30,?,6] -> [2,3,5,?,6] -// Associations are, -// -- for collapse, [0,1], [2], [3,4] and +// Associations are, +// -- for collapse, [0,1], [2], [3,4] and // -- for expand [0,1,2], [3], [4]. func.func @torch.aten.view$singleUnknownMatches0(%arg0: !torch.vtensor<[10,3,?,2,3],f32>) -> !torch.vtensor<[2,3,5,?,6],f32> { %int3 = torch.constant.int 3 diff --git a/test/Conversion/TorchToStablehlo/basic.mlir b/test/Conversion/TorchToStablehlo/basic.mlir index 51e3e6f9bdbb..e0ab6bf1502b 100644 --- a/test/Conversion/TorchToStablehlo/basic.mlir +++ b/test/Conversion/TorchToStablehlo/basic.mlir @@ -281,7 +281,7 @@ func.func @torch.aten.cat$convert(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torc // ----- // CHECK-LABEL: func.func @torch.aten.cat( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],f32>, +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?],f32>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { // CHECK: %int0 = torch.constant.int 0 // CHECK: %[[VAL_0:.*]] = torch.prim.ListConstruct %[[ARG_0]], %[[ARG_1]] : (!torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32>) -> !torch.list diff --git a/test/Conversion/TorchToStablehlo/linear.mlir b/test/Conversion/TorchToStablehlo/linear.mlir index b9bac97ca6c9..7f253a98df04 100644 --- a/test/Conversion/TorchToStablehlo/linear.mlir +++ b/test/Conversion/TorchToStablehlo/linear.mlir @@ -269,7 +269,7 @@ func.func @torch.aten.mm$proj(%arg0: !torch.vtensor<[?,256],f32>) -> !torch.vten // ----- // CHECK-LABEL: func.func @torch.aten.convolution( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[?,?,3,3],f32>) -> !torch.vtensor<[?,?,?,?],f32> { // CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor // CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,3,3],f32> -> tensor @@ -306,7 +306,7 @@ func.func @torch.aten.convolution(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: ! // ----- // CHECK-LABEL: func.func @torch.aten.convolution$bias( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG_1:.*]]: !torch.vtensor<[?,?,3,3],f32>, +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[?,?,?,?],f32>, %[[ARG_1:.*]]: !torch.vtensor<[?,?,3,3],f32>, // CHECK-SAME: %[[ARG_2:.*]]: !torch.vtensor<[?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { // CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[?,?,?,?],f32> -> tensor // CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[?,?,3,3],f32> -> tensor @@ -349,7 +349,7 @@ func.func @torch.aten.convolution$bias(%arg0: !torch.vtensor<[?,?,?,?],f32>, %ar // ----- // CHECK-LABEL: func.func @torch.aten.convolution$transposed_basic( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,9,9],f32> { // CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> // CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,4,3,3],f32> -> tensor<2x4x3x3xf32> @@ -380,7 +380,7 @@ func.func @torch.aten.convolution$transposed_basic(%arg0: !torch.vtensor<[1,2,7, // ----- // CHECK-LABEL: func.func @torch.aten.convolution$transposed_stride( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,15,15],f32> { // CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> // CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,4,3,3],f32> -> tensor<2x4x3x3xf32> @@ -415,7 +415,7 @@ func.func @torch.aten.convolution$transposed_stride(%arg0: !torch.vtensor<[1,2,7 // ----- // CHECK-LABEL: func.func @torch.aten.convolution$transposed_outputpadding( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,4,3,3],f32>) -> !torch.vtensor<[1,4,16,16],f32> { // CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> // CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,4,3,3],f32> -> tensor<2x4x3x3xf32> @@ -450,7 +450,7 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor // ----- // CHECK-LABEL: func.func @torch.aten.convolution$transposed_groups( -// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, +// CHECK-SAME: %[[ARG_0:.*]]: !torch.vtensor<[1,2,7,7],f32>, // CHECK-SAME: %[[ARG_1:.*]]: !torch.vtensor<[2,2,3,3],f32>) -> !torch.vtensor<[1,4,15,15],f32> { // CHECK: %[[T_0:.*]] = torch_c.to_builtin_tensor %[[ARG_0]] : !torch.vtensor<[1,2,7,7],f32> -> tensor<1x2x7x7xf32> // CHECK: %[[T_1:.*]] = torch_c.to_builtin_tensor %[[ARG_1]] : !torch.vtensor<[2,2,3,3],f32> -> tensor<2x2x3x3xf32> @@ -485,7 +485,7 @@ func.func @torch.aten.convolution$transposed_outputpadding(%arg0: !torch.vtensor // CHECK: %[[T_15:.*]] = stablehlo.transpose %[[T_14]], dims = [0, 1, 3, 2, 4] : (tensor<3x3x2x2x1xf32>) -> tensor<3x3x2x2x1xf32> // CHECK: %from_elements_3 = tensor.from_elements %[[T_8]], %[[T_9]], %[[T_13]], %[[T_12]] : tensor<4xi64> // CHECK: %[[T_16:.*]] = stablehlo.dynamic_reshape %[[T_15]], %from_elements_3 : (tensor<3x3x2x2x1xf32>, tensor<4xi64>) -> tensor<3x3x4x1xf32> -// CHECK: %[[T_17:.*]] = stablehlo.convolution(%[[T_0]], %[[T_16]]) +// CHECK: %[[T_17:.*]] = stablehlo.convolution(%[[T_0]], %[[T_16]]) // CHECK{LITERAL}: dim_numbers = [b, f, 0, 1]x[0, 1, o, i]->[b, f, 0, 1], window = {stride = [1, 1], pad = [[2, 2], [2, 2]], lhs_dilate = [2, 2], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 2 : i64} : (tensor<1x2x7x7xf32>, tensor<3x3x4x1xf32>) -> tensor<1x4x15x15xf32> // CHECK: %[[T_18:.*]] = torch_c.from_builtin_tensor %[[T_17]] : tensor<1x4x15x15xf32> -> !torch.vtensor<[1,4,15,15],f32> // CHECK: return %[[T_18]] : !torch.vtensor<[1,4,15,15],f32> diff --git a/test/Conversion/TorchToStablehlo/pooling.mlir b/test/Conversion/TorchToStablehlo/pooling.mlir index 426a43542477..fd531006d614 100644 --- a/test/Conversion/TorchToStablehlo/pooling.mlir +++ b/test/Conversion/TorchToStablehlo/pooling.mlir @@ -50,7 +50,7 @@ func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK: ^bb0(%[[VAL_8:.*]]: tensor, %[[VAL_9:.*]]: tensor): // CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor // CHECK: stablehlo.return %[[VAL_10]] : tensor -// CHECK: }) +// CHECK: }) // CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?,?,?],f32> diff --git a/test/Conversion/TorchToTosa/conv2d_transpose.mlir b/test/Conversion/TorchToTosa/conv2d_transpose.mlir index 678034cb8405..7f0d5e2ab25b 100644 --- a/test/Conversion/TorchToTosa/conv2d_transpose.mlir +++ b/test/Conversion/TorchToTosa/conv2d_transpose.mlir @@ -12,7 +12,7 @@ func.func @forward(%input: !torch.vtensor<[1,64,1,100],f32>) -> !torch.vtensor<[ %stride = torch.prim.ListConstruct %int2, %int2 : (!torch.int, !torch.int) -> !torch.list %int1x1 = torch.prim.ListConstruct %int1, %int1 : (!torch.int, !torch.int) -> !torch.list // expected-error@+1 {{failed to legalize operation 'torch.aten.convolution' that was explicitly marked illegal}} - %output = torch.aten.convolution %input, %weight, %bias, %stride, %int1x1, %int1x1, %true, %int1x1, %int1 : !torch.vtensor<[1,64,1,100],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,64,2,200],f32> + %output = torch.aten.convolution %input, %weight, %bias, %stride, %int1x1, %int1x1, %true, %int1x1, %int1 : !torch.vtensor<[1,64,1,100],f32>, !torch.vtensor<[64,64,3,3],f32>, !torch.vtensor<[64],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,64,2,200],f32> return %output : !torch.vtensor<[1,64,2,200],f32> } From 2ef228328f327cb4bddbdbfcdb5476481c8a55b8 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 25 Jan 2024 16:40:21 -0800 Subject: [PATCH 119/283] [torch] `torch.dequantize` for per channel tensors to` linalg` (#2769) Support a lowering for dequantization for per channel tensors from `torch` dialect to a linalg decomposition. Tested via a numerical `torch` test. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 53 ++++++++ .../TorchToLinalg/Uncategorized.cpp | 114 ++++++++++++++++-- .../Transforms/AbstractInterpLibrary.cpp | 32 +++++ .../base_lazy_backend/shape_inference.cpp | 20 +++ projects/pt1/e2e_testing/xfail_sets.py | 1 + .../build_tools/abstract_interp_lib_gen.py | 17 +++ .../build_tools/torch_ods_gen.py | 2 + .../test_suite/elementwise.py | 27 +++++ 8 files changed, 258 insertions(+), 8 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index a46c79acb941..c09900ce8ecc 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -14465,6 +14465,33 @@ def Torch_AtenLeakyReluBackwardOp : Torch_Op<"aten.leaky_relu_backward", [ }]; } +def Torch_AtenQuantizePerChannelOp : Torch_Op<"aten.quantize_per_channel", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$scales, + AnyTorchTensorType:$zero_points, + Torch_IntType:$axis, + Torch_IntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenQuantizePerChannelOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenQuantizePerChannelOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_AtenQuantizePerTensorOp : Torch_Op<"aten.quantize_per_tensor", [ AllowsTypeRefinement, HasValueSemantics, @@ -14560,6 +14587,32 @@ def Torch_AtenIntReprOp : Torch_Op<"aten.int_repr", [ }]; } +def Torch_Aten_MakePerChannelQuantizedTensorOp : Torch_Op<"aten._make_per_channel_quantized_tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::_make_per_channel_quantized_tensor : (Tensor, Tensor, Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$scale, + AnyTorchTensorType:$zero_point, + Torch_IntType:$axis + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult Aten_MakePerChannelQuantizedTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 4, 1); + } + void Aten_MakePerChannelQuantizedTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 4, 1); + } + }]; +} + def Torch_Aten_MakePerTensorQuantizedTensorOp : Torch_Op<"aten._make_per_tensor_quantized_tensor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 593afeb1aa84..9ff4c63741b2 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1344,7 +1344,7 @@ static Value createLinalgPayloadCalculationForElementwiseOp( auto makeQTensor = qtensor.getDefiningOp(); if (!makeQTensor) { - op->emitError( + op->emitWarning( "unimplemented: dequantizing tensor of unknown scale / zero-point"); return nullptr; } @@ -2221,16 +2221,109 @@ class ConvertAtenIntReprOp : public OpConversionPattern { } // namespace namespace { -class ConvertMakePerTensorQuantizedTensorOp - : public OpConversionPattern { +class ConvertDequantizePerChannel + : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(Aten_MakePerTensorQuantizedTensorOp op, OpAdaptor adaptor, + matchAndRewrite(AtenDequantizeSelfOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - RankedTensorType resultType = getTypeConverter() - ->convertType(op->getResult(0).getType()) - .cast(); + auto loc = op.getLoc(); + auto qoperand = op.getOperand(); + auto make = qoperand.getDefiningOp(); + if (!make) { + llvm::errs() << "Did not find make per channel\n"; + return rewriter.notifyMatchFailure(op, "did not find per channel qint"); + } + + auto converter = getTypeConverter(); + auto operand = make.getOperand(0); + auto scale = make.getScale(); + auto zeropoint = make.getZeroPoint(); + auto axis = make.getAxis(); + + IntegerAttr axisAttr; + if (!matchPattern(axis, m_Constant(&axisAttr))) { + return failure(); + } + + auto operandDTy = operand.getType().cast().getDtype(); + auto zeropointDTy = zeropoint.getType().cast().getDtype(); + operand = converter->materializeTargetConversion( + rewriter, loc, converter->convertType(operand.getType()), operand); + scale = converter->materializeTargetConversion( + rewriter, loc, converter->convertType(scale.getType()), scale); + zeropoint = converter->materializeTargetConversion( + rewriter, loc, converter->convertType(zeropoint.getType()), zeropoint); + + auto resultType = converter->convertType(op->getResult(0).getType()) + .cast(); + + llvm::SmallVector dynSizes; + for (auto [index, dim] : llvm::enumerate(resultType.getShape())) { + if (ShapedType::isDynamic(dim)) { + dynSizes.push_back(rewriter.create(loc, operand, index)); + } + } + + llvm::SmallVector iterators( + resultType.getRank(), utils::IteratorType::parallel); + llvm::SmallVector maps( + 4, {rewriter.getMultiDimIdentityMap(resultType.getRank())}); + auto broadcastMap = AffineMap::get( + resultType.getRank(), /*symbolCount=*/0, + {rewriter.getAffineDimExpr(axisAttr.getInt())}, rewriter.getContext()); + maps[1] = broadcastMap; + maps[2] = broadcastMap; + + auto empty = + rewriter.create(op.getLoc(), resultType, dynSizes); + auto linalgOp = rewriter.create( + loc, resultType, ValueRange{operand, scale, zeropoint}, + ValueRange{empty}, maps, iterators, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value operand = args[0]; + Value scale = args[1]; + Value zeropoint = args[2]; + if (operandDTy.isUnsignedInteger(8)) { + operand = b.create(loc, b.getI32Type(), operand); + } else if (operandDTy.isSignedInteger(8)) { + operand = b.create(loc, b.getI32Type(), operand); + } + + if (zeropointDTy.isUnsignedInteger(8)) { + zeropoint = + b.create(loc, b.getI32Type(), zeropoint); + } else if (zeropointDTy.isSignedInteger(8)) { + zeropoint = + b.create(loc, b.getI32Type(), zeropoint); + } + + Value sub = rewriter.create(loc, operand, zeropoint); + Value fp = + rewriter.create(loc, args[3].getType(), sub); + Value mul = rewriter.create(loc, fp, scale); + b.create(loc, mul); + }); + rewriter.replaceOp(op, linalgOp.getResults()); + return success(); + } +}; +} // namespace + +namespace { + +template +class ConvertCastEquivalentOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename OpTy::Adaptor; + + LogicalResult + matchAndRewrite(OpTy op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto converter = this->getTypeConverter(); + RankedTensorType resultType = cast( + converter->convertType(op->getResult(0).getType())); rewriter.replaceOpWithNewOp(op, resultType, adaptor.getSelf()); return success(); @@ -2283,6 +2376,11 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); - patterns.add(typeConverter, context); + patterns.add>( + typeConverter, context); + target.addIllegalOp(); + patterns.add>( + typeConverter, context); target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 590bea8d7176..bb9717303e6b 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6549,6 +6549,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.quantize_per_channel\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.int, %arg4: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.quantize_per_tensor\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6565,6 +6569,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten._make_per_channel_quantized_tensor\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.int) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten._make_per_tensor_quantized_tensor\"(%arg0: !torch.list, %arg1: !torch.float, %arg2: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -12632,6 +12640,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_channel\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.int, %arg4: !torch.int) -> !torch.int {\n" +" return %arg4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.quantize_per_tensor\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" " return %arg3 : !torch.int\n" " }\n" @@ -12664,6 +12675,27 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten._make_per_channel_quantized_tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.int) -> !torch.int {\n" +" %int14 = torch.constant.int 14\n" +" %int12 = torch.constant.int 12\n" +" %int1 = torch.constant.int 1\n" +" %int13 = torch.constant.int 13\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int13 : !torch.int\n" +" } else {\n" +" %3 = torch.aten.eq.int %0#1, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int12 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int14 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._make_per_tensor_quantized_tensor\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.int) -> !torch.int {\n" " %int14 = torch.constant.int 14\n" " %int12 = torch.constant.int 12\n" diff --git a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp index ff43359ebe80..325e89e14d5e 100644 --- a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp +++ b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp @@ -39,6 +39,20 @@ std::vector compute_shape_div(const at::Tensor& self, return {Shape(self.scalar_type(), self.sizes().vec())}; } +std::vector +compute_shape__make_per_channel_quantized_tensor(const at::Tensor &self, + const at::Tensor &scale, + const at::Tensor &zero_point, + int64_t axis) { + if (self.scalar_type() == at::kChar) + return {Shape(at::kQInt8, self.sizes().vec())}; + if (self.scalar_type() == at::kByte) + return {Shape(at::kQUInt8, self.sizes().vec())}; + if (self.scalar_type() == at::kInt) + return {Shape(at::kQInt32, self.sizes().vec())}; + assert(false); +} + std::vector compute_shape__make_per_tensor_quantized_tensor( const at::Tensor &self, double scale, int64_t zero_point) { if (self.scalar_type() == at::kChar) @@ -75,6 +89,12 @@ std::vector compute_shape_isinf(const at::Tensor& self) { return {Shape(at::kBool, self.sizes().vec())}; } +std::vector compute_shape_quantize_per_channel( + const at::Tensor &self, const at::Tensor &scales, + const at::Tensor &zero_points, int64_t axis, at::ScalarType dtype) { + return {Shape(dtype, self.sizes().vec())}; +} + std::vector compute_shape_max_pool3d_with_indices( const at::Tensor& self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index f0261b16f6af..f43c325069ce 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -313,6 +313,7 @@ "GroupNormNoWeightAndBiasModule_basic", # Dynamo does not support tracing quantized tensors + "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", "ElementwiseQuantizePerTensorModule_basic", "AtenMmQuint8_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 28e87cc60990..91e98d99c9ff 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -251,6 +251,9 @@ def aten〇clamp_max〡shape(self: List[int], max: float) -> List[int]: def aten〇rsub〇Scalar〡shape(self: List[int], other: float, alpha: float = 1) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇quantize_per_channel〡shape(self: List[int], scales: List[int], zero_points: List[int], axis: int, dtype: int) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇quantize_per_tensor〡shape(self: List[int], scale: float, zero_point: int, dtype: int) -> List[int]: return upstream_shape_functions.unary(self) @@ -263,6 +266,9 @@ def aten〇dequantize〇tensor〡shape(qtensor: List[int]) -> List[int]: def aten〇int_repr〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇_make_per_channel_quantized_tensor〡shape(self: List[int], scale: List[int], zero_point: List[int], axis: int) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇_make_per_tensor_quantized_tensor〡shape(self: List[int], scale: float, zero_point: int) -> List[int]: return upstream_shape_functions.unary(self) @@ -4280,6 +4286,9 @@ def prims〇collapse〡dtype(a_rank_dtype: Tuple[int, int], start: int, end: int return a_dtype +def aten〇quantize_per_channel〡dtype(self_rank_dtype: Tuple[int, int], scales_rank_dtype: Tuple[int, int], zero_points_rank_dtype: Tuple[int, int], axis: int, dtype: int) -> int: + return dtype + def aten〇quantize_per_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int, dtype: int) -> int: return dtype @@ -4297,6 +4306,14 @@ def aten〇int_repr〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.int8 return torch.int32 +def aten〇_make_per_channel_quantized_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale_rank_dtype: Tuple[int, int], zero_point_rank_dtype: Tuple[int, int], axis: int) -> int: + self_rank, self_dtype = self_rank_dtype + if (self_dtype == torch.uint8): + return torch.quint8 + if (self_dtype == torch.int8): + return torch.qint8 + return torch.qint32 + def aten〇_make_per_tensor_quantized_tensor〡dtype(self_rank_dtype: Tuple[int, int], scale: float, zero_point: int) -> int: self_rank, self_dtype = self_rank_dtype if (self_dtype == torch.uint8): diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index ae4c608c6de7..3b930c20e79d 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -820,10 +820,12 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::leaky_relu_backward : (Tensor, Tensor, Scalar, bool) -> (Tensor)") # quantized ops + emit("aten::quantize_per_channel : (Tensor, Tensor, Tensor, int, int) -> (Tensor)") emit("aten::quantize_per_tensor : (Tensor, float, int, int) -> (Tensor)") emit("aten::dequantize.self : (Tensor) -> (Tensor)") emit("aten::dequantize.tensor : (Tensor) -> (Tensor)") emit("aten::int_repr : (Tensor) -> (Tensor)") + emit("aten::_make_per_channel_quantized_tensor : (Tensor, Tensor, Tensor, int) -> (Tensor)") emit("aten::_make_per_tensor_quantized_tensor : (Tensor, float, int) -> (Tensor)") # ========================================================================== diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index a422772fc298..26eac617a4a6 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -4328,6 +4328,33 @@ def ElementwiseDequantizePerTensorModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseDequantizePerChannelModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([3, 4], torch.int8, True), + ([4], torch.int8, True), + ([4], torch.float, True), + ]) + def forward(self, x, zeropoint, scale): + qx = torch._make_per_channel_quantized_tensor(x, scale, zeropoint, axis=1) + qx = torch.dequantize(qx) + return qx + +@register_test_case(module_factory=lambda: ElementwiseDequantizePerChannelModule()) +def ElementwiseDequantizePerChannelModule_basic(module, tu: TestUtils): + module.forward( + tu.randint(3, 4, low=-128, high=127).to(torch.int8), + tu.randint(4, low=-128, high=127).to(torch.int8), + tu.rand(4) + ) + +# ============================================================================== + class GluStaticModule(torch.nn.Module): def __init__(self): super().__init__() From e73c5368fb26ea80ad4ba495f699b92cb4f2bc73 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Fri, 26 Jan 2024 09:01:47 +0800 Subject: [PATCH 120/283] [FxImporter] make FxImporter to fit python<=3.9 (#2802) As that torch with py3.9 is also used widely. --- python/torch_mlir/extras/fx_importer.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 9ec90e766c46..d799d61f6a92 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -5,10 +5,16 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. +try: + from types import NoneType +except ImportError: + # python less than 3.10 doesn't have NoneType + NoneType = type(None) + import logging import operator import re -from types import NoneType, BuiltinMethodType, BuiltinFunctionType +from types import BuiltinMethodType, BuiltinFunctionType from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union import weakref From 4964977e8571e7f14c25581188f7cba4cf6904ca Mon Sep 17 00:00:00 2001 From: Phaneesh Barwaria Date: Fri, 26 Jan 2024 23:06:39 +0530 Subject: [PATCH 121/283] [ONNX][MLIR] support constantOfShape op (#2747) --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 81 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 72 +++++++++++++++++ 2 files changed, 153 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 54cfb3e2ab13..bed22b08407a 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1472,4 +1472,85 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); + patterns.onOp( + "ConstantOfShape", 20, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value shape; + if (binder.tensorOperand(shape) || binder.tensorResultType(resultType)) + return failure(); + + // convert shape tensor to list of ints + auto shapeSizes = + dyn_cast(shape.getType()).getSizes(); + SmallVector dimList; + SmallVector selectSizes; + selectSizes.push_back(1); + Torch::BaseTensorType shapeType = + shape.getType().cast(); + Type selectResultType = shapeType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype()); + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + for (int i = 0; i < shapeSizes[0]; i++) { + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, shape, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + dimList.push_back(dim); + } + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + Value noneVal = rewriter.create(binder.getLoc()); + + // Get fill_value if it is present. + // Assumption : resultDType and value attr type match. + Value value_const; + auto attr = binder.op->getAttr("torch.onnx.value"); + auto resultDType = resultType.getDtype(); + + // Extract the fill value and dtype + // ONNX requires value attr to be a tensor + if (!attr) { + attr = DenseElementsAttr::get( + resultType.toBuiltinTensor().clone(resultDType), + rewriter.getFloatAttr(resultDType, 0.0)); + } + if (!isa(attr)) { + return rewriter.notifyMatchFailure( + binder.op, "`value` attr needs to be a tensor."); + } + + auto denseAttr = attr.cast(); + auto denseAttrEleType = denseAttr.getElementType(); + if (!isa(denseAttrEleType)) { + return rewriter.notifyMatchFailure( + binder.op, + "`value` attr tensor only supports types int and float for now."); + } + + // Create constant op for value + if (denseAttrEleType.isa()) { + int64_t intVal = denseAttr.getSplatValue().getSInt(); + value_const = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(intVal)); + } + if (denseAttrEleType.isa()) { + float floatVal = + denseAttr.getSplatValue().getValue().convertToFloat(); + value_const = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(floatVal)); + } + + rewriter.replaceOpWithNewOp( + binder.op, resultType, dimValueList, value_const, /*dtype=*/noneVal, + /*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal); + return success(); + }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 493cdc98312f..2c06567bde97 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1413,3 +1413,75 @@ func.func @test_flatten_1d_axis_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vten %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[2,1],f32> return %0 : !torch.vtensor<[2,1],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_constant_of_shape_dense_float_default +func.func @test_constant_of_shape_dense_float_default() -> !torch.vtensor<[2,3,4], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[SHAPE_CST:.*]] = torch.vtensor.literal(dense<[2, 3, 4]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[ELE_0]], %[[ELE_1]], %[[ELE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FILL_VAL:.*]] = torch.constant.float 0.000000e+00 + // CHECK: %[[ATEN_FULL:.*]] = torch.aten.full %[[DIM_LIST]], %[[FILL_VAL]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3,4],f32> + %cst = torch.vtensor.literal(dense<[2,3,4]> : tensor<3xsi64>) : !torch.vtensor<[3], si64> + %0 = "torch.operator"(%cst) <{name = "onnx.ConstantOfShape"}> : (!torch.vtensor<[3], si64>) -> !torch.vtensor<[2,3,4], f32> + return %0 : !torch.vtensor<[2,3,4], f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_constant_of_shape_dense_float_cst +func.func @test_constant_of_shape_dense_float_cst() -> !torch.vtensor<[2,3,4], f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[SHAPE_CST:.*]] = torch.vtensor.literal(dense<[2, 3, 4]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[ELE_0]], %[[ELE_1]], %[[ELE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FILL_VAL:.*]] = torch.constant.float 3.4000000953674316 + // CHECK: %[[ATEN_FULL:.*]] = torch.aten.full %[[DIM_LIST]], %[[FILL_VAL]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.float, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3,4],f32> + %cst = torch.vtensor.literal(dense<[2,3,4]> : tensor<3xsi64>) : !torch.vtensor<[3], si64> + %0 = "torch.operator"(%cst) <{name = "onnx.ConstantOfShape"}> {torch.onnx.value = dense<3.4> : tensor<1xf32>}: (!torch.vtensor<[3], si64>) -> !torch.vtensor<[2,3,4], f32> + return %0 : !torch.vtensor<[2,3,4], f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_constant_of_shape_dense_int_cst +func.func @test_constant_of_shape_dense_int_cst() -> !torch.vtensor<[2,3,4], si64> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[SHAPE_CST:.*]] = torch.vtensor.literal(dense<[2, 3, 4]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[ELE_0]], %[[ELE_1]], %[[ELE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[FILL_VAL:.*]] = torch.constant.int 3 + // CHECK: %[[ATEN_FULL:.*]] = torch.aten.full %[[DIM_LIST]], %[[FILL_VAL]], %[[NONE]], %[[NONE]], %[[NONE]], %[[NONE]] : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[2,3,4],si64> + %cst = torch.vtensor.literal(dense<[2,3,4]> : tensor<3xsi64>) : !torch.vtensor<[3], si64> + %0 = "torch.operator"(%cst) <{name = "onnx.ConstantOfShape"}> {torch.onnx.value = dense<3> : tensor<1xsi64>}: (!torch.vtensor<[3], si64>) -> !torch.vtensor<[2,3,4], si64> + return %0 : !torch.vtensor<[2,3,4], si64> +} From da7c6d2c16439ff5a009bd660567d38158fc3049 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Fri, 26 Jan 2024 23:16:54 +0530 Subject: [PATCH 122/283] [MLIR][TORCH] Add support for dynamic shape for Onnx.Transpose op (#2803) Signed-Off By: Vivek Khandelwal --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 6 ++++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 12 ++++++++++++ 2 files changed, 18 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 2ead942ded1b..6569e3abc0b5 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1244,6 +1244,12 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( current[i] = i; } + // Convert dynamic shape dimension. + for (unsigned i = 0; i < shape.size(); i++){ + if (shape[i] == ShapedType::kDynamic) + shape[i] = Torch::kUnknownSize; + } + for (int64_t i = 0; i < rank; ++i) { if (current[i] == permutations[i]) continue; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 918141065cc6..a9f6098a26d2 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -968,6 +968,18 @@ func.func @test_transpose_all_permutations_4(%arg0: !torch.vtensor<[2,3,4],f32>) return %0 : !torch.vtensor<[4,2,3],f32> } +// ----- + +// CHECK-LABEL: func.func @test_transpose_dynamic +func.func @test_transpose_dynamic(%arg0: !torch.vtensor<[?,32,5,128],f32>) -> !torch.vtensor<[?,5,32,128],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK-DAG: %[[I1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[I2:.+]] = torch.constant.int 2 + // CHECK: %[[TRANSPOSE:.+]] = torch.aten.transpose.int %arg0, %[[I1]], %[[I2]] : !torch.vtensor<[?,32,5,128],f32>, !torch.int, !torch.int -> !torch.vtensor<[?,5,32,128],f32> + %0 = torch.operator "onnx.Transpose"(%arg0) {torch.onnx.perm = [0 : si64, 2 : si64, 1 : si64, 3 : si64]} : (!torch.vtensor<[?,32,5,128],f32>) -> !torch.vtensor<[?,5,32,128],f32> + return %0 : !torch.vtensor<[?,5,32,128],f32> +} + + // ----- // CHECK-LABEL: func.func @test_slice From 46a25d72412c1bb00bd947a44b1c7dde7bd7ef53 Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Fri, 26 Jan 2024 10:54:59 -0800 Subject: [PATCH 123/283] [torch-mlir][sparse] preserve sparsity during lowering torch to linalg (#2809) This preserves sparsity at the most obvious places of lowering TORCH tensors to MLIR RankedTensorType tensors. Other places are marked for audit. With some initial lowering tests. --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 3 ++ .../TorchToLinalg/IndirectDataMovement.cpp | 4 +-- lib/Conversion/TorchToLinalg/Linear.cpp | 11 +++--- lib/Conversion/TorchToLinalg/Utils.cpp | 1 + lib/Dialect/Torch/IR/TorchTypes.cpp | 3 +- test/Conversion/TorchToLinalg/sparse.mlir | 36 +++++++++++++++++++ 6 files changed, 50 insertions(+), 8 deletions(-) create mode 100644 test/Conversion/TorchToLinalg/sparse.mlir diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 297a0f4c2be6..e96d65970b82 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -978,6 +978,7 @@ class ConvertAtenViewOp : public OpConversionPattern { return success(); } + // TODO: audit possibility of sparsity on these tensors Type adjustedResultType = RankedTensorType::get( makeShapeLLVMCompatible(outputShape), resultType.getElementType()); Type adjustedInputType = RankedTensorType::get( @@ -1005,6 +1006,7 @@ class ConvertAtenViewOp : public OpConversionPattern { intermediateShape.push_back(sum); } + // TODO: audit possibility of sparsity on these tensor Type intermediateResultType = RankedTensorType::get(makeShapeLLVMCompatible(intermediateShape), resultType.getElementType()); @@ -1657,6 +1659,7 @@ class ConvertAtenSliceScatterOp auto srcType = src.getType().cast(); int64_t srcRank = srcType.getRank(); SmallVector srcAbstractSizes(srcRank, kUnknownSize); + // TODO: audit possibility of sparsity on these tensor auto abstractSrcType = RankedTensorType::get( makeShapeLLVMCompatible(srcAbstractSizes), srcType.getElementType()); Value abstractSrc = diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index 277341bea874..f9ee56070d61 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -206,8 +206,8 @@ namespace { // // TODO: Find an optimal lowering. // current lowering is not optimal for bags of large embeddings. -// Since it traverses the output tensor multiple times. -// +// Since it traverses the output tensor multiple times. +// // class ConvertAtenEmbeddingBagPaddingIdxOp diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index d818b99c0c4a..6d0d72075d76 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -377,8 +377,8 @@ class ConvertAtenMatmulOp : public OpConversionPattern { // TODO: Improve usage of static shape information. SmallVector lhsTargetShape(lhsBroadcastToShape.size(), ShapedType::kDynamic); - auto lhsBroadcastType = - RankedTensorType::get(lhsTargetShape, lhsType.getElementType()); + auto lhsBroadcastType = RankedTensorType::get( + lhsTargetShape, lhsType.getElementType(), lhsType.getEncoding()); if (failed(torch_to_linalg::broadcastToGivenShape( op, rewriter, lhs, lhsBroadcastToShape, lhsBroadcastType, broadcastedLhs))) { @@ -387,8 +387,8 @@ class ConvertAtenMatmulOp : public OpConversionPattern { } SmallVector rhsTargetShape(rhsBroadcastToShape.size(), ShapedType::kDynamic); - auto rhsBroadcastType = - RankedTensorType::get(rhsTargetShape, rhsType.getElementType()); + auto rhsBroadcastType = RankedTensorType::get( + rhsTargetShape, rhsType.getElementType(), rhsType.getEncoding()); if (failed(torch_to_linalg::broadcastToGivenShape( op, rewriter, rhs, rhsBroadcastToShape, rhsBroadcastType, broadcastedRhs))) { @@ -880,7 +880,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { if(numSpacialDims != 2) return rewriter.notifyMatchFailure( op, "unimplemented: only 2D grouped convolution supported"); - + // Special depthwise case auto inShape = makeShapeTorchCompatible( input.getType().cast().getShape()); @@ -894,6 +894,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { (weightShape[0] == kUnknownSize ? kUnknownSize : weightShape[0] * weightShape[1]), weightShape[2], weightShape[3]}; + // TODO: audit possibility of sparsity on this tensor Type collapsedType = RankedTensorType::get( makeShapeLLVMCompatible(collapsedShape), elementType); Value collapsedWeight = rewriter.create( diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 77459aca3a60..8bff5034c6b4 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -87,6 +87,7 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor( *pad = castIntToIndex(b, loc, *pad); Type elementType = input.getType().cast().getElementType(); + // TODO: audit possibility of sparsity on this tensor Type inputType = RankedTensorType::get(makeShapeLLVMCompatible(llvm::ArrayRef( SmallVector(inRank, kUnknownSize))), diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index b5b63954fe42..a154fb4653c4 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -467,7 +467,8 @@ TensorType ValueTensorType::toBuiltinTensor() const { Type elementType = convertDtypeToBuiltinElementType(getContext(), getDtype()); if (!elementType) return nullptr; - return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType); + return RankedTensorType::get(makeShapeLLVMCompatible(getSizes()), elementType, + getOptionalSparsity()); } LogicalResult diff --git a/test/Conversion/TorchToLinalg/sparse.mlir b/test/Conversion/TorchToLinalg/sparse.mlir new file mode 100644 index 000000000000..5d952fde3509 --- /dev/null +++ b/test/Conversion/TorchToLinalg/sparse.mlir @@ -0,0 +1,36 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s + +// ----- + +#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> + +// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +// CHECK-LABEL: func.func @sum( +// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[64,64],f32,#[[$CSR]]>) -> !torch.vtensor<[],f32> +// CHECK: %[[S:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[64,64],f32,#[[$CSR]]> -> tensor<64x64xf32, #[[$CSR]]> +// CHECK: linalg.generic {{{.*}}} ins(%[[S]] : tensor<64x64xf32, #[[$CSR]]>) +func.func @sum(%arg0: !torch.vtensor<[64,64],f32,#CSR>) -> !torch.vtensor<[],f32> { + %none = torch.constant.none + %0 = torch.aten.sum %arg0, %none + : !torch.vtensor<[64,64],f32,#CSR>, !torch.none -> !torch.vtensor<[],f32> + return %0 : !torch.vtensor<[],f32> +} + +// ----- + +#CSR = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> + +// CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +// CHECK-LABEL: func.func @SpMM( +// CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,16],f32,#[[$CSR]]>, +// CHECK-SAME: %[[B:.*]]: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> +// CHECK: %[[S:.*]] = torch_c.to_builtin_tensor %[[A]] : !torch.vtensor<[8,16],f32,#[[$CSR]]> -> tensor<8x16xf32, #[[$CSR]]> +// CHECK: %[[T:.*]] = torch_c.to_builtin_tensor %[[B]] : !torch.vtensor<[16,8],f32> -> tensor<16x8xf32> +// CHECK: linalg.matmul ins(%[[S]], %[[T]] : tensor<8x16xf32, #[[$CSR]]>, tensor<16x8xf32>) +func.func @SpMM(%arg0: !torch.vtensor<[8,16],f32,#CSR>, + %arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> { + %0 = torch.aten.matmul %arg0, %arg1 + : !torch.vtensor<[8,16],f32,#CSR>, + !torch.vtensor<[16,8],f32> -> !torch.vtensor<[8,8],f32> + return %0 : !torch.vtensor<[8,8],f32> +} From 28c7051ceb5a49944dfbd0d1fdf1ada4ac2a5b9e Mon Sep 17 00:00:00 2001 From: MaheshRavishankar <1663364+MaheshRavishankar@users.noreply.github.com> Date: Fri, 26 Jan 2024 18:38:44 -0800 Subject: [PATCH 124/283] Bump LLVM to llvm/llvm-project@5fcf907b34355980f77d7665a175b05fea7a6b7b (#2810) --- externals/llvm-project | 2 +- .../Transforms/AdjustCallingConventions.cpp | 8 ++---- .../Transforms/MaximizeValueSemantics.cpp | 4 +-- .../Torch/Transforms/ReduceOpVariants.cpp | 28 ++++++++++--------- 4 files changed, 21 insertions(+), 21 deletions(-) diff --git a/externals/llvm-project b/externals/llvm-project index eae82ac259ee..5fcf907b3435 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit eae82ac259ee5a58bc4070a414bc53239e18bad0 +Subproject commit 5fcf907b34355980f77d7665a175b05fea7a6b7b diff --git a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp index 30cc4db44181..2891a22eb817 100644 --- a/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp +++ b/lib/Dialect/Torch/Transforms/AdjustCallingConventions.cpp @@ -81,7 +81,7 @@ class AdjustCallingConventionForFunc } newResultTypes.push_back(type); } - rewriter.updateRootInPlace(func, [&] { + rewriter.modifyOpInPlace(func, [&] { func.setType(FunctionType::get( getContext(), conversion.getConvertedTypes(), newResultTypes)); // Clear out the type bounds, now that the type incorporates them. @@ -194,14 +194,12 @@ static LogicalResult adjustCallingConventions(func::FuncOp func, TypeConverter typeConverter; typeConverter.addConversion([](Type type) { return type; }); typeConverter.addConversion( - [](Torch::TupleType type, - SmallVectorImpl &types) -> LogicalResult { + [](Torch::TupleType type, SmallVectorImpl &types) -> LogicalResult { llvm::append_range(types, type.getContainedTypes()); return success(); }); typeConverter.addConversion( - [](Torch::NoneType type, - SmallVectorImpl &types) -> LogicalResult { + [](Torch::NoneType type, SmallVectorImpl &types) -> LogicalResult { return success(); }); diff --git a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp index cd76275a745d..7db6bc6776b3 100644 --- a/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp +++ b/lib/Dialect/Torch/Transforms/MaximizeValueSemantics.cpp @@ -175,7 +175,7 @@ class AbstractlyInterpretCopyToNonValueTensorOpUsersWithinABlock // Replace return type of view-like ops with value-semantics type variant. for (Operation *viewLikeOp : ops.viewLikeOps) { - rewriter.updateRootInPlace(viewLikeOp, [&] { + rewriter.modifyOpInPlace(viewLikeOp, [&] { Value result = viewLikeOp->getResult(0); auto resultType = result.getType().dyn_cast(); if (resultType) @@ -337,7 +337,7 @@ class RewriteViewLikeSubgraph // correctly copy them back to their mlir::func::ReturnOp's expected types. DenseMap originalTypes; for (Operation *op : viewLikeOps) { - rewriter.updateRootInPlace(op, [&]() { + rewriter.modifyOpInPlace(op, [&]() { if (auto nonValueTensorType = op->getResult(0).getType().dyn_cast()) { originalTypes[op->getResult(0)] = nonValueTensorType; diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 8ba0479625d8..200f25c82c43 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -9,10 +9,10 @@ #include "PassDetail.h" +#include "ReifyAbstractInterpCalculationsUtils.h" #include "mlir/Transforms/DialectConversion.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" -#include "ReifyAbstractInterpCalculationsUtils.h" #include "llvm/ADT/StringExtras.h" using namespace mlir; @@ -72,8 +72,8 @@ namespace { // immutable tensors. class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { public: - ConvertHasValueSemanticsOpsToValueTensors(MLIRContext *context, - const std::optional& extraLibrary) + ConvertHasValueSemanticsOpsToValueTensors( + MLIRContext *context, const std::optional &extraLibrary) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) { this->extraLibrary = extraLibrary; } @@ -87,7 +87,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { return rewriter.notifyMatchFailure(op, "does not have value semantics"); } - rewriter.startRootUpdate(op); + rewriter.startOpModification(op); // Convert all operands. SmallVector newOperands; for (OpOperand &opOperand : op->getOpOperands()) { @@ -105,7 +105,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { auto listConstruct = opOperand.get().getDefiningOp(); if (!listConstruct) { - rewriter.cancelRootUpdate(op); + rewriter.cancelOpModification(op); return rewriter.notifyMatchFailure( op, "unimplemented: list of non vtensor type not constructed " "from list construct"); @@ -120,7 +120,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { if (!llvm::all_of(listConstruct.getElements(), [](Value val) { return val.getType().isa(); })) { - rewriter.cancelRootUpdate(op); + rewriter.cancelOpModification(op); return rewriter.notifyMatchFailure( op, "unimplemented: list containing optional type is not " "handled."); @@ -138,7 +138,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { Type newListType = getContainerOrTensorTypeWithValueSemantics(listType); if (!newListType) { - rewriter.cancelRootUpdate(op); + rewriter.cancelOpModification(op); return rewriter.notifyMatchFailure( op, "Unable to convert list type to value semantics."); } @@ -154,7 +154,7 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { // from the non value tensor of the original optional value. auto derefine = opOperand.get().getDefiningOp(); if (!derefine) { - rewriter.cancelRootUpdate(op); + rewriter.cancelOpModification(op); return rewriter.notifyMatchFailure( op, "unimplemented: optional of non vtensor type not from " "derefine"); @@ -180,9 +180,10 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { rewriter.create(op->getLoc(), result); result.replaceAllUsesExcept(nonValueTensor, nonValueTensor); } - rewriter.finalizeRootUpdate(op); + rewriter.finalizeOpModification(op); return success(); } + private: std::optional extraLibrary; }; @@ -290,9 +291,9 @@ class ReduceTrailingUnderscoreInplaceVariant : public RewritePattern { Operation *newOp = rewriter.create(state); // Note: need to convert result to first input's dtype because mix precision // compute would result in different behaviors. - // For example: - // a = torch.randn(3, 3).half() # float16 - // b = torch.randn(3, 3) # float32 + // For example: + // a = torch.randn(3, 3).half() # float16 + // b = torch.randn(3, 3) # float32 // a += b # i.e. torch.ops.aten.add_(a, b), result is float16 // c = a + b # i.e. torch.ops.aten.add(a, b), result is float32 Value none = rewriter.create(op->getLoc()); @@ -300,7 +301,8 @@ class ReduceTrailingUnderscoreInplaceVariant : public RewritePattern { auto aDtype = rewriter.create(op->getLoc(), op->getOperand(0)); auto toDtype = rewriter.create( op->getLoc(), newOp->getResult(0).getType(), newOp->getResult(0), - aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); + aDtype, /*non_blocking=*/cstFalse, /*copy=*/cstFalse, + /*memory_format=*/none); auto tensor = rewriter.create(op->getLoc(), toDtype); createOverwriteTensorContents(rewriter, op->getLoc(), tensor, op->getOperand(0)); From 4a4d80a6ad18b01b7eaf096c5231106228fd719c Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sat, 27 Jan 2024 15:48:06 -0800 Subject: [PATCH 125/283] [ci] Add lint job and enable yaml linting of GH files. (#2819) --- .github/actions/setup-build/action.yml | 99 ++-- .github/workflows/RollPyTorch.yml | 260 +++++----- .github/workflows/bazelBuildAndTest.yml | 159 +++--- .github/workflows/buildAndTest.yml | 202 ++++---- .github/workflows/buildRelease.yml | 491 +++++++++---------- .github/workflows/gh-pages-releases.yml | 1 + .github/workflows/lint.yml | 17 + .github/workflows/merge-rollpytorch.yml | 31 +- .github/workflows/oneshotSnapshotPackage.yml | 1 + .github/workflows/releaseSnapshotPackage.yml | 1 + 10 files changed, 642 insertions(+), 620 deletions(-) create mode 100644 .github/workflows/lint.yml diff --git a/.github/actions/setup-build/action.yml b/.github/actions/setup-build/action.yml index 73592a7dce86..a21c9a1d7296 100644 --- a/.github/actions/setup-build/action.yml +++ b/.github/actions/setup-build/action.yml @@ -1,3 +1,4 @@ +# yamllint disable rule:line-length name: "Setup build environment" description: "Setup the build environment. An action so that it can be shared between in-tree/out-of-tree jobs" @@ -24,59 +25,59 @@ runs: using: "composite" steps: - - name: Set up Python - if: ${{ runner.arch == 'X64' }} - uses: actions/setup-python@v4 - with: - python-version: '3.11' + - name: Set up Python + if: ${{ runner.arch == 'X64' }} + uses: actions/setup-python@v4 + with: + python-version: '3.11' - - name: Install MLIR Python depends - if: ${{ runner.os != 'Linux' }} - run: | - python -m pip install -r $GITHUB_WORKSPACE/externals/llvm-project/mlir/python/requirements.txt - shell: bash + - name: Install MLIR Python depends + if: ${{ runner.os != 'Linux' }} + run: | + python -m pip install -r $GITHUB_WORKSPACE/externals/llvm-project/mlir/python/requirements.txt + shell: bash - - name: Install PyTorch nightly depends - if: ${{ runner.os != 'Linux' }} - run: | - python -m pip install -r pytorch-requirements.txt - python -m pip install -r build-requirements.txt - shell: bash + - name: Install PyTorch nightly depends + if: ${{ runner.os != 'Linux' }} + run: | + python -m pip install -r pytorch-requirements.txt + python -m pip install -r build-requirements.txt + shell: bash - - name: Install prerequisites (Linux) - if: ${{ runner.os == 'Linux' }} - run: sudo apt-get install --yes ccache ninja-build - shell: bash + - name: Install prerequisites (Linux) + if: ${{ runner.os == 'Linux' }} + run: sudo apt-get install --yes ccache ninja-build + shell: bash - - name: Install prerequisites (macOS) - if: ${{ runner.os == 'macOS' }} - run: brew install ccache ninja - shell: bash + - name: Install prerequisites (macOS) + if: ${{ runner.os == 'macOS' }} + run: brew install ccache ninja + shell: bash - - name: Install prerequisites (Windows) - if: ${{ runner.os == 'Windows' }} - run: | - pip install ninja - choco install ccache --yes - shell: bash + - name: Install prerequisites (Windows) + if: ${{ runner.os == 'Windows' }} + run: | + pip install ninja + choco install ccache --yes + shell: bash - - name: Configure ccache - if: ${{ inputs.cache-enabled == 'true' }} - run: | - rm -rf ${{ github.workspace }}/.ccache - mkdir -p ${{ github.workspace }}/.ccache - ccache --set-config "cache_dir=${{ github.workspace }}/.ccache" - ccache --set-config "compression=true" - ccache --set-config "max_size=300M" - ccache --zero-stats - shell: bash + - name: Configure ccache + if: ${{ inputs.cache-enabled == 'true' }} + run: | + rm -rf ${{ github.workspace }}/.ccache + mkdir -p ${{ github.workspace }}/.ccache + ccache --set-config "cache_dir=${{ github.workspace }}/.ccache" + ccache --set-config "compression=true" + ccache --set-config "max_size=300M" + ccache --zero-stats + shell: bash - - name: Enable ccache - if: ${{ inputs.cache-enabled == 'true' }} - uses: actions/cache@v3 - with: - path: ${{ github.workspace }}/.ccache - key: ${{ runner.os }}-${{ inputs.cache-suffix }}-${{ github.sha }} - restore-keys: | - ${{ runner.os }}-${{ inputs.cache-suffix }}- - ${{ runner.os }}- + - name: Enable ccache + if: ${{ inputs.cache-enabled == 'true' }} + uses: actions/cache@v3 + with: + path: ${{ github.workspace }}/.ccache + key: ${{ runner.os }}-${{ inputs.cache-suffix }}-${{ github.sha }} + restore-keys: | + ${{ runner.os }}-${{ inputs.cache-suffix }}- + ${{ runner.os }}- diff --git a/.github/workflows/RollPyTorch.yml b/.github/workflows/RollPyTorch.yml index 4f2d9d8c509a..975b538c5d95 100644 --- a/.github/workflows/RollPyTorch.yml +++ b/.github/workflows/RollPyTorch.yml @@ -1,3 +1,4 @@ +# yamllint disable rule:line-length name: Roll PyTorch on: @@ -14,133 +15,132 @@ jobs: if: github.repository == 'llvm/torch-mlir' steps: - - - name: Prepare workspace - run: | - # Clear the workspace directory so that we don't run into errors about - # existing lock files. - sudo rm -rf $GITHUB_WORKSPACE/* - - - name: Get torch-mlir - uses: actions/checkout@v3 - with: - submodules: 'false' - token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - - - name: Get LLVM and StableHlo submodules - run: | - set -eo pipefail - cd ${GITHUB_WORKSPACE} - - # Fetching the submodules concurrently may cause problems, so we fetch - # them one after another. - rm -f .git/modules/externals/llvm-project/index.lock - rm -f .git/modules/externals/stablehlo/index.lock - git submodule update --init --recursive externals/llvm-project - git submodule update --init --recursive externals/stablehlo - - - name: Setup ccache - uses: ./.github/actions/setup-build - with: - cache-suffix: 'rollPyTorch' - - - name: Determine nightly PyTorch version - run: | - set -eo pipefail - - cd ${GITHUB_WORKSPACE} - python -m pip install wheel - sudo apt-get install unzip - - # Fetch the most recent nightly torchvision release - VISION_RELEASE=$(python -m pip index versions -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre torchvision | grep "Available versions" | tr ' ' '\n' | grep "^[0-9]" | sort --version-sort --reverse | head -n1 | tr -d ',' | sed 's/\([^+]*\).*/\1/') - echo "Found torchvision release ${VISION_RELEASE}" - - # Fetch the whl file associated with the nightly torchvision release - rm -f torch*.whl - python -m pip download -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre "torchvision==${VISION_RELEASE}" - - # Downloading the torchvision WHL also downloads the PyTorch WHL file - # Read the version from the downloaded whl file without extracting it - PT_RELEASE=$(unzip -p torch-*.whl 'torch-*/METADATA' | grep "^Version:" | awk '{ print $2 }' | sed 's/\([^+]*\).*/\1/') - echo "Found torch release ${PT_RELEASE}" - printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorch==%s\n" "${PT_RELEASE}" > pytorch-requirements.txt - printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorchvision==%s\n" "${VISION_RELEASE}" > torchvision-requirements.txt - - # Read the commit hash from the downloaded whl file without extracting it - PT_HASH=$(unzip -p torch-"${PT_RELEASE}"*.whl torch/version.py | grep git_version | tail -1 | awk '{ print $3 }' | tr -d "'") - echo "Found torch commit hash ${PT_HASH}" - - PT_HASH_CHANGED=0 - echo "${PT_HASH}" | cmp - pytorch-hash.txt --quiet || PT_HASH_CHANGED=$? - echo "${PT_HASH}" > pytorch-hash.txt - rm torch-"${PT_RELEASE}"*.whl - - # Write the release and hash to the environment file so that we can - # retrieve them when creating a PR - echo "PT_HASH=${PT_HASH}" >> ${GITHUB_ENV} - echo "PT_RELEASE=${PT_RELEASE}" >> ${GITHUB_ENV} - echo "PTVISION_RELEASE=${VISION_RELEASE}" >> ${GITHUB_ENV} - echo "PT_HASH_CHANGED=${PT_HASH_CHANGED}" >> ${GITHUB_ENV} - - - name: Build and test (out-of-tree), also update ODS and abstract interpretation library - if: env.PT_HASH_CHANGED != '0' - run: | - cd ${GITHUB_WORKSPACE} - TM_PACKAGES="out-of-tree" TM_USE_PYTORCH_BINARY="OFF" \ - TORCH_MLIR_SRC_PYTORCH_BRANCH="${{ env.PT_HASH }}" \ - TORCH_MLIR_SRC_PYTORCH_RELEASE="${{ env.PT_RELEASE }}" \ - TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB="ON" \ - ./build_tools/python_deploy/build_linux_packages.sh - - - name: Post issue comment on build failure - if: failure() - uses: peter-evans/create-or-update-comment@v2 - with: - issue-number: 1690 - body: | - The RollPyTorch action has failed. See [CI log](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}) for details. - - The following information may come handy when fixing the code. - ``` - torch version: ${{ env.PT_RELEASE }} - torch commit hash: ${{ env.PT_HASH }} - torchvision version: ${{ env.PTVISION_RELEASE }} - ``` - - - name: Update PyTorch Build Cache (if running on main branch) - if: github.ref_name == 'main' - id: cache-pytorch - uses: actions/cache@v3 - with: - path: ${{ github.workspace }}/build_tools/python_deploy/wheelhouse - key: ${{ runner.os }}-pytorch-${{ env.PT_HASH }} - - - name: Commit changes locally - if: env.PT_HASH_CHANGED != '0' - run: | - cd ${GITHUB_WORKSPACE} - git config user.email "torch-mlir@users.noreply.github.com" - git config user.name "Roll PyTorch Action" - git fetch --recurse-submodules=no - git checkout main - git pull origin main - - - name: Create pull request - uses: peter-evans/create-pull-request@v5.0.1 - with: - author: Roll PyTorch Action - branch: rollpytorch - body: | - torch version: ${{ env.PT_RELEASE }} - torch commit hash: ${{ env.PT_HASH }} - torchvision version: ${{ env.PTVISION_RELEASE }} - commit-message: | - update PyTorch version to ${{ env.PT_RELEASE }} - - - torch version: ${{ env.PT_RELEASE }} - - torch commit hash: ${{ env.PT_HASH }} - - torchvision version: ${{ env.PTVISION_RELEASE }} - committer: Roll PyTorch Action - title: update PyTorch version to ${{ env.PT_RELEASE }} - token: ${{ secrets.ROLLPYTORCH_TOKEN0 }} + - name: Prepare workspace + run: | + # Clear the workspace directory so that we don't run into errors about + # existing lock files. + sudo rm -rf $GITHUB_WORKSPACE/* + + - name: Get torch-mlir + uses: actions/checkout@v3 + with: + submodules: 'false' + token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + + - name: Get LLVM and StableHlo submodules + run: | + set -eo pipefail + cd ${GITHUB_WORKSPACE} + + # Fetching the submodules concurrently may cause problems, so we fetch + # them one after another. + rm -f .git/modules/externals/llvm-project/index.lock + rm -f .git/modules/externals/stablehlo/index.lock + git submodule update --init --recursive externals/llvm-project + git submodule update --init --recursive externals/stablehlo + + - name: Setup ccache + uses: ./.github/actions/setup-build + with: + cache-suffix: 'rollPyTorch' + + - name: Determine nightly PyTorch version + run: | + set -eo pipefail + + cd ${GITHUB_WORKSPACE} + python -m pip install wheel + sudo apt-get install unzip + + # Fetch the most recent nightly torchvision release + VISION_RELEASE=$(python -m pip index versions -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre torchvision | grep "Available versions" | tr ' ' '\n' | grep "^[0-9]" | sort --version-sort --reverse | head -n1 | tr -d ',' | sed 's/\([^+]*\).*/\1/') + echo "Found torchvision release ${VISION_RELEASE}" + + # Fetch the whl file associated with the nightly torchvision release + rm -f torch*.whl + python -m pip download -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre "torchvision==${VISION_RELEASE}" + + # Downloading the torchvision WHL also downloads the PyTorch WHL file + # Read the version from the downloaded whl file without extracting it + PT_RELEASE=$(unzip -p torch-*.whl 'torch-*/METADATA' | grep "^Version:" | awk '{ print $2 }' | sed 's/\([^+]*\).*/\1/') + echo "Found torch release ${PT_RELEASE}" + printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorch==%s\n" "${PT_RELEASE}" > pytorch-requirements.txt + printf -- "-f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html\n--pre\ntorchvision==%s\n" "${VISION_RELEASE}" > torchvision-requirements.txt + + # Read the commit hash from the downloaded whl file without extracting it + PT_HASH=$(unzip -p torch-"${PT_RELEASE}"*.whl torch/version.py | grep git_version | tail -1 | awk '{ print $3 }' | tr -d "'") + echo "Found torch commit hash ${PT_HASH}" + + PT_HASH_CHANGED=0 + echo "${PT_HASH}" | cmp - pytorch-hash.txt --quiet || PT_HASH_CHANGED=$? + echo "${PT_HASH}" > pytorch-hash.txt + rm torch-"${PT_RELEASE}"*.whl + + # Write the release and hash to the environment file so that we can + # retrieve them when creating a PR + echo "PT_HASH=${PT_HASH}" >> ${GITHUB_ENV} + echo "PT_RELEASE=${PT_RELEASE}" >> ${GITHUB_ENV} + echo "PTVISION_RELEASE=${VISION_RELEASE}" >> ${GITHUB_ENV} + echo "PT_HASH_CHANGED=${PT_HASH_CHANGED}" >> ${GITHUB_ENV} + + - name: Build and test (out-of-tree), also update ODS and abstract interpretation library + if: env.PT_HASH_CHANGED != '0' + run: | + cd ${GITHUB_WORKSPACE} + TM_PACKAGES="out-of-tree" TM_USE_PYTORCH_BINARY="OFF" \ + TORCH_MLIR_SRC_PYTORCH_BRANCH="${{ env.PT_HASH }}" \ + TORCH_MLIR_SRC_PYTORCH_RELEASE="${{ env.PT_RELEASE }}" \ + TM_UPDATE_ODS_AND_ABSTRACT_INTERP_LIB="ON" \ + ./build_tools/python_deploy/build_linux_packages.sh + + - name: Post issue comment on build failure + if: failure() + uses: peter-evans/create-or-update-comment@v2 + with: + issue-number: 1690 + body: | + The RollPyTorch action has failed. See [CI log](${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}) for details. + + The following information may come handy when fixing the code. + ``` + torch version: ${{ env.PT_RELEASE }} + torch commit hash: ${{ env.PT_HASH }} + torchvision version: ${{ env.PTVISION_RELEASE }} + ``` + + - name: Update PyTorch Build Cache (if running on main branch) + if: github.ref_name == 'main' + id: cache-pytorch + uses: actions/cache@v3 + with: + path: ${{ github.workspace }}/build_tools/python_deploy/wheelhouse + key: ${{ runner.os }}-pytorch-${{ env.PT_HASH }} + + - name: Commit changes locally + if: env.PT_HASH_CHANGED != '0' + run: | + cd ${GITHUB_WORKSPACE} + git config user.email "torch-mlir@users.noreply.github.com" + git config user.name "Roll PyTorch Action" + git fetch --recurse-submodules=no + git checkout main + git pull origin main + + - name: Create pull request + uses: peter-evans/create-pull-request@v5.0.1 + with: + author: Roll PyTorch Action + branch: rollpytorch + body: | + torch version: ${{ env.PT_RELEASE }} + torch commit hash: ${{ env.PT_HASH }} + torchvision version: ${{ env.PTVISION_RELEASE }} + commit-message: | + update PyTorch version to ${{ env.PT_RELEASE }} + + - torch version: ${{ env.PT_RELEASE }} + - torch commit hash: ${{ env.PT_HASH }} + - torchvision version: ${{ env.PTVISION_RELEASE }} + committer: Roll PyTorch Action + title: update PyTorch version to ${{ env.PT_RELEASE }} + token: ${{ secrets.ROLLPYTORCH_TOKEN0 }} diff --git a/.github/workflows/bazelBuildAndTest.yml b/.github/workflows/bazelBuildAndTest.yml index b3cb3b8fb165..23f2addbe5af 100644 --- a/.github/workflows/bazelBuildAndTest.yml +++ b/.github/workflows/bazelBuildAndTest.yml @@ -1,8 +1,9 @@ +# yamllint disable rule:line-length name: Bazel Build and Test on: push: - branches: [ main ] + branches: [main] workflow_dispatch: # Ensure that only a single job or workflow using the same @@ -24,90 +25,90 @@ jobs: runs-on: ubuntu-latest steps: - - name: Prepare workspace - run: | - # Clear the workspace directory so that we don't run into errors about - # existing lock files. - sudo rm -rf $GITHUB_WORKSPACE/* + - name: Prepare workspace + run: | + # Clear the workspace directory so that we don't run into errors about + # existing lock files. + sudo rm -rf $GITHUB_WORKSPACE/* - - name: Checkout torch-mlir - uses: actions/checkout@v3 - with: - submodules: 'true' + - name: Checkout torch-mlir + uses: actions/checkout@v3 + with: + submodules: 'true' - # Continually update cache even if there's a "hit" during - # restore to avoid the cache going stale over time - # https://github.com/actions/cache/blob/main/workarounds.md#update-a-cache - - name: Setup cache for bazel - uses: actions/cache@v3 - with: - path: ~/.cache/bazel - key: torch_mlir-bazel-build-cache-${{ runner.os }}-${{ github.sha }} - restore-keys: | - torch_mlir-bazel-build-cache-${{ runner.os }} + # Continually update cache even if there's a "hit" during + # restore to avoid the cache going stale over time + # https://github.com/actions/cache/blob/main/workarounds.md#update-a-cache + - name: Setup cache for bazel + uses: actions/cache@v3 + with: + path: ~/.cache/bazel + key: torch_mlir-bazel-build-cache-${{ runner.os }}-${{ github.sha }} + restore-keys: | + torch_mlir-bazel-build-cache-${{ runner.os }} - # Change bazel cache directory to root ownership - # to allow writing to it from within the docker container. - # If no cache hits, this directory is not present - # so don't run chown (will error otherwise). - - name: Set bazel cache permissions - run: | - if [ -d "${HOME}/.cache/bazel" ]; then - sudo chown -R root:root "${HOME}/.cache/bazel" - fi + # Change bazel cache directory to root ownership + # to allow writing to it from within the docker container. + # If no cache hits, this directory is not present + # so don't run chown (will error otherwise). + - name: Set bazel cache permissions + run: | + if [ -d "${HOME}/.cache/bazel" ]; then + sudo chown -R root:root "${HOME}/.cache/bazel" + fi - - name: Build docker image - run: | - docker build -f utils/bazel/docker/Dockerfile \ - -t torch-mlir:ci \ - . + - name: Build docker image + run: | + docker build -f utils/bazel/docker/Dockerfile \ + -t torch-mlir:ci \ + . - - name: Verify buildifier was run (bazel lint) - run: | - docker run --rm \ - -v "$(pwd)":"/opt/src/torch-mlir" \ - -v "${HOME}/.cache/bazel":"/root/.cache/bazel" \ - torch-mlir:ci \ - bazel run @torch-mlir//:buildifier - if [ -n "$(git status --porcelain)" ]; then - echo "Please 'bazel run @torch-mlir//:buildifier' and commit changes." - exit 1 - fi + - name: Verify buildifier was run (bazel lint) + run: | + docker run --rm \ + -v "$(pwd)":"/opt/src/torch-mlir" \ + -v "${HOME}/.cache/bazel":"/root/.cache/bazel" \ + torch-mlir:ci \ + bazel run @torch-mlir//:buildifier + if [ -n "$(git status --porcelain)" ]; then + echo "Please 'bazel run @torch-mlir//:buildifier' and commit changes." + exit 1 + fi - - name: Bazel build torch-mlir - run: | - docker run --rm \ - -v "$(pwd)":"/opt/src/torch-mlir" \ - -v "${HOME}/.cache/bazel":"/root/.cache/bazel" \ - torch-mlir:ci \ - bazel build @torch-mlir//:torch-mlir-opt + - name: Bazel build torch-mlir + run: | + docker run --rm \ + -v "$(pwd)":"/opt/src/torch-mlir" \ + -v "${HOME}/.cache/bazel":"/root/.cache/bazel" \ + torch-mlir:ci \ + bazel build @torch-mlir//:torch-mlir-opt - - name: Bazel test torch-mlir (lit tests) - run: | - docker run --rm \ - -v "$(pwd)":"/opt/src/torch-mlir" \ - -v "${HOME}/.cache/bazel":"/root/.cache/bazel" \ - torch-mlir:ci \ - bazel test @torch-mlir//test/... + - name: Bazel test torch-mlir (lit tests) + run: | + docker run --rm \ + -v "$(pwd)":"/opt/src/torch-mlir" \ + -v "${HOME}/.cache/bazel":"/root/.cache/bazel" \ + torch-mlir:ci \ + bazel test @torch-mlir//test/... - # Switch back bazel cache directory to user ownership - # to allow GHA post-cache step to save cache without - # permissions issue. - - name: Switch bazel cache permissions - run: | - if [ -d "${HOME}/.cache/bazel" ]; then - sudo chown -R "$USER":"$USER" "${HOME}/.cache/bazel" - fi + # Switch back bazel cache directory to user ownership + # to allow GHA post-cache step to save cache without + # permissions issue. + - name: Switch bazel cache permissions + run: | + if [ -d "${HOME}/.cache/bazel" ]; then + sudo chown -R "$USER":"$USER" "${HOME}/.cache/bazel" + fi - - name: Send mail - if: failure() - uses: dawidd6/action-send-mail@v3 - with: - server_address: ${{ secrets.SMTP_SERVER }} - server_port: ${{ secrets.SMTP_PORT }} - username: ${{ secrets.SMTP_USERNAME }} - password: ${{ secrets.SMTP_PASSWORD }} - subject: GitHub Action Bazel Build and Test failed! - body: Bazel Build job failed! See https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }} for more information. - to: ${{ secrets.MAIL_RECEIVER }} - from: Torch-MLIR Bazel Build GitHub Actions + - name: Send mail + if: failure() + uses: dawidd6/action-send-mail@v3 + with: + server_address: ${{ secrets.SMTP_SERVER }} + server_port: ${{ secrets.SMTP_PORT }} + username: ${{ secrets.SMTP_USERNAME }} + password: ${{ secrets.SMTP_PASSWORD }} + subject: GitHub Action Bazel Build and Test failed! + body: Bazel Build job failed! See https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }} for more information. + to: ${{ secrets.MAIL_RECEIVER }} + from: Torch-MLIR Bazel Build GitHub Actions diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index f52ddf9b439b..c6d345860129 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -1,10 +1,11 @@ +# yamllint disable rule:line-length name: Build and Test on: pull_request: - branches: [ main ] + branches: [main] push: - branches: [ main ] + branches: [main] workflow_dispatch: # Ensure that only a single job or workflow using the same @@ -57,102 +58,101 @@ jobs: runs-on: ${{ matrix.os }} steps: - - - name: Prepare workspace - if: ${{ matrix.os-arch == 'ubuntu-x86_64' }} - run: | - # Clear the workspace directory so that we don't run into errors about - # existing lock files. - sudo rm -rf $GITHUB_WORKSPACE/* - - - name: Checkout torch-mlir - uses: actions/checkout@v3 - with: - submodules: 'true' - fetch-depth: 0 - - - name: Fetch PyTorch commit hash - if: ${{ matrix.os-arch != 'windows-x86_64' }} - run: | - PT_HASH="$(cat ${GITHUB_WORKSPACE}/pytorch-hash.txt)" - echo "PT_HASH=${PT_HASH}" >> ${GITHUB_ENV} - - - name: Setup ccache - uses: ./.github/actions/setup-build - with: - cache-suffix: 'build-${{ matrix.llvm-build }}-${{ matrix.torch-version }}' - torch-version: ${{ matrix.torch-version }} - - - name: Set up Visual Studio shell - if: ${{ matrix.os-arch == 'windows-x86_64' }} - uses: egor-tensin/vs-shell@v2 - with: - arch: x64 - - - name: Try to Restore PyTorch Build Cache - if: ${{ matrix.torch-binary == 'OFF' }} - id: cache-pytorch - uses: actions/cache/restore@v3 - with: - path: ${{ github.workspace }}/build_tools/python_deploy/wheelhouse - key: ${{ runner.os }}-pytorch-${{ env.PT_HASH }} - - - name: Build and Test os-arch='ubuntu-x86_64' llvm-build='${{ matrix.llvm-build }}' torch-binary='${{ matrix.torch-binary }}' - if: ${{ matrix.os-arch == 'ubuntu-x86_64' }} - run: | - cd $GITHUB_WORKSPACE - TORCH_MLIR_SRC_PYTORCH_BRANCH="$(cat pytorch-hash.txt)" \ - TM_PACKAGES="${{ matrix.llvm-build }}" \ - TM_USE_PYTORCH_BINARY="${{ matrix.torch-binary }}" \ - TM_PYTORCH_INSTALL_WITHOUT_REBUILD="${{ steps.cache-pytorch.outputs.cache-hit }}" \ - TM_TORCH_VERSION="${{ matrix.torch-version }}" \ - ./build_tools/python_deploy/build_linux_packages.sh - - - name: Configure os-arch='macos-arm64' llvm-build='in-tree' torch-binary='${{ matrix.torch-binary }}' - # cross compile, can't test arm64 - if: ${{ matrix.os-arch == 'macos-arm64' && matrix.llvm-build == 'in-tree' }} - run: | - # TODO: Reenable LTC after build on macOS-arm64 is fixed (https://github.com/llvm/torch-mlir/issues/1253) - cmake -GNinja -Bbuild_arm64 \ - -DCMAKE_BUILD_TYPE=Release \ - -DCMAKE_C_COMPILER=clang \ - -DCMAKE_CXX_COMPILER=clang++ \ - -DCMAKE_C_COMPILER_LAUNCHER=ccache \ - -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ - -DCMAKE_LINKER=lld \ - -DCMAKE_OSX_ARCHITECTURES=arm64 \ - -DLLVM_ENABLE_ASSERTIONS=ON \ - -DLLVM_ENABLE_PROJECTS=mlir \ - -DLLVM_EXTERNAL_PROJECTS="torch-mlir" \ - -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$GITHUB_WORKSPACE" \ - -DLLVM_TARGETS_TO_BUILD=AArch64 \ - -DLLVM_USE_HOST_TOOLS=ON \ - -DLLVM_ENABLE_ZSTD=OFF \ - -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ - -DTORCH_MLIR_ENABLE_STABLEHLO=OFF \ - -DTORCH_MLIR_ENABLE_LTC=OFF \ - -DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \ - -DMACOSX_DEPLOYMENT_TARGET=12.0 \ - -DPython3_EXECUTABLE="$(which python)" \ - $GITHUB_WORKSPACE/externals/llvm-project/llvm - - - name: Build torch-mlir (cross-compile) - if: ${{ matrix.os-arch == 'macos-arm64' }} - run: | - cmake --build build_arm64 - - - name: Build (Windows) - if: ${{ matrix.os-arch == 'windows-x86_64' }} - shell: bash - run: ./build_tools/python_deploy/build_windows_ci.sh - - - name: Save PyTorch Build Cache - if: ${{ github.ref_name == 'main' && matrix.torch-binary == 'OFF' }} - uses: actions/cache/save@v3 - with: - path: ${{ github.workspace }}/build_tools/python_deploy/wheelhouse - key: ${{ runner.os }}-pytorch-${{ env.PT_HASH }} - - - name: Print ccache statistics - shell: bash - run: ccache --show-stats + - name: Prepare workspace + if: ${{ matrix.os-arch == 'ubuntu-x86_64' }} + run: | + # Clear the workspace directory so that we don't run into errors about + # existing lock files. + sudo rm -rf $GITHUB_WORKSPACE/* + + - name: Checkout torch-mlir + uses: actions/checkout@v3 + with: + submodules: 'true' + fetch-depth: 0 + + - name: Fetch PyTorch commit hash + if: ${{ matrix.os-arch != 'windows-x86_64' }} + run: | + PT_HASH="$(cat ${GITHUB_WORKSPACE}/pytorch-hash.txt)" + echo "PT_HASH=${PT_HASH}" >> ${GITHUB_ENV} + + - name: Setup ccache + uses: ./.github/actions/setup-build + with: + cache-suffix: 'build-${{ matrix.llvm-build }}-${{ matrix.torch-version }}' + torch-version: ${{ matrix.torch-version }} + + - name: Set up Visual Studio shell + if: ${{ matrix.os-arch == 'windows-x86_64' }} + uses: egor-tensin/vs-shell@v2 + with: + arch: x64 + + - name: Try to Restore PyTorch Build Cache + if: ${{ matrix.torch-binary == 'OFF' }} + id: cache-pytorch + uses: actions/cache/restore@v3 + with: + path: ${{ github.workspace }}/build_tools/python_deploy/wheelhouse + key: ${{ runner.os }}-pytorch-${{ env.PT_HASH }} + + - name: Build and Test os-arch='ubuntu-x86_64' llvm-build='${{ matrix.llvm-build }}' torch-binary='${{ matrix.torch-binary }}' + if: ${{ matrix.os-arch == 'ubuntu-x86_64' }} + run: | + cd $GITHUB_WORKSPACE + TORCH_MLIR_SRC_PYTORCH_BRANCH="$(cat pytorch-hash.txt)" \ + TM_PACKAGES="${{ matrix.llvm-build }}" \ + TM_USE_PYTORCH_BINARY="${{ matrix.torch-binary }}" \ + TM_PYTORCH_INSTALL_WITHOUT_REBUILD="${{ steps.cache-pytorch.outputs.cache-hit }}" \ + TM_TORCH_VERSION="${{ matrix.torch-version }}" \ + ./build_tools/python_deploy/build_linux_packages.sh + + - name: Configure os-arch='macos-arm64' llvm-build='in-tree' torch-binary='${{ matrix.torch-binary }}' + # cross compile, can't test arm64 + if: ${{ matrix.os-arch == 'macos-arm64' && matrix.llvm-build == 'in-tree' }} + run: | + # TODO: Reenable LTC after build on macOS-arm64 is fixed (https://github.com/llvm/torch-mlir/issues/1253) + cmake -GNinja -Bbuild_arm64 \ + -DCMAKE_BUILD_TYPE=Release \ + -DCMAKE_C_COMPILER=clang \ + -DCMAKE_CXX_COMPILER=clang++ \ + -DCMAKE_C_COMPILER_LAUNCHER=ccache \ + -DCMAKE_CXX_COMPILER_LAUNCHER=ccache \ + -DCMAKE_LINKER=lld \ + -DCMAKE_OSX_ARCHITECTURES=arm64 \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_EXTERNAL_PROJECTS="torch-mlir" \ + -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$GITHUB_WORKSPACE" \ + -DLLVM_TARGETS_TO_BUILD=AArch64 \ + -DLLVM_USE_HOST_TOOLS=ON \ + -DLLVM_ENABLE_ZSTD=OFF \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DTORCH_MLIR_ENABLE_STABLEHLO=OFF \ + -DTORCH_MLIR_ENABLE_LTC=OFF \ + -DTORCH_MLIR_USE_INSTALLED_PYTORCH="${{ matrix.torch-binary }}" \ + -DMACOSX_DEPLOYMENT_TARGET=12.0 \ + -DPython3_EXECUTABLE="$(which python)" \ + $GITHUB_WORKSPACE/externals/llvm-project/llvm + + - name: Build torch-mlir (cross-compile) + if: ${{ matrix.os-arch == 'macos-arm64' }} + run: | + cmake --build build_arm64 + + - name: Build (Windows) + if: ${{ matrix.os-arch == 'windows-x86_64' }} + shell: bash + run: ./build_tools/python_deploy/build_windows_ci.sh + + - name: Save PyTorch Build Cache + if: ${{ github.ref_name == 'main' && matrix.torch-binary == 'OFF' }} + uses: actions/cache/save@v3 + with: + path: ${{ github.workspace }}/build_tools/python_deploy/wheelhouse + key: ${{ runner.os }}-pytorch-${{ env.PT_HASH }} + + - name: Print ccache statistics + shell: bash + run: ccache --show-stats diff --git a/.github/workflows/buildRelease.yml b/.github/workflows/buildRelease.yml index 31888220491c..e84aabb4b388 100644 --- a/.github/workflows/buildRelease.yml +++ b/.github/workflows/buildRelease.yml @@ -1,3 +1,4 @@ +# yamllint disable rule:line-length name: Release Build on: @@ -16,295 +17,293 @@ jobs: runs-on: a100 strategy: matrix: - package: [ torch-mlir ] - py_version: [ cp38-cp38, cp311-cp311 ] + package: [torch-mlir] + py_version: [cp38-cp38, cp311-cp311] steps: + - name: Prepare workspace + run: | + # Clear the workspace directory so that we don't run into errors about + # existing lock files. + sudo rm -rf $GITHUB_WORKSPACE/* - - name: Prepare workspace - run: | - # Clear the workspace directory so that we don't run into errors about - # existing lock files. - sudo rm -rf $GITHUB_WORKSPACE/* + - name: Get torch-mlir + uses: actions/checkout@v3 + with: + submodules: 'true' + fetch-depth: 0 - - name: Get torch-mlir - uses: actions/checkout@v3 - with: - submodules: 'true' - fetch-depth: 0 + - uses: ./.github/actions/setup-build + with: + cache-enabled: 'false' + - name: Build Python wheels and smoke test. + run: | + cd $GITHUB_WORKSPACE + TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} + printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version + TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} ./build_tools/python_deploy/build_linux_packages.sh - - uses: ./.github/actions/setup-build - with: - cache-enabled: 'false' - - name: Build Python wheels and smoke test. - run: | - cd $GITHUB_WORKSPACE - TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} - printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version - TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} ./build_tools/python_deploy/build_linux_packages.sh + # If we were given a release_id, then upload the package we just built + # to the github releases page. + - name: Upload Release Assets (if requested) + if: github.event.inputs.release_id != '' + id: upload-release-assets + uses: dwenegar/upload-release-assets@v1 + env: + GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + assets_path: ./build_tools/python_deploy/wheelhouse/torch*.whl + # Publishing is necessary to make the release visible to `pip` + # on the github releases page. + - name: Publish Release (if requested) + if: github.event.inputs.release_id != '' + id: publish_release + uses: eregon/publish-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + - name: Create dist directory + if: github.event.inputs.release_id != '' + run: mkdir dist + - name: Copy releases to publish to dist directory + if: github.event.inputs.release_id != '' + run: cp build_tools/python_deploy/wheelhouse/torch_mlir*.whl dist/ - # If we were given a release_id, then upload the package we just built - # to the github releases page. - - name: Upload Release Assets (if requested) - if: github.event.inputs.release_id != '' - id: upload-release-assets - uses: dwenegar/upload-release-assets@v1 - env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - with: - release_id: ${{ github.event.inputs.release_id }} - assets_path: ./build_tools/python_deploy/wheelhouse/torch*.whl - # Publishing is necessary to make the release visible to `pip` - # on the github releases page. - - name: Publish Release (if requested) - if: github.event.inputs.release_id != '' - id: publish_release - uses: eregon/publish-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - with: - release_id: ${{ github.event.inputs.release_id }} - - name: Create dist directory - if: github.event.inputs.release_id != '' - run: mkdir dist - - name: Copy releases to publish to dist directory - if: github.event.inputs.release_id != '' - run: cp build_tools/python_deploy/wheelhouse/torch_mlir*.whl dist/ - - # Wheels must be published from a linux environment. - # - # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - - name: Store the binary wheel - uses: actions/upload-artifact@v2 - with: - name: wheels - path: dist + # Wheels must be published from a linux environment. + # + # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 + - name: Store the binary wheel + uses: actions/upload-artifact@v2 + with: + name: wheels + path: dist build_linux_arm64: name: Manylinux arm64 Build runs-on: linux-arm64 strategy: matrix: - package: [ torch-mlir ] - py_version: [ cp311-cp311 ] + package: [torch-mlir] + py_version: [cp311-cp311] steps: + - name: Prepare workspace + run: | + # Clear the workspace directory so that we don't run into errors about + # existing lock files. + sudo rm -rf $GITHUB_WORKSPACE/* - - name: Prepare workspace - run: | - # Clear the workspace directory so that we don't run into errors about - # existing lock files. - sudo rm -rf $GITHUB_WORKSPACE/* - - - name: Get torch-mlir - uses: actions/checkout@v3 - with: - submodules: 'true' - fetch-depth: 0 + - name: Get torch-mlir + uses: actions/checkout@v3 + with: + submodules: 'true' + fetch-depth: 0 - - uses: ./.github/actions/setup-build - with: - cache-enabled: 'false' - - name: Build Python wheels and smoke test. - run: | - cd $GITHUB_WORKSPACE - TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} - printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version - TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} TORCH_MLIR_ENABLE_LTC='0' ./build_tools/python_deploy/build_linux_packages.sh + - uses: ./.github/actions/setup-build + with: + cache-enabled: 'false' + - name: Build Python wheels and smoke test. + run: | + cd $GITHUB_WORKSPACE + TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} + printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version + TM_PYTHON_VERSIONS=${{ matrix.py_version }} TM_PACKAGES=${{ matrix.package }} TORCH_MLIR_ENABLE_LTC='0' ./build_tools/python_deploy/build_linux_packages.sh - # If we were given a release_id, then upload the package we just built - # to the github releases page. - - name: Upload Release Assets (if requested) - if: github.event.inputs.release_id != '' - id: upload-release-assets - uses: dwenegar/upload-release-assets@v1 - env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - with: - release_id: ${{ github.event.inputs.release_id }} - assets_path: ./build_tools/python_deploy/wheelhouse/torch*.whl - # Publishing is necessary to make the release visible to `pip` - # on the github releases page. - - name: Publish Release (if requested) - if: github.event.inputs.release_id != '' - id: publish_release - uses: eregon/publish-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - with: - release_id: ${{ github.event.inputs.release_id }} - - name: Create dist directory - if: github.event.inputs.release_id != '' - run: mkdir dist - - name: Copy releases to publish to dist directory - if: github.event.inputs.release_id != '' - run: cp build_tools/python_deploy/wheelhouse/torch_mlir*.whl dist/ + # If we were given a release_id, then upload the package we just built + # to the github releases page. + - name: Upload Release Assets (if requested) + if: github.event.inputs.release_id != '' + id: upload-release-assets + uses: dwenegar/upload-release-assets@v1 + env: + GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + assets_path: ./build_tools/python_deploy/wheelhouse/torch*.whl + # Publishing is necessary to make the release visible to `pip` + # on the github releases page. + - name: Publish Release (if requested) + if: github.event.inputs.release_id != '' + id: publish_release + uses: eregon/publish-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + - name: Create dist directory + if: github.event.inputs.release_id != '' + run: mkdir dist + - name: Copy releases to publish to dist directory + if: github.event.inputs.release_id != '' + run: cp build_tools/python_deploy/wheelhouse/torch_mlir*.whl dist/ - # Wheels must be published from a linux environment. - # - # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - - name: Store the binary wheel - uses: actions/upload-artifact@v2 - with: - name: wheels - path: dist + # Wheels must be published from a linux environment. + # + # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 + - name: Store the binary wheel + uses: actions/upload-artifact@v2 + with: + name: wheels + path: dist build_macos: name: MacOS Build runs-on: macos-latest strategy: matrix: - package: [ torch-mlir ] + package: [torch-mlir] steps: - - name: Get torch-mlir - uses: actions/checkout@v3 - with: - submodules: 'true' - - uses: ./.github/actions/setup-build - with: - cache-enabled: 'false' - - name: Build Python wheels and smoke test. - run: | - cd $GITHUB_WORKSPACE - python -m pip install wheel - TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} - printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version - sudo ./build_tools/python_deploy/install_macos_deps.sh - packages=${{ matrix.package }} TORCH_MLIR_PYTHON_VERSIONS="3.11" ./build_tools/python_deploy/build_macos_packages.sh + - name: Get torch-mlir + uses: actions/checkout@v3 + with: + submodules: 'true' + - uses: ./.github/actions/setup-build + with: + cache-enabled: 'false' + - name: Build Python wheels and smoke test. + run: | + cd $GITHUB_WORKSPACE + python -m pip install wheel + TM_PACKAGE_VERSION=${{ github.event.inputs.python_package_version }} + printf "TORCH_MLIR_PYTHON_PACKAGE_VERSION=%s\n" $TM_PACKAGE_VERSION > ./torch_mlir_package_version + sudo ./build_tools/python_deploy/install_macos_deps.sh + packages=${{ matrix.package }} TORCH_MLIR_PYTHON_VERSIONS="3.11" ./build_tools/python_deploy/build_macos_packages.sh - # If we were given a release_id, then upload the package we just built - # to the github releases page. - - name: Upload Release Assets (if requested) - if: github.event.inputs.release_id != '' - id: upload-release-assets - uses: dwenegar/upload-release-assets@v1 - env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - with: - release_id: ${{ github.event.inputs.release_id }} - assets_path: ./build_tools/python_deploy/wheelhouse/torch*.whl - # Publishing is necessary to make the release visible to `pip` - # on the github releases page. - - name: Publish Release (if requested) - if: github.event.inputs.release_id != '' - id: publish_release - uses: eregon/publish-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - with: - release_id: ${{ github.event.inputs.release_id }} - - name: Create dist directory - if: github.event.inputs.release_id != '' - run: mkdir dist - - name: Copy releases to publish to dist directory - if: github.event.inputs.release_id != '' - run: cp build_tools/python_deploy/wheelhouse/torch_mlir*.whl dist/ + # If we were given a release_id, then upload the package we just built + # to the github releases page. + - name: Upload Release Assets (if requested) + if: github.event.inputs.release_id != '' + id: upload-release-assets + uses: dwenegar/upload-release-assets@v1 + env: + GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + assets_path: ./build_tools/python_deploy/wheelhouse/torch*.whl + # Publishing is necessary to make the release visible to `pip` + # on the github releases page. + - name: Publish Release (if requested) + if: github.event.inputs.release_id != '' + id: publish_release + uses: eregon/publish-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + - name: Create dist directory + if: github.event.inputs.release_id != '' + run: mkdir dist + - name: Copy releases to publish to dist directory + if: github.event.inputs.release_id != '' + run: cp build_tools/python_deploy/wheelhouse/torch_mlir*.whl dist/ - # Wheels must be published from a linux environment. - # - # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - - name: Store the binary wheel - uses: actions/upload-artifact@v2 - with: - name: wheels - path: dist + # Wheels must be published from a linux environment. + # + # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 + - name: Store the binary wheel + uses: actions/upload-artifact@v2 + with: + name: wheels + path: dist build_windows: name: Windows Build runs-on: windows-latest strategy: matrix: - package: [ torch-mlir ] + package: [torch-mlir] steps: - - name: Get torch-mlir - uses: actions/checkout@v3 - with: - submodules: 'true' - - uses: ./.github/actions/setup-build - with: - cache-enabled: 'false' - - name: Set up Visual Studio shell - uses: egor-tensin/vs-shell@v2 - with: - arch: x64 - - name: Build Python wheels and smoke test. - shell: pwsh - run: | - $env:TORCH_MLIR_ENABLE_JIT_IR_IMPORTER='1' - $env:TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS='0' - $env:TORCH_MLIR_PYTHON_PACKAGE_VERSION = '${{ github.event.inputs.python_package_version }}' - ./build_tools/python_deploy/build_windows.ps1 + - name: Get torch-mlir + uses: actions/checkout@v3 + with: + submodules: 'true' + - uses: ./.github/actions/setup-build + with: + cache-enabled: 'false' + - name: Set up Visual Studio shell + uses: egor-tensin/vs-shell@v2 + with: + arch: x64 + - name: Build Python wheels and smoke test. + shell: pwsh + run: | + $env:TORCH_MLIR_ENABLE_JIT_IR_IMPORTER='1' + $env:TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS='0' + $env:TORCH_MLIR_PYTHON_PACKAGE_VERSION = '${{ github.event.inputs.python_package_version }}' + ./build_tools/python_deploy/build_windows.ps1 - # If we were given a release_id, then upload the package we just built - # to the github releases page. - - name: Upload Release Assets (if requested) - if: github.event.inputs.release_id != '' - id: upload-release-assets - uses: dwenegar/upload-release-assets@v1 - env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - with: - release_id: ${{ github.event.inputs.release_id }} - assets_path: ./wheelhouse/torch*.whl - # Publishing is necessary to make the release visible to `pip` - # on the github releases page. - - name: Publish Release (if requested) - if: github.event.inputs.release_id != '' - id: publish_release - uses: eregon/publish-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - with: - release_id: ${{ github.event.inputs.release_id }} - - name: Create dist directory - if: github.event.inputs.release_id != '' - run: mkdir dist - continue-on-error: true - - name: Copy releases to publish to dist directory - if: github.event.inputs.release_id != '' - run: cp ./wheelhouse/torch_mlir*.whl dist/ + # If we were given a release_id, then upload the package we just built + # to the github releases page. + - name: Upload Release Assets (if requested) + if: github.event.inputs.release_id != '' + id: upload-release-assets + uses: dwenegar/upload-release-assets@v1 + env: + GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + assets_path: ./wheelhouse/torch*.whl + # Publishing is necessary to make the release visible to `pip` + # on the github releases page. + - name: Publish Release (if requested) + if: github.event.inputs.release_id != '' + id: publish_release + uses: eregon/publish-release@v1 + env: + GITHUB_TOKEN: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + with: + release_id: ${{ github.event.inputs.release_id }} + - name: Create dist directory + if: github.event.inputs.release_id != '' + run: mkdir dist + continue-on-error: true + - name: Copy releases to publish to dist directory + if: github.event.inputs.release_id != '' + run: cp ./wheelhouse/torch_mlir*.whl dist/ - # Wheels must be published from a linux environment. - # - # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - - name: Store the binary wheel - uses: actions/upload-artifact@v2 - with: - name: wheels - path: dist + # Wheels must be published from a linux environment. + # + # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 + - name: Store the binary wheel + uses: actions/upload-artifact@v2 + with: + name: wheels + path: dist publish_releases: runs-on: ubuntu-latest needs: - - build_linux - - build_linux_arm64 - - build_macos - - build_windows + - build_linux + - build_linux_arm64 + - build_macos + - build_windows # Publish even if one of the builds failed if: ${{ always() }} steps: - - name: Invoke Publish Releases Page - uses: benc-uk/workflow-dispatch@v1 - with: - workflow: Publish releases page - token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + - name: Invoke Publish Releases Page + uses: benc-uk/workflow-dispatch@v1 + with: + workflow: Publish releases page + token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - # Wheels must be published from a linux environment. - # - # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 - # - # We're temporarily disabling pypi publishing until we can fix audit wheel - # ODR torch issues. See https://github.com/llvm/torch-mlir/issues/1709 - # - #- name: Download wheels for publishing to PyPI - # uses: actions/download-artifact@v3 - # with: - # name: wheels - # path: dist - #- name: Publish to PyPI - # if: github.event.inputs.release_id != '' - # uses: pypa/gh-action-pypi-publish@v1.5.1 - # with: - # password: ${{ secrets.PYPI_API_TOKEN }} + # Wheels must be published from a linux environment. + # + # See https://github.com/pypa/gh-action-pypi-publish/discussions/15 + # + # We're temporarily disabling pypi publishing until we can fix audit wheel + # ODR torch issues. See https://github.com/llvm/torch-mlir/issues/1709 + # + #- name: Download wheels for publishing to PyPI + # uses: actions/download-artifact@v3 + # with: + # name: wheels + # path: dist + #- name: Publish to PyPI + # if: github.event.inputs.release_id != '' + # uses: pypa/gh-action-pypi-publish@v1.5.1 + # with: + # password: ${{ secrets.PYPI_API_TOKEN }} diff --git a/.github/workflows/gh-pages-releases.yml b/.github/workflows/gh-pages-releases.yml index c6df475cca4d..a0eb45257b11 100644 --- a/.github/workflows/gh-pages-releases.yml +++ b/.github/workflows/gh-pages-releases.yml @@ -1,3 +1,4 @@ +# yamllint disable rule:line-length # See: https://github.com/llvm/torch-mlir/issues/1374 name: Publish releases page diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml new file mode 100644 index 000000000000..464ebdad93c0 --- /dev/null +++ b/.github/workflows/lint.yml @@ -0,0 +1,17 @@ +name: Lint Checks + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + static_lint_checks: + name: Static Lint Checks + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v3 + - name: Validate GitHub Actions yaml files + run: | + yamllint ./.github/workflows/ ./.github/actions/ diff --git a/.github/workflows/merge-rollpytorch.yml b/.github/workflows/merge-rollpytorch.yml index 7247a3683281..58a91fd1d409 100644 --- a/.github/workflows/merge-rollpytorch.yml +++ b/.github/workflows/merge-rollpytorch.yml @@ -1,3 +1,4 @@ +# yamllint disable rule:line-length name: RollPyTorch Merge on: @@ -15,19 +16,19 @@ jobs: github.event.workflow_run.conclusion == 'success' steps: - # Fetch the repo first so that the gh command knows where to look for the PR - - name: Fetch Repo - uses: actions/checkout@v3 - with: - token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} + # Fetch the repo first so that the gh command knows where to look for the PR + - name: Fetch Repo + uses: actions/checkout@v3 + with: + token: ${{ secrets.WORKFLOW_INVOCATION_TOKEN }} - - name: Merge RollPyTorch PR - run: | - for pr_id in ${{ join(github.event.workflow_run.pull_requests.*.number, ' ') }} - do - echo "Merging PR: $pr_id" - gh pr merge $pr_id --delete-branch --squash - done - shell: bash - env: - GH_TOKEN: ${{ secrets.ROLLPYTORCH_TOKEN1 }} + - name: Merge RollPyTorch PR + run: | + for pr_id in ${{ join(github.event.workflow_run.pull_requests.*.number, ' ') }} + do + echo "Merging PR: $pr_id" + gh pr merge $pr_id --delete-branch --squash + done + shell: bash + env: + GH_TOKEN: ${{ secrets.ROLLPYTORCH_TOKEN1 }} diff --git a/.github/workflows/oneshotSnapshotPackage.yml b/.github/workflows/oneshotSnapshotPackage.yml index 46832ce9c667..ec1878606624 100644 --- a/.github/workflows/oneshotSnapshotPackage.yml +++ b/.github/workflows/oneshotSnapshotPackage.yml @@ -1,3 +1,4 @@ +# yamllint disable rule:line-length name: Release oneshot snapshot package on: diff --git a/.github/workflows/releaseSnapshotPackage.yml b/.github/workflows/releaseSnapshotPackage.yml index c18eff88d32f..71acd9ac3ae8 100644 --- a/.github/workflows/releaseSnapshotPackage.yml +++ b/.github/workflows/releaseSnapshotPackage.yml @@ -1,3 +1,4 @@ +# yamllint disable rule:line-length name: Release snapshot package on: From 77c14ab22b11c31b57a2e45958cf360e70eef103 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sat, 27 Jan 2024 18:35:45 -0800 Subject: [PATCH 126/283] [ci] Upgrade to new runners and disable unsupported jobs. (#2818) Per the RFC and numerous conversations on Discord, this rebuilds the torch-mlir CI and discontinues the infra and coupling to the binary releases (https://discourse.llvm.org/t/rfc-discontinuing-pytorch-1-binary-releases/76371). I iterated on this to get latency back to about what it was with the old (much larger and non-ephemeral) runners: About 4m - 4.5m for an incremental change. Behind the scenes changes: * Uses a new runner pool operated by AMD. It is currently set to manual scaling and has two runners (32-core, 64GiB RAM) while we get some traction. We can either fiddle with some auto-scaling or use a schedule to give it an increase during certain high traffic hours. * Builds are now completely isolated and cannot have run-to-run interference like we were getting before (i.e. lock file/permissions stuff). * The GHA runner is installed directly into a manylinux 2.28 container with upgraded dev tools. This eliminates the need to do sub-invocations of docker on Linux in order to run on the same OS that is used to build wheels. * While not using it now, this setup was cloned from another project that posts the built artifacts to the job and fans out testing. Might be useful here later. * Uses a special git cache that lets us have ephemeral runners and still check out the repo and deps (incl. llvm) in ~13s. * Running in an Azure VM Scale Set. In-repo changes: * Disables (but does not yet delete): * Old buildAndTest.yml jobs * releaseSnapshotPackage.yml * Adds a new `ci.yml` pipeline and scripts the steps in `build_tools/ci` (by decomposing the existing `build_linux_packages.sh` for in-tree builds and modularizing it a bit better). * Test framework changes: * Adds a `TORCH_MLIR_TEST_CONCURRENCY` env var that can be used to bound the multiprocess concurrency. Ended up not using this in the final version but is useful to have as a knob. * Changes the default concurrency to `nproc * 0.8 + 1` vs `nproc * 1.1`. We're running on systems with significantly less virtual memory and I did a bit of fiddling to find a good tradeoff. * Changed multiprocess mode to spawn instead of fork. Otherwise, I was getting instability (as discussed on discord). * Added MLIR configuration to disable multithreaded contexts globally for the project. Constantly spawning `nproc * nproc` threads (more than that actually) was OOM'ing. * Added a test timeout of 5 minutes. If a multiprocess worker crashes, the framework can get wedged indefinitely (and then will just be reaped after multiple hours). We should fix this, but this at least keeps the CI pool from wedging with stuck jobs. Functional changes needing followup: * No matter what I did, I couldn't get the LTC tests to work, and I'm not 100% sure they were being run in the old setup as the scripts were a bit twisty. I disabled them and left a comment. * Dropped out-of-tree build variants. These were not providing much signal and increase CI needs by 50%. * Dropped MacOS and Windows builds. Now that we are "just a library" and not building releases, there is less pressure to test these commit by commit. Further, since we bump torch-mlir to known good commits on these platforms, it has been a long time since either of these jobs have provided much signal (and they take ~an hour+ to run). We can add them back later post-submit if ever needed. --- .github/workflows/buildAndTest.yml | 12 +-- .github/workflows/ci.yml | 77 +++++++++++++++++++ .github/workflows/releaseSnapshotPackage.yml | 5 +- build_tools/ci/build_posix.sh | 60 +++++++++++++++ build_tools/ci/check_generated_sources.sh | 45 +++++++++++ build_tools/ci/install_python_deps.sh | 34 ++++++++ build_tools/ci/linux_default_toolchain.cmake | 14 ++++ build_tools/ci/test_posix.sh | 44 +++++++++++ .../python/torch_mlir_e2e_test/framework.py | 24 +++++- python/CMakeLists.txt | 8 ++ .../_mlir_libs/_site_initialize_0.py | 9 +++ 11 files changed, 319 insertions(+), 13 deletions(-) create mode 100644 .github/workflows/ci.yml create mode 100755 build_tools/ci/build_posix.sh create mode 100755 build_tools/ci/check_generated_sources.sh create mode 100755 build_tools/ci/install_python_deps.sh create mode 100644 build_tools/ci/linux_default_toolchain.cmake create mode 100755 build_tools/ci/test_posix.sh create mode 100644 python/torch_mlir/_mlir_libs/_site_initialize_0.py diff --git a/.github/workflows/buildAndTest.yml b/.github/workflows/buildAndTest.yml index c6d345860129..08ab86571cd9 100644 --- a/.github/workflows/buildAndTest.yml +++ b/.github/workflows/buildAndTest.yml @@ -2,11 +2,11 @@ name: Build and Test on: - pull_request: - branches: [main] - push: - branches: [main] - workflow_dispatch: + # pull_request: + # branches: [main] + # push: + # branches: [main] + # workflow_dispatch: # Ensure that only a single job or workflow using the same # concurrency group will run at a time. This would cancel @@ -30,7 +30,7 @@ jobs: strategy: fail-fast: true matrix: - os-arch: [ubuntu-x86_64, macos-arm64, windows-x86_64] + os-arch: [macos-arm64, windows-x86_64] llvm-build: [in-tree, out-of-tree] torch-binary: [ON] torch-version: [nightly, stable] diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 000000000000..4bb27526b063 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,77 @@ +name: CI + +on: + workflow_dispatch: + workflow_call: + pull_request: + branches: [main] + push: + branches: [main] + +concurrency: + # A PR number if a pull request and otherwise the commit hash. This cancels + # queued and in-progress runs for the same PR (presubmit) or commit + # (postsubmit). + group: ci-build-test-cpp-linux-${{ github.event.number || github.sha }} + cancel-in-progress: true + +jobs: + build-test-linux: + strategy: + fail-fast: true + matrix: + torch-version: [nightly, stable] + name: Build and Test (Linux, torch-${{ matrix.torch-version }}, assertions) + runs-on: torch-mlir-cpubuilder-manylinux-x86-64 + env: + CACHE_DIR: ${{ github.workspace }}/.container-cache + steps: + - name: Configure local git mirrors + run: | + # Our stock runners have access to certain local git caches. If these + # files are available, it will prime the cache and configure git to + # use them. Practically, this eliminates network/latency for cloning + # llvm. + if [[ -x /gitmirror/scripts/trigger_update_mirrors.sh ]]; then + /gitmirror/scripts/trigger_update_mirrors.sh + /gitmirror/scripts/git_config.sh + fi + - name: "Checking out repository" + uses: actions/checkout@8f4b7f84864484a7bf31766abe9204da3cbe65b3 # v3.5.0 + with: + submodules: true + + - name: Enable cache + uses: actions/cache/restore@v3 + with: + path: ${{ env.CACHE_DIR }} + key: build-test-cpp-asserts-manylinux-v2-${{ github.sha }} + restore-keys: | + build-test-cpp-asserts-manylinux-v2- + + - name: Install python deps (torch-${{ matrix.torch-version }}) + run: | + export cache_dir="${{ env.CACHE_DIR }}" + bash build_tools/ci/install_python_deps.sh ${{ matrix.torch-version }} + + - name: Build project + run: | + export cache_dir="${{ env.CACHE_DIR }}" + bash build_tools/ci/build_posix.sh + + - name: Save cache + uses: actions/cache/save@v3 + if: ${{ !cancelled() }} + with: + path: ${{ env.CACHE_DIR }} + key: build-test-cpp-asserts-manylinux-v2-${{ github.sha }} + + - name: Test project (torch-${{ matrix.torch-version }}) + run: | + export cache_dir="${{ env.CACHE_DIR }}" + bash build_tools/ci/test_posix.sh ${{ matrix.torch-version }} + + - name: Check generated sources (torch-nightly only) + if: ${{ matrix.torch-version == 'nightly' }} + run: | + bash build_tools/ci/check_generated_sources.sh diff --git a/.github/workflows/releaseSnapshotPackage.yml b/.github/workflows/releaseSnapshotPackage.yml index 71acd9ac3ae8..8a0ec914440f 100644 --- a/.github/workflows/releaseSnapshotPackage.yml +++ b/.github/workflows/releaseSnapshotPackage.yml @@ -2,9 +2,8 @@ name: Release snapshot package on: - schedule: - - cron: '0 11 * * *' - + # schedule: + # - cron: '0 11 * * *' workflow_dispatch: jobs: diff --git a/build_tools/ci/build_posix.sh b/build_tools/ci/build_posix.sh new file mode 100755 index 000000000000..c6e7e168dc81 --- /dev/null +++ b/build_tools/ci/build_posix.sh @@ -0,0 +1,60 @@ +#!/bin/bash + +set -eu -o errtrace + +this_dir="$(cd $(dirname $0) && pwd)" +repo_root="$(cd $this_dir/../.. && pwd)" +build_dir="$repo_root/build" +install_dir="$repo_root/install" +mkdir -p "$build_dir" +build_dir="$(cd $build_dir && pwd)" +cache_dir="${cache_dir:-}" + +# Setup cache dir. +if [ -z "${cache_dir}" ]; then + cache_dir="${repo_root}/.build-cache" + mkdir -p "${cache_dir}" + cache_dir="$(cd ${cache_dir} && pwd)" +fi +echo "Caching to ${cache_dir}" +mkdir -p "${cache_dir}/ccache" +mkdir -p "${cache_dir}/pip" + +python="$(which python)" +echo "Using python: $python" + +export CMAKE_TOOLCHAIN_FILE="$this_dir/linux_default_toolchain.cmake" +export CC=clang +export CXX=clang++ +export CCACHE_DIR="${cache_dir}/ccache" +export CCACHE_MAXSIZE="350M" +export CMAKE_C_COMPILER_LAUNCHER=ccache +export CMAKE_CXX_COMPILER_LAUNCHER=ccache + +# Clear ccache stats. +ccache -z + +cd $repo_root + +echo "::group::CMake configure" +cmake -S "$repo_root/externals/llvm-project/llvm" -B "$build_dir" \ + -GNinja \ + -DCMAKE_BUILD_TYPE=Release \ + -DPython3_EXECUTABLE="$(which python)" \ + -DLLVM_ENABLE_ASSERTIONS=ON \ + -DCMAKE_INSTALL_PREFIX="$install_dir" \ + -DCMAKE_INSTALL_LIBDIR=lib \ + -DLLVM_ENABLE_PROJECTS=mlir \ + -DLLVM_EXTERNAL_PROJECTS="torch-mlir" \ + -DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR="$repo_root" \ + -DLLVM_TARGETS_TO_BUILD=host \ + -DMLIR_ENABLE_BINDINGS_PYTHON=ON \ + -DTORCH_MLIR_ENABLE_LTC=ON +echo "::endgroup::" + +echo "::group::Build" +cmake --build "$build_dir" --target tools/torch-mlir/all -- -k 0 +echo "::endgroup::" + +# Show ccache stats. +ccache --show-stats diff --git a/build_tools/ci/check_generated_sources.sh b/build_tools/ci/check_generated_sources.sh new file mode 100755 index 000000000000..719e221d71ba --- /dev/null +++ b/build_tools/ci/check_generated_sources.sh @@ -0,0 +1,45 @@ +#!/bin/bash + +set -eu -o errtrace + +this_dir="$(cd $(dirname $0) && pwd)" +repo_root="$(cd $this_dir/../.. && pwd)" + +function _check_file_not_changed_by() { + # _check_file_not_changed_by + cmd="$1" + file="$2" + file_backup="$PWD/$(basename $file)" + file_new="$PWD/$(basename $file).new" + # Save the original file. + cp "$file" "$file_backup" + # Run the command to regenerate it. + "$1" || return 1 + # Save the new generated file. + cp "$file" "$file_new" + # Restore the original file. We want this function to not change the user's + # working tree state. + mv "$file_backup" "$file" + # We use git-diff as "just a diff program" (no SCM stuff) because it has + # nicer output than regular `diff`. + if ! git diff --no-index --quiet "$file" "$file_new"; then + echo "#######################################################" + echo "Generated file '${file}' is not up to date (see diff below)" + echo ">>> Please run '${cmd}' to update it <<<" + echo "#######################################################" + git diff --no-index --color=always "$file" "$file_new" + # TODO: Is there a better cleanup strategy that doesn't require duplicating + # this inside and outside the `if`? + rm "$file_new" + return 1 + fi + rm "$file_new" +} + +echo "::group:: Check that update_abstract_interp_lib.sh has been run" +_check_file_not_changed_by $repo_root/build_tools/update_abstract_interp_lib.sh $repo_root/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +echo "::endgroup::" + +echo "::group:: Check that update_torch_ods.sh has been run" +_check_file_not_changed_by $repo_root/build_tools/update_torch_ods.sh $repo_root/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +echo "::endgroup::" diff --git a/build_tools/ci/install_python_deps.sh b/build_tools/ci/install_python_deps.sh new file mode 100755 index 000000000000..6b49689ce8ea --- /dev/null +++ b/build_tools/ci/install_python_deps.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +set -eu -o errtrace + +this_dir="$(cd $(dirname $0) && pwd)" +repo_root="$(cd $this_dir/../.. && pwd)" +torch_version="${1:-unknown}" + +echo "::group::installing llvm python deps" +python -m pip install --no-cache-dir -r $repo_root/externals/llvm-project/mlir/python/requirements.txt +echo "::endgroup::" + +case $torch_version in + nightly) + echo "::group::installing nightly torch" + python3 -m pip install --no-cache-dir -r $repo_root/requirements.txt + python3 -m pip install --no-cache-dir -r $repo_root/torchvision-requirements.txt + echo "::endgroup::" + ;; + stable) + echo "::group::installing stable torch" + python3 -m pip install --no-cache-dir torch torchvision --index-url https://download.pytorch.org/whl/cpu + python3 -m pip install --no-cache-dir -r $repo_root/build-requirements.txt + echo "::endgroup::" + ;; + *) + echo "Unrecognized torch version '$torch_version' (specify 'nightly' or 'stable' with cl arg)" + exit 1 + ;; +esac + +echo "::group::installing test requirements" +python -m pip install --no-cache-dir -r $repo_root/test-requirements.txt +echo "::endgroup::" diff --git a/build_tools/ci/linux_default_toolchain.cmake b/build_tools/ci/linux_default_toolchain.cmake new file mode 100644 index 000000000000..4e0c36c71be7 --- /dev/null +++ b/build_tools/ci/linux_default_toolchain.cmake @@ -0,0 +1,14 @@ +message(STATUS "Enabling thin archives (static libraries will not be relocatable)") +set(CMAKE_C_ARCHIVE_APPEND " qT ") +set(CMAKE_CXX_ARCHIVE_APPEND " qT ") +set(CMAKE_C_ARCHIVE_CREATE " crT ") +set(CMAKE_CXX_ARCHIVE_CREATE " crT ") + +set(CMAKE_EXE_LINKER_FLAGS_INIT "-fuse-ld=lld -Wl,--gdb-index") +set(CMAKE_MODULE_LINKER_FLAGS_INIT "-fuse-ld=lld -Wl,--gdb-index") +set(CMAKE_SHARED_LINKER_FLAGS_INIT "-fuse-ld=lld -Wl,--gdb-index") + +set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -gsplit-dwarf -ggnu-pubnames") +set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -gsplit-dwarf -ggnu-pubnames") +set(CMAKE_C_FLAGS_RELWITHDEBINFO "${CMAKE_C_FLAGS_RELWITHDEBINFO} -gsplit-dwarf -ggnu-pubnames") +set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "${CMAKE_CXX_FLAGS_RELWITHDEBINFO} -gsplit-dwarf -ggnu-pubnames") diff --git a/build_tools/ci/test_posix.sh b/build_tools/ci/test_posix.sh new file mode 100755 index 000000000000..8cc68d77bd79 --- /dev/null +++ b/build_tools/ci/test_posix.sh @@ -0,0 +1,44 @@ +#!/bin/bash + +set -eu -o errtrace + +this_dir="$(cd $(dirname $0) && pwd)" +repo_root="$(cd $this_dir/../.. && pwd)" +torch_version="${1:-unknown}" + +export PYTHONPATH="$repo_root/build/tools/torch-mlir/python_packages/torch_mlir:$repo_root/projects/pt1" + +echo "::group::Run Linalg e2e integration tests" +python -m e2e_testing.main --config=linalg -v +echo "::endgroup::" + +echo "::group::Run make_fx + TOSA e2e integration tests" +python -m e2e_testing.main --config=make_fx_tosa -v +echo "::endgroup::" + +echo "::group::Run TOSA e2e integration tests" +python -m e2e_testing.main --config=tosa -v +echo "::endgroup::" + +case $torch_version in + nightly) + # Failing with: NotImplementedError: + # Could not run 'aten::empty.memory_format' with arguments from the 'Lazy' backend. + # As of 2024-01-07 + # echo "::group::Run Lazy Tensor Core e2e integration tests" + # python -m e2e_testing.main --config=lazy_tensor_core -v + # echo "::endgroup::" + + # TODO: There is one failing test in this group on stable. It could + # be xfailed vs excluding entirely. + echo "::group::Run TorchDynamo e2e integration tests" + python -m e2e_testing.main --config=torchdynamo -v + echo "::endgroup::" + ;; + stable) + ;; + *) + echo "Unrecognized torch version '$torch_version' (specify 'nightly' or 'stable' with cl arg)" + exit 1 + ;; +esac diff --git a/projects/pt1/python/torch_mlir_e2e_test/framework.py b/projects/pt1/python/torch_mlir_e2e_test/framework.py index f1fbad2ec914..388976256591 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/framework.py +++ b/projects/pt1/python/torch_mlir_e2e_test/framework.py @@ -24,11 +24,19 @@ from typing import Any, Callable, List, NamedTuple, Optional, TypeVar, Union, Dict from itertools import repeat +import os import sys import traceback -import torch import multiprocess as mp +from multiprocess import set_start_method +try: + set_start_method("spawn") +except RuntimeError: + # Children can error here so we suppress. + pass + +import torch TorchScriptValue = Union[int, float, List['TorchScriptValue'], Dict['TorchScriptValue', @@ -316,8 +324,16 @@ def compile_and_run_test(test: Test, config: TestConfig, verbose=False) -> Any: def run_tests(tests: List[Test], config: TestConfig, sequential=False, verbose=False) -> List[TestResult]: - """Invoke the given `Test`'s with the provided `TestConfig`.""" - num_processes = min(int(mp.cpu_count() * 1.1), len(tests)) + """Invoke the given `Test`'s with the provided `TestConfig`.""" + num_processes = min(int(mp.cpu_count() * 0.8) + 1, len(tests)) + try: + env_concurrency = int(os.getenv("TORCH_MLIR_TEST_CONCURRENCY", "0")) + except ValueError as e: + raise ValueError("Bad value for TORCH_MLIR_TEST_CONCURRENCY env var: " + "Expected integer.") from e + if env_concurrency > 0: + num_processes = min(num_processes, env_concurrency) + # TODO: We've noticed that on certain 2 core machine parallelizing the tests # makes the llvm backend legacy pass manager 20x slower than using a # single process. Need to investigate the root cause eventually. This is a @@ -344,7 +360,7 @@ def run_tests(tests: List[Test], config: TestConfig, sequential=False, verbose=F pool = mp.Pool(num_processes) arg_list = zip(tests, repeat(config)) handles = pool.starmap_async(compile_and_run_test, arg_list) - results = handles.get() + results = handles.get(timeout=360) tests_with_results = {result.unique_name for result in results} all_tests = {test.unique_name for test in tests} diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index b8f8394459d9..d725aae6c584 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -46,6 +46,13 @@ declare_mlir_python_sources(TorchMLIRPythonSources.Tools tools/import_onnx/__main__.py ) +declare_mlir_python_sources(TorchMLIRSiteInitialize + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TorchMLIRPythonSources + SOURCES + _mlir_libs/_site_initialize_0.py +) + ################################################################################ # Extensions ################################################################################ @@ -79,6 +86,7 @@ set(_source_components MLIRPythonExtension.RegisterEverything TorchMLIRPythonSources TorchMLIRPythonExtensions + TorchMLIRSiteInitialize # Sources related to optional Torch extension dependent features. Typically # empty unless if project features are enabled. diff --git a/python/torch_mlir/_mlir_libs/_site_initialize_0.py b/python/torch_mlir/_mlir_libs/_site_initialize_0.py new file mode 100644 index 000000000000..3b93b1fa930d --- /dev/null +++ b/python/torch_mlir/_mlir_libs/_site_initialize_0.py @@ -0,0 +1,9 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# Multi-threading rarely helps the frontend and we are also running in contexts +# where we want to run a lot of test parallelism (and nproc*nproc threads +# puts a large load on the system and virtual memory). +disable_multithreading = True From 4513c3ca8702fd4fac96673207559556bb21d114 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sat, 27 Jan 2024 19:35:48 -0800 Subject: [PATCH 127/283] [ci] Add step to run unit tests. (#2820) --- .github/workflows/ci.yml | 3 +-- build_tools/ci/build_posix.sh | 4 ++++ 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4bb27526b063..525e62c9bba4 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -66,9 +66,8 @@ jobs: path: ${{ env.CACHE_DIR }} key: build-test-cpp-asserts-manylinux-v2-${{ github.sha }} - - name: Test project (torch-${{ matrix.torch-version }}) + - name: Integration tests (torch-${{ matrix.torch-version }}) run: | - export cache_dir="${{ env.CACHE_DIR }}" bash build_tools/ci/test_posix.sh ${{ matrix.torch-version }} - name: Check generated sources (torch-nightly only) diff --git a/build_tools/ci/build_posix.sh b/build_tools/ci/build_posix.sh index c6e7e168dc81..438a55c74389 100755 --- a/build_tools/ci/build_posix.sh +++ b/build_tools/ci/build_posix.sh @@ -56,5 +56,9 @@ echo "::group::Build" cmake --build "$build_dir" --target tools/torch-mlir/all -- -k 0 echo "::endgroup::" +echo "::group::Unit tests" +cmake --build $repo_root/build --target check-torch-mlir-all +echo "::endgroup::" + # Show ccache stats. ccache --show-stats From 6b3ebb237fbf9103f63a66fc41c2e27e15ab0c8c Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sat, 27 Jan 2024 19:42:29 -0800 Subject: [PATCH 128/283] [ci] Use a different cache key for torch nightly vs stable. --- .github/workflows/ci.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 525e62c9bba4..4665dabe190c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -45,9 +45,9 @@ jobs: uses: actions/cache/restore@v3 with: path: ${{ env.CACHE_DIR }} - key: build-test-cpp-asserts-manylinux-v2-${{ github.sha }} + key: build-test-cpp-asserts-manylinux-${{ matrix.torch-version }}-v2-${{ github.sha }} restore-keys: | - build-test-cpp-asserts-manylinux-v2- + build-test-cpp-asserts-manylinux-${{ matrix.torch-version }}-v2- - name: Install python deps (torch-${{ matrix.torch-version }}) run: | @@ -64,7 +64,7 @@ jobs: if: ${{ !cancelled() }} with: path: ${{ env.CACHE_DIR }} - key: build-test-cpp-asserts-manylinux-v2-${{ github.sha }} + key: build-test-cpp-asserts-manylinux-${{ matrix.torch-version }}-v2-${{ github.sha }} - name: Integration tests (torch-${{ matrix.torch-version }}) run: | From 032f225fa513b8eb781a11a0e27cb759943594cf Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Sat, 27 Jan 2024 19:43:41 -0800 Subject: [PATCH 129/283] [ci] Allow long line in YAML --- .github/workflows/ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4665dabe190c..63ef01cdeb51 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -1,3 +1,4 @@ +# yamllint disable rule:line-length name: CI on: From 67cb2e7341658e28ca8756774801e8e771067f16 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 29 Jan 2024 09:23:05 -0800 Subject: [PATCH 130/283] Fix illegal use of TypeRange (#2815) TypeRange is an ArrayRef and therefore cannot be safely instantiated from a list initializer. --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 7f2fb2b53006..f4e8a60ec1cd 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2615,11 +2615,9 @@ namespace { LogicalResult matchAndRewrite(AtenConvTbcOp op, PatternRewriter &rewriter) const override { Value emptyList = rewriter.create( - op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + op.getLoc(), + Torch::ListType::get(Torch::IntType::get(op.getContext())), SmallVector()); - Value zeroList = rewriter.create( - op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), - SmallVector{rewriter.create(op.getLoc(), rewriter.getI64IntegerAttr(0))}); Value cstFalse = rewriter.create(op.getLoc(), false); Value oneList = rewriter.create( op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), @@ -5406,8 +5404,8 @@ class DecomposeAten_EmbeddingBagOp auto resultType2 = op->getResult(2).getType(); auto resultType3 = op->getResult(3).getType(); - mlir::TypeRange returnTypes{resultType0, resultType1, resultType2, - resultType3}; + llvm::SmallVector returnTypes{resultType0, resultType1, resultType2, + resultType3}; rewriter.replaceOpWithNewOp( op, returnTypes, weight, indices, offsets, scaleGradByFreq, mode, From d3fd754b93ea10e6f3a1cc46bbb471d1a1ff287a Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 29 Jan 2024 09:40:21 -0800 Subject: [PATCH 131/283] [onnx] `onnx.MatMulInteger` lowering to `torch.mm` and `quint*` types (#2761) Torch does not have an equivalent matmul operation for integers. Instead it sidechannels the information via its quantized types. For this lowering we setup these sidechannels then invoke `torch.mm`. --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 69 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 15 ++++ 2 files changed, 84 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 7e3025da3e9b..4ee71af3fadb 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -136,6 +136,75 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); + patterns.onOp( + "MatMulInteger", 10, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value lhs, rhs, lhsZp, rhsZp; + if (binder.tensorOperandAtIndex(lhs, 0) || + binder.tensorOperandAtIndex(rhs, 1) || + binder.tensorResultType(resultType)) + return failure(); + + if (binder.tensorOperandAtIndex(lhsZp, 2)) { + lhsZp = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + } + + if (binder.tensorOperandAtIndex(rhsZp, 3)) { + rhsZp = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + } + + auto lhsTy = dyn_cast(lhs.getType()); + auto rhsTy = dyn_cast(rhs.getType()); + + if (auto zpTy = dyn_cast(lhsZp.getType())) { + for (auto dim : zpTy.getSizes()) + if (dim != 1) + return failure(); + lhsZp = rewriter.create( + binder.getLoc(), rewriter.getType(), lhsZp); + } + + if (auto zpTy = dyn_cast(rhsZp.getType())) { + for (auto dim : zpTy.getSizes()) + if (dim != 1) + return failure(); + rhsZp = rewriter.create( + binder.getLoc(), rewriter.getType(), rhsZp); + } + + Value scale = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(1.0)); + + auto q = [&](Type qty) -> Type { + if (qty.isSignedInteger(8)) + return rewriter.getType(); + if (qty.isUnsignedInteger(8)) + return rewriter.getType(); + if (qty.isSignedInteger(32)) + return rewriter.getType(); + return {}; + }; + + Type lhsQTy = rewriter.getType( + lhsTy.getOptionalSizes(), q(lhsTy.getDtype())); + Type rhsQTy = rewriter.getType( + rhsTy.getOptionalSizes(), q(rhsTy.getDtype())); + + lhs = rewriter.create( + binder.getLoc(), lhsQTy, lhs, scale, lhsZp); + rhs = rewriter.create( + binder.getLoc(), rhsQTy, rhs, scale, rhsZp); + + rewriter.replaceOpWithNewOp(binder.op, resultType, lhs, + rhs); + return success(); + }); patterns.onOp("Mul", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 6a420300cdc9..449b7e4feb32 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -222,6 +222,21 @@ func.func @test_matmul_4d(%arg0: !torch.vtensor<[1,2,3,4],f32>, %arg1: !torch.vt // ----- +// CHECK-LABEL: @test_matmulinteger +func.func @test_matmulinteger(%arg0: !torch.vtensor<[4,3],ui8>, %arg1: !torch.vtensor<[3,2],ui8>, %arg2: !torch.vtensor<[1],ui8>, %arg3: !torch.vtensor<[1],ui8>) -> !torch.vtensor<[4,2],si32> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.MatMulInteger"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[4,3],ui8>, !torch.vtensor<[3,2],ui8>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[4,2],si32> + // CHECK: %[[LITEM:.+]] = torch.aten.item %arg2 + // CHECK: %[[RITEM:.+]] = torch.aten.item %arg3 + // CHECK: %[[SCALE:.+]] = torch.constant.float 1.000000e+00 + // CHECK: %[[LMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[SCALE]], %[[LITEM]] : !torch.vtensor<[4,3],ui8>, !torch.float, !torch.int -> !torch.vtensor<[4,3],!torch.quint8> + // CHECK: %[[RMAKE:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[SCALE]], %[[RITEM]] : !torch.vtensor<[3,2],ui8>, !torch.float, !torch.int -> !torch.vtensor<[3,2],!torch.quint8> + // CHECK: %[[MM:.+]] = torch.aten.mm %[[LMAKE]], %[[RMAKE]] + // CHECK: return %[[MM]] + return %0 : !torch.vtensor<[4,2],si32> +} + +// ----- + // CHECK-LABEL: func.func @test_mul func.func @test_mul(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.mul.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> From 494089d53db4c183b3ba12e36f61ce1c7553984c Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Mon, 29 Jan 2024 12:59:33 -0500 Subject: [PATCH 132/283] Clang format refresh (#2812) After noticing a number of commits with unrelated formatting changes, I think something was changed with clang-format at one point and we're seeing a number of unrelated changes. Doing a refresh can help avoid this. The changes made here came from ``` find lib -iname *.h -o -iname *.cpp | xargs clang-format -i --style=llvm find include -iname *.h -o -iname *.cpp | xargs clang-format -i --style=llvm find projects -iname *.h -o -iname *.cpp | xargs clang-format -i --style=llvm ``` --- include/torch-mlir-c/Dialects.h | 2 +- .../Dialect/TMTensor/IR/TMTensorInterfaces.h | 2 +- .../Conversion/TorchOnnxToTorch/Patterns.h | 7 +- .../TorchToTosa/TosaLegalizeCommon.h | 20 +- .../TorchToTosa/TosaLegalizeUtils.h | 6 +- .../torch-mlir/Dialect/Torch/IR/TorchTraits.h | 6 +- .../Dialect/Torch/Transforms/Passes.h | 6 +- .../Dialect/Torch/Utils/TorchUpstream.h | 7 +- .../torch-mlir/Dialect/Torch/Utils/Utils.h | 3 +- lib/CAPI/Dialects.cpp | 5 +- lib/Conversion/Passes.cpp | 4 +- .../TorchConversionToMLProgram.cpp | 3 +- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 105 ++--- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 107 ++--- lib/Conversion/TorchToArith/TorchToArith.cpp | 26 +- lib/Conversion/TorchToLinalg/DataMovement.cpp | 125 +++--- .../TorchToLinalg/IndirectDataMovement.cpp | 69 +-- lib/Conversion/TorchToLinalg/Linear.cpp | 70 +-- lib/Conversion/TorchToLinalg/Random.cpp | 1 - lib/Conversion/TorchToLinalg/Reduction.cpp | 22 +- .../TorchToLinalg/TensorConstructors.cpp | 413 +++++++++--------- .../TorchToLinalg/Uncategorized.cpp | 5 +- lib/Conversion/TorchToLinalg/Utils.cpp | 2 +- lib/Conversion/TorchToStablehlo/Basic.cpp | 6 +- .../TorchToStablehlo/GatherScatter.cpp | 4 +- lib/Conversion/TorchToStablehlo/Pooling.cpp | 135 +++--- lib/Conversion/TorchToStablehlo/Reduction.cpp | 2 +- lib/Conversion/TorchToStablehlo/ViewLike.cpp | 5 +- .../TorchToTosa/TosaLegalizeCommon.cpp | 20 +- .../TorchToTosa/TosaLegalizeUtils.cpp | 45 +- lib/Dialect/TMTensor/Transforms/Bufferize.cpp | 3 +- lib/Dialect/Torch/IR/TorchDialect.cpp | 2 +- lib/Dialect/Torch/IR/TorchOps.cpp | 74 ++-- lib/Dialect/Torch/IR/TorchTypes.cpp | 8 +- .../Torch/Transforms/DecomposeComplexOps.cpp | 204 +++++---- .../Torch/Transforms/GlobalizeObjectGraph.cpp | 17 +- .../Torch/Transforms/InlineGlobalSlots.cpp | 13 +- .../Transforms/LowerToBackendContract.cpp | 17 +- .../ReifyAbstractInterpCalculationsUtils.cpp | 8 +- .../Transforms/SimplifyDtypeCalculations.cpp | 3 +- lib/Dialect/Torch/Utils/Utils.cpp | 8 +- .../IR/TorchConversionDialect.cpp | 6 +- .../Transforms/BackendTypeConversion.cpp | 95 ++-- .../Transforms/ConvertCustomQuantOp.cpp | 65 +-- .../VerifyLinalgOnTensorsBackendContract.cpp | 4 +- .../VerifyStablehloBackendContract.cpp | 3 +- .../csrc/base_lazy_backend/backend_impl.cpp | 72 +-- .../ltc/csrc/base_lazy_backend/backend_impl.h | 53 ++- .../ltc/csrc/base_lazy_backend/dynamic_ir.cpp | 29 +- .../mlir_lowering_context.cpp | 183 ++++---- .../base_lazy_backend/mlir_lowering_context.h | 64 ++- .../mlir_native_functions.cpp | 401 +++++++++-------- .../ltc/csrc/base_lazy_backend/mlir_node.cpp | 81 ++-- .../ltc/csrc/base_lazy_backend/mlir_node.h | 48 +- .../base_lazy_backend/mlir_node_lowering.cpp | 178 ++++---- .../base_lazy_backend/mlir_node_lowering.h | 6 +- .../base_lazy_backend/ops/device_data.cpp | 24 +- .../csrc/base_lazy_backend/ops/device_data.h | 20 +- .../csrc/base_lazy_backend/ops/generic.cpp | 8 +- .../ltc/csrc/base_lazy_backend/ops/generic.h | 12 +- .../ltc/csrc/base_lazy_backend/ops/index.cpp | 32 +- .../ltc/csrc/base_lazy_backend/ops/index.h | 32 +- .../ltc/csrc/base_lazy_backend/ops/ivalue.cpp | 8 +- .../ltc/csrc/base_lazy_backend/ops/ivalue.h | 10 +- .../ltc/csrc/base_lazy_backend/ops/split.cpp | 32 +- .../ltc/csrc/base_lazy_backend/ops/split.h | 30 +- .../ltc/csrc/base_lazy_backend/ops/to_copy.h | 73 ++-- .../csrc/base_lazy_backend/ops/unbind_int.cpp | 12 +- .../csrc/base_lazy_backend/ops/unbind_int.h | 8 +- .../base_lazy_backend/shape_inference.cpp | 253 ++++++----- .../ltc/csrc/base_lazy_backend/tensor.cpp | 10 +- projects/ltc/csrc/base_lazy_backend/tensor.h | 3 +- .../csrc/base_lazy_backend/utils/exception.h | 4 +- .../base_lazy_backend/utils/jit_utils.cpp | 16 +- .../csrc/base_lazy_backend/utils/jit_utils.h | 2 +- .../base_lazy_backend/utils/string_utils.h | 60 +-- .../csrc/base_lazy_backend/utils/sys_utils.h | 13 +- .../base_lazy_backend/utils/tensor_utils.cpp | 128 +++--- .../base_lazy_backend/utils/tensor_utils.h | 27 +- .../reference_lazy_backend/backend_impl.cpp | 32 +- .../reference_lazy_backend_pybind.cpp | 26 +- 81 files changed, 1955 insertions(+), 1798 deletions(-) diff --git a/include/torch-mlir-c/Dialects.h b/include/torch-mlir-c/Dialects.h index 99156c17009c..60f6ec1e5e26 100644 --- a/include/torch-mlir-c/Dialects.h +++ b/include/torch-mlir-c/Dialects.h @@ -22,4 +22,4 @@ MLIR_DECLARE_CAPI_DIALECT_REGISTRATION(Torch, torch); } #endif -#endif // TORCHMLIR_C_DIALECTS_H +#endif // TORCHMLIR_C_DIALECTS_H diff --git a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h index f16b436c8790..159bcea7899e 100644 --- a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h +++ b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorInterfaces.h @@ -10,9 +10,9 @@ #ifndef TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_ #define TORCH_MLIR_DIALECTS_DIALECT_TMTENSOR_IR_TMTENSORINTERFACES_H_ -#include "mlir/IR/IRMapping.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/IRMapping.h" #include "mlir/IR/OpDefinition.h" #include "mlir/Support/LLVM.h" diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 44e33ab09741..2df6f95c8ad6 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -78,8 +78,8 @@ struct OpBinder { return failure(); return success(); } - - ParseResult tensorOperandsList( llvm::SmallVectorImpl &values) { + + ParseResult tensorOperandsList(llvm::SmallVectorImpl &values) { for (uint32_t i = 0; i < op->getNumOperands(); i++) { values.push_back(op->getOperand(i)); } @@ -97,7 +97,8 @@ struct OpBinder { return success(); } - ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx, int64_t idx) { + ParseResult tensorResultTypeAtIndex(Torch::ValueTensorType &typeIdx, + int64_t idx) { if (idx >= op->getNumResults()) return failure(); auto t = toValidTensorType(op->getResult(idx).getType()); diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h index 16bf235de89e..44b9bbdde3b2 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h @@ -37,33 +37,31 @@ TosaOpT createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op, return CreateOpAndInfer(rewriter, op->getLoc(), outType, lhs, rhs); } -// This specialization is for Div op. Unlike other binary ops, it doesn't support -// floating type. +// This specialization is for Div op. Unlike other binary ops, it doesn't +// support floating type. template <> tosa::DivOp createBinaryOpAndCast(PatternRewriter &rewriter, Operation *op, TensorType outType, Value lhs, Value rhs); std::optional convertTorchIndexToTfIndices(PatternRewriter &rewriter, - Operation *op, - Value params_value, - Value index_value, - int32_t axis); + Operation *op, + Value params_value, + Value index_value, + int32_t axis); // Lowers torch.aten.Gather operators to a sequence of TOSA ops. // Revised from // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc -std::optional convertGatherNdOp(PatternRewriter &rewriter, - Operation *op, Type out_type, - Value params_value, - Value indices_value); +std::optional convertGatherNdOp(PatternRewriter &rewriter, Operation *op, + Type out_type, Value params_value, + Value indices_value); std::optional convertScatterNdOp(PatternRewriter &rewriter, Operation *op, Type outType, Value paramsValue, Value indicesValue, Value fillValues); - // Lowers ReduceAll to a sequence of TOSA ops. std::optional convertReduceAllOp(PatternRewriter &rewriter, Operation *op, diff --git a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h index 14cf9cba7d2e..44c033eb82c4 100644 --- a/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h +++ b/include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h @@ -67,7 +67,7 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType); // op. This allows shape inference during the framework to TOSA lowering. template TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty, - Args &&... args) { + Args &&...args) { auto op = rewriter.create(loc, result_ty, args...); InferShapedTypeOpInterface shapeInterface = @@ -111,7 +111,7 @@ TosaOp CreateOpAndInfer(PatternRewriter &rewriter, Location loc, Type result_ty, template void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op, - Type result_ty, Args &&... args) { + Type result_ty, Args &&...args) { auto result = CreateOpAndInfer(rewriter, op->getLoc(), result_ty, args...); rewriter.replaceOp(op, result->getResults()); @@ -119,7 +119,7 @@ void CreateReplaceOpAndInfer(PatternRewriter &rewriter, Operation *op, // Get accumulator type for AvgPool2dOp. LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input, - TypeAttr &accType); + TypeAttr &accType); } // namespace tosa } // namespace mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchTraits.h b/include/torch-mlir/Dialect/Torch/IR/TorchTraits.h index 20f1bc109885..271481f0ae8a 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchTraits.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchTraits.h @@ -36,8 +36,7 @@ class HasValueSemantics // This is a weaker form of HasValueSemantics, since that trait also requires no // aliasing. That is, HasValueSemantics implies this trait. template -class ReadOnly - : public ::mlir::OpTrait::TraitBase {}; +class ReadOnly : public ::mlir::OpTrait::TraitBase {}; // If a Torch op has this trait, it means that the op is a "trailing underscore" // op variant that performs an in-place operation on its first argument. These @@ -62,7 +61,8 @@ class AllowsTypeRefinement // by the IValue importer. template class AllowedInModuleInitializer - : public ::mlir::OpTrait::TraitBase {}; + : public ::mlir::OpTrait::TraitBase {}; } // namespace OpTrait } // namespace Torch diff --git a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h index fd7468847e5f..71111c00cd28 100644 --- a/include/torch-mlir/Dialect/Torch/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/Torch/Transforms/Passes.h @@ -61,7 +61,8 @@ struct TorchLoweringPipelineOptions Option extraLibrary{ *this, "extra-library", - llvm::cl::desc("Filename of MLIR module for splicing into the abstract interpretation library.")}; + llvm::cl::desc("Filename of MLIR module for splicing into the abstract " + "interpretation library.")}; }; /// Creates a pipeline that lowers the object graph IR that is produced by @@ -125,8 +126,7 @@ createSimplifyDtypeCalculationsPass(); std::unique_ptr> createDropAbstractInterpCalculationsPass(); -std::unique_ptr> -createEraseModuleInitializerPass(); +std::unique_ptr> createEraseModuleInitializerPass(); std::unique_ptr> createLowerToBackendContractPass(int maxIterations, bool decompose, diff --git a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h index efb114fbfa14..043dd92549b2 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h +++ b/include/torch-mlir/Dialect/Torch/Utils/TorchUpstream.h @@ -140,12 +140,7 @@ enum Reduction { None, Mean, Sum, END }; // Source: // https://github.com/pytorch/pytorch/blob/master/c10/core/MemoryFormat.h //===----------------------------------------------------------------------===// -enum MemoryFormat { - Contiguous, - Preserve, - ChannelsLast, - ChannelsLast3d -}; +enum MemoryFormat { Contiguous, Preserve, ChannelsLast, ChannelsLast3d }; //===----------------------------------------------------------------------===// // Possible values for `layout` argument in PyTorch ops that support it. diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index b5c815ca7614..beafe7d21adc 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -121,8 +121,7 @@ LogicalResult checkDefaultStrideHelper(Operation *op, PatternRewriter &rewriter, // Helper to create a tensor filled with the given scalar. Scalar would be // converted the to the element type of the given tensor type. Value createInitTensor(PatternRewriter &rewriter, Location loc, - BaseTensorType resultType, Value scalar, - Value sizeList); + BaseTensorType resultType, Value scalar, Value sizeList); // Helper to create a rank 0 tensor filled with the given `scalar`. `scalar` // would be converted to the element type of the given `inputType`. diff --git a/lib/CAPI/Dialects.cpp b/lib/CAPI/Dialects.cpp index 06be821c0cfd..048e37e083a3 100644 --- a/lib/CAPI/Dialects.cpp +++ b/lib/CAPI/Dialects.cpp @@ -9,7 +9,8 @@ #include "torch-mlir-c/Dialects.h" -#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "mlir/CAPI/Registration.h" +#include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" -MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Torch, torch, mlir::torch::Torch::TorchDialect) +MLIR_DEFINE_CAPI_DIALECT_REGISTRATION(Torch, torch, + mlir::torch::Torch::TorchDialect) diff --git a/lib/Conversion/Passes.cpp b/lib/Conversion/Passes.cpp index b9af2afa3f81..6d8adbaa146d 100644 --- a/lib/Conversion/Passes.cpp +++ b/lib/Conversion/Passes.cpp @@ -30,6 +30,4 @@ namespace { #include "torch-mlir/Conversion/Passes.h.inc" } // end namespace -void mlir::torch::registerConversionPasses() { - ::registerPasses(); -} +void mlir::torch::registerConversionPasses() { ::registerPasses(); } diff --git a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp index eab81c2bec18..6a00e5190f4b 100644 --- a/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp +++ b/lib/Conversion/TorchConversionToMLProgram/TorchConversionToMLProgram.cpp @@ -82,7 +82,8 @@ class ConvertGetNextSeedOp : public OpConversionPattern { // temp = multiplier * currentSeed + incrementStep Value mul = rewriter.create(loc, currentSeed, multiplier); Value seed = rewriter.create(loc, mul, incrementStep); - globalVar = rewriter.create(loc, seed, globalVar, ValueRange()); + globalVar = + rewriter.create(loc, seed, globalVar, ValueRange()); rewriter.create( loc, SymbolRefAttr::get(op->getContext(), getSeedGobalVarName()), globalVar); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 4ee71af3fadb..df20a83515bf 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -29,7 +29,8 @@ using namespace mlir::torch::onnx_c; // thing here, so we simplify. void mlir::torch::onnx_c::populateDefaultDomainGtoP( OnnxCustomOpConversionPattern &patterns) { - patterns.onOp("HardSigmoid", 6, + patterns.onOp( + "HardSigmoid", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value tensorOperand; @@ -39,8 +40,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.f32FloatAttr(beta, "beta", 0.5f) || binder.tensorResultType(resultType)) return failure(); - - // HardSigmoid computes the following expression: max(0, min(1, alpha * x + beta)) + + // HardSigmoid computes the following expression: + // max(0, min(1, alpha * x + beta)) Value constAlpha = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getF64FloatAttr(alpha)); @@ -51,7 +53,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // Expression: alpha * x + beta Value alpha_x_plus_beta = rewriter.create( - binder.getLoc(), resultType, tensorOperand, constBeta, /*alpha=*/constAlpha); + binder.getLoc(), resultType, tensorOperand, constBeta, + /*alpha=*/constAlpha); // Expression: min(1, alpha * x + beta) Value constantOne = rewriter.create( @@ -100,7 +103,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); return success(); - }); + }); patterns.onOp("LessOrEqual", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -109,9 +112,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.tensorResultType(resultType)) { return failure(); } - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( binder.op, resultType, lhs, rhs); - return success(); + return success(); }); patterns.onOp("Log", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -126,7 +129,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return success(); }); patterns.onOp("MatMul", 13, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { + [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || @@ -206,20 +209,20 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return success(); }); patterns.onOp("Mul", 7, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; Value lhs, rhs; if (binder.tensorOperands(lhs, rhs) || binder.tensorResultType(resultType)) { return failure(); } rewriter.replaceOpWithNewOp( - binder.op, resultType, lhs, rhs); - return success(); - }); + binder.op, resultType, lhs, rhs); + return success(); + }); patterns.onOp("NonZero", 13, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; Value operand; if (binder.tensorOperand(operand) || binder.tensorResultType(resultType)) { @@ -332,41 +335,38 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); - patterns.onOp("Max", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - llvm::SmallVector operands; - if (binder.tensorOperandsList(operands) || - binder.tensorResultType(resultType) || - operands.size() == 0) { - return failure(); - } - Value result = operands[0]; - for (uint64_t i = 1; i < operands.size(); i++) { - result = rewriter.create( - binder.getLoc(), resultType, result, operands[i]); - } - rewriter.replaceOp(binder.op, result.getDefiningOp()); - return success(); - }); - patterns.onOp("Min", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - llvm::SmallVector operands; - if (binder.tensorOperandsList(operands) || - binder.tensorResultType(resultType) || - operands.size() == 0) { - return failure(); - } - Value result = operands[0]; - for (uint64_t i = 1; i < operands.size(); i++) { - result = rewriter.create( - binder.getLoc(), resultType, result, operands[i]); - } - rewriter.replaceOp( - binder.op, result.getDefiningOp()); - return success(); - }); + patterns.onOp( + "Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if (binder.tensorOperandsList(operands) || + binder.tensorResultType(resultType) || operands.size() == 0) { + return failure(); + } + Value result = operands[0]; + for (uint64_t i = 1; i < operands.size(); i++) { + result = rewriter.create( + binder.getLoc(), resultType, result, operands[i]); + } + rewriter.replaceOp(binder.op, result.getDefiningOp()); + return success(); + }); + patterns.onOp( + "Min", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if (binder.tensorOperandsList(operands) || + binder.tensorResultType(resultType) || operands.size() == 0) { + return failure(); + } + Value result = operands[0]; + for (uint64_t i = 1; i < operands.size(); i++) { + result = rewriter.create( + binder.getLoc(), resultType, result, operands[i]); + } + rewriter.replaceOp(binder.op, result.getDefiningOp()); + return success(); + }); patterns.onOp("Neg", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -693,7 +693,8 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstStrides); - Value cstFalse = rewriter.create(binder.getLoc(), false); + Value cstFalse = + rewriter.create(binder.getLoc(), false); Value cstCeilMode = cstFalse; Value cstCountIncludePad = cstFalse; Value cstNone = rewriter.create(binder.getLoc()); @@ -903,7 +904,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return failure(); } rewriter.replaceOpWithNewOp( - binder.op, resultType, lhs, rhs); + binder.op, resultType, lhs, rhs); return success(); }); patterns.onOp( diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 6569e3abc0b5..87f68375a593 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -42,56 +42,63 @@ Value getItemOp(OpBinder binder, ConversionPatternRewriter &rewriter, void mlir::torch::onnx_c::populateDefaultDomainQtoZ( OnnxCustomOpConversionPattern &patterns) { - patterns.onOp("QuantizeLinear", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - llvm::SmallVector operands; - if (binder.tensorOperands(operands, 3) || - binder.tensorResultType(resultType)) - return failure(); + patterns.onOp( + "QuantizeLinear", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if (binder.tensorOperands(operands, 3) || + binder.tensorResultType(resultType)) + return failure(); - Value operand = operands[0]; - Value scale = operands[1]; - Value zeropoint = operands[2]; - - auto scaleTy = scale.getType().dyn_cast(); - if (!scaleTy || !scaleTy.hasSizes()) - return rewriter.notifyMatchFailure(binder.op, - "requires known rank"); - if (!resultType.hasDtype()) - return rewriter.notifyMatchFailure( - binder.op, "requires known result dtype"); - - if (scaleTy.getSizes().size() == 0) { - Type qTy = resultType.getDtype(); - - if (qTy.isUnsignedInteger(8)) { - qTy = rewriter.getType(); - } else if (qTy.isSignedInteger(8)) { - qTy = rewriter.getType(); - } else if (qTy.isSignedInteger(32)) { - qTy = rewriter.getType(); - } else { - return rewriter.notifyMatchFailure(binder.op, "unsupported result dtype"); - } - - auto qTensorTy = rewriter.getType(resultType.getOptionalSizes(), qTy); - auto torchqTy = Torch::getScalarTypeForType(qTy); - - Value tyConst = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), static_cast(torchqTy))); - - scale = rewriter.create(binder.getLoc(), rewriter.getType(), scale); - zeropoint = rewriter.create(binder.getLoc(), rewriter.getType(), zeropoint); - - auto quantize = rewriter.create(binder.getLoc(), qTensorTy, operand, scale, zeropoint, tyConst); - rewriter.replaceOpWithNewOp(binder.op, resultType, quantize); - return success(); - } - - return failure(); - }); + Value operand = operands[0]; + Value scale = operands[1]; + Value zeropoint = operands[2]; + + auto scaleTy = scale.getType().dyn_cast(); + if (!scaleTy || !scaleTy.hasSizes()) + return rewriter.notifyMatchFailure(binder.op, "requires known rank"); + if (!resultType.hasDtype()) + return rewriter.notifyMatchFailure(binder.op, + "requires known result dtype"); + + if (scaleTy.getSizes().size() == 0) { + Type qTy = resultType.getDtype(); + + if (qTy.isUnsignedInteger(8)) { + qTy = rewriter.getType(); + } else if (qTy.isSignedInteger(8)) { + qTy = rewriter.getType(); + } else if (qTy.isSignedInteger(32)) { + qTy = rewriter.getType(); + } else { + return rewriter.notifyMatchFailure(binder.op, + "unsupported result dtype"); + } + + auto qTensorTy = rewriter.getType( + resultType.getOptionalSizes(), qTy); + auto torchqTy = Torch::getScalarTypeForType(qTy); + + Value tyConst = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + static_cast(torchqTy))); + + scale = rewriter.create( + binder.getLoc(), rewriter.getType(), scale); + zeropoint = rewriter.create( + binder.getLoc(), rewriter.getType(), zeropoint); + + auto quantize = rewriter.create( + binder.getLoc(), qTensorTy, operand, scale, zeropoint, tyConst); + rewriter.replaceOpWithNewOp( + binder.op, resultType, quantize); + return success(); + } + + return failure(); + }); patterns.onOp( "QLinearMatMul", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -1245,7 +1252,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } // Convert dynamic shape dimension. - for (unsigned i = 0; i < shape.size(); i++){ + for (unsigned i = 0; i < shape.size(); i++) { if (shape[i] == ShapedType::kDynamic) shape[i] = Torch::kUnknownSize; } diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index e1e53acb2363..d2000d7fc3d2 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -43,7 +43,8 @@ class ConvertAtenDimOp : public OpConversionPattern { LogicalResult matchAndRewrite(AtenDimOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto rank = rewriter.create(op->getLoc(), adaptor.getSelf()); + auto rank = + rewriter.create(op->getLoc(), adaptor.getSelf()); rewriter.replaceOpWithNewOp( op, getTypeConverter()->convertType(op.getType()), rank); return success(); @@ -74,7 +75,8 @@ class ConvertAtenBinaryOp : public OpConversionPattern { matchAndRewrite(AtenOp op, typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.template replaceOpWithNewOp(op, adaptor.getA(), adaptor.getB()); + rewriter.template replaceOpWithNewOp(op, adaptor.getA(), + adaptor.getB()); return success(); } }; @@ -112,10 +114,10 @@ class ConvertAtenDivIntOp : public OpConversionPattern { typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { Location loc = op.getLoc(); - Value a = - convertScalarToDtype(rewriter, loc, adaptor.getA(), rewriter.getF64Type()); - Value b = - convertScalarToDtype(rewriter, loc, adaptor.getB(), rewriter.getF64Type()); + Value a = convertScalarToDtype(rewriter, loc, adaptor.getA(), + rewriter.getF64Type()); + Value b = convertScalarToDtype(rewriter, loc, adaptor.getB(), + rewriter.getF64Type()); rewriter.replaceOpWithNewOp(op, a, b); return success(); } @@ -176,15 +178,16 @@ class ConvertTorchTensorLiteralOp unsigned bitWidth = elemTy.getIntOrFloatBitWidth(); Type builtinTensorElemTy = IntegerType::get(context, bitWidth); auto shapedType = - RankedTensorType::get(type.getShape(), builtinTensorElemTy); + RankedTensorType::get(type.getShape(), builtinTensorElemTy); auto rawData = elements.getRawData(); - DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer( - shapedType, rawData); + DenseElementsAttr newAttr = + DenseElementsAttr::getFromRawBuffer(shapedType, rawData); rewriter.replaceOpWithNewOp(op, newAttr); return success(); } } - if (auto elements = op.getValueAttr().dyn_cast()) { + if (auto elements = + op.getValueAttr().dyn_cast()) { if (auto type = elements.getType().dyn_cast()) { if (auto intType = type.getElementType().dyn_cast()) { Type builtinTensorElemTy = @@ -360,7 +363,8 @@ class ConvertAtenBoolLikeOp : public OpConversionPattern { // ----------------------------------------------------------------------------- namespace { -class ConvertTorchToArith : public ConvertTorchToArithBase { +class ConvertTorchToArith + : public ConvertTorchToArithBase { public: void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index e96d65970b82..add32ff05cd6 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -110,22 +110,32 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, // Example: // input = tensor([[[0., 1., 2., 3.], // [4., 5., 6., 7.]]]) -// torch.ops.aten.reflection_pad1d(input, (3,1)) ; padding_left = 3, padding_right = 1 -// tensor([[[3., 2., 1., 0., 1., 2., 3., 2.], -// [7., 6., 5., 4., 5., 6., 7., 6.]]]) -// Checks: 1) Each of padding_left and padding_right must be non-negative less than size of last dimension -// Implementation: a) Construct a result tensor of shape of input tensor except for the last dimension. -// The last dimension of the result tensor should be last dimension of input tensor + -// left padding size + right padding size. INitialize result tensor to all zeros -// b) Setup affine map to take slice from input tensor of size left padding starting from -// second column onwards as first column is reflection boundary +// torch.ops.aten.reflection_pad1d(input, (3,1)); +// padding_left = 3, +// padding_right = 1 +// output = tensor([[[3., 2., 1., 0., 1., 2., 3., 2.], +// [7., 6., 5., 4., 5., 6., 7., 6.]]]) +// Checks: 1) Each of padding_left and padding_right must be non-negative and +// less than the size of the last dimension. +// Implementation: a) Construct a result tensor of +// shape of input tensor except for the last dimension. +// The last dimension of the result tensor should be last +// dimension of input tensor + left padding size + right +// padding size. Initialize result tensor to all zeros +// b) Setup affine map to take slice from input tensor of size +// left padding starting from +// second column onwards as first column is reflection +// boundary // c) Reflect the affine map to have resultant slice reflected // d) Take the slice and write from begining in result tensor // e) write the original tensor next into result tensor -// f) Setup affine map to take slice from input tensor of right padding size ending -// at second last column as last column is reflection boundary for right padding +// f) Setup affine map to take slice from input tensor of right +// padding size ending +// at second last column as last column is reflection +// boundary for right padding // g) Reflect the affine map to have resultant slice reflected -// h) Take the slice and write from left padding size + orignal tensor last dim size +// h) Take the slice and write from left padding size + orignal +// tensor last dim size // into result tensor // Uses the ideas/code used for AtenReflectionPad2dOp namespace { @@ -138,7 +148,7 @@ class ConvertAtenReflectionPad1dOp ConversionPatternRewriter &rewriter) const override { if (failed(verifyLinalgCompatibleTypes(op, rewriter))) return failure(); - + SmallVector padInts; if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) return rewriter.notifyMatchFailure( @@ -158,55 +168,68 @@ class ConvertAtenReflectionPad1dOp return rewriter.create(loc, x, y); }; - enum PadLocation {PAD_LEFT = 0, PAD_RIGHT = 1, PAD_CENTER=2}; + enum PadLocation { PAD_LEFT = 0, PAD_RIGHT = 1, PAD_CENTER = 2 }; Value input = adaptor.getSelf(); Type indexType = rewriter.getIndexType(); Value zero = getConstant(rewriter, loc, 0, indexType); Value one = getConstant(rewriter, loc, 1, indexType); auto inputType = llvm::cast(input.getType()); - auto outputType = llvm::cast(getTypeConverter()->convertType(op->getResult(0).getType())); + auto outputType = llvm::cast( + getTypeConverter()->convertType(op->getResult(0).getType())); unsigned numDims = inputType.getRank(); assert(numDims >= 2 && "Not enough input dimensions"); int64_t lastDim = numDims - 1; SmallVector inputShape = getTensorSizes(rewriter, loc, input); - Value lastDimSize = inputShape[lastDim]; // input [1,2,4], then lastDim = 2, inputShape[2] will give 4 + Value lastDimSize = inputShape[lastDim]; // input [1,2,4], then lastDim = 2, + // inputShape[2] will give 4 Value tileWidth[3], extractOffset[3], insertOffset[3]; - - tileWidth[PAD_LEFT] = getConstant(rewriter, loc, padInts[PAD_LEFT], indexType); - tileWidth[PAD_RIGHT] = getConstant(rewriter, loc, padInts[PAD_RIGHT], indexType); + + tileWidth[PAD_LEFT] = + getConstant(rewriter, loc, padInts[PAD_LEFT], indexType); + tileWidth[PAD_RIGHT] = + getConstant(rewriter, loc, padInts[PAD_RIGHT], indexType); tileWidth[PAD_CENTER] = lastDimSize; extractOffset[PAD_LEFT] = one; - // for (1,2,4) input, padding (3,1) lastDimSize=4, 4 - 1 - 1 = 2 [3,5, 6,7], so start offset to 6, which is right - // lasDimSize - (tileWidth[PAD_RIGHT] + one) - extractOffset[PAD_RIGHT] = createISub(lastDimSize, createIAdd(tileWidth[PAD_RIGHT], one)); + // The offset for the right hand padding "bar" is: + // [right] lastDimSize - (tileWidth[PAD_RIGHT] + one) + extractOffset[PAD_RIGHT] = + createISub(lastDimSize, createIAdd(tileWidth[PAD_RIGHT], one)); extractOffset[PAD_CENTER] = zero; insertOffset[PAD_LEFT] = zero; insertOffset[PAD_RIGHT] = createIAdd(lastDimSize, tileWidth[PAD_LEFT]); insertOffset[PAD_CENTER] = tileWidth[PAD_LEFT]; - SmallVector resultShape{inputShape}; - // Result's last dimension will have shape lastDimSize + left padding size + right padding size - resultShape[lastDim] = createIAdd(resultShape[lastDim], createIAdd(tileWidth[PAD_LEFT], tileWidth[PAD_RIGHT])); - Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape, inputType.getElementType()); + // Result's last dimension will have size: + // lastDimSize + left padding size + right padding size + resultShape[lastDim] = + createIAdd(resultShape[lastDim], + createIAdd(tileWidth[PAD_LEFT], tileWidth[PAD_RIGHT])); + Value resultTensor = createZeroInitTensor(rewriter, loc, resultShape, + inputType.getElementType()); - // Helper to reflect/reverse the i-th dimension of an affine map without symbols. This only works if applied on a tensor - // for which the corresponding dimension has a statically known size - auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i, int64_t size) { + // Helper to reflect/reverse the i-th dimension of an affine map without + // symbols. This only works if applied on a tensor for which the + // corresponding dimension has a statically known size + auto reflectDim = [](AffineMap map, unsigned numDims, int64_t i, + int64_t size) { AffineExpr d = map.getResult(i); - return map.replace(d, size - d - 1, numDims, 0); // left reflect for (3,1) on input shape (1,2,4). size = 3, lastDim=2, numDims=3 + return map.replace(d, size - d - 1, numDims, + 0); // left reflect for (3,1) on input shape (1,2,4). + // size = 3, lastDim=2, numDims=3 }; - SmallVector iteratorTypes{numDims, utils::IteratorType::parallel}; + SmallVector iteratorTypes{ + numDims, utils::IteratorType::parallel}; auto idMap = AffineMap::getMultiDimIdentityMap(numDims, context); SmallVector allOneStrides(numDims, one); auto addTileToResult = [&](PadLocation padPosition) { - // Create the tile by extracting a slice from the input tensor. + // Create the tile by extracting a slice from the input tensor. SmallVector extractShape{inputShape}; extractShape[lastDim] = tileWidth[padPosition]; SmallVector extractOffsets(numDims, zero); @@ -214,35 +237,39 @@ class ConvertAtenReflectionPad1dOp Value tile = rewriter.create( loc, input, extractOffsets, extractShape, allOneStrides); - auto inputMap = AffineMap::getMultiDimIdentityMap(numDims, context); - // Setup the affine map function to resverse the tile along the horizontal for left and right slices - if(padPosition < PAD_CENTER) { - inputMap = reflectDim(inputMap, numDims, lastDim, padInts[padPosition]); - // Take reflected slice as per inputMap - tile = rewriter.create(loc, llvm::cast(tile.getType()), tile, - tile, ArrayRef({inputMap, idMap}), iteratorTypes, - [](OpBuilder &b, Location nestedLoc, ValueRange args) { - b.create(nestedLoc, args[0]); - }).getResult(0); + // Setup the affine map function to resverse the tile along the horizontal + // for left and right slices + if (padPosition < PAD_CENTER) { + inputMap = reflectDim(inputMap, numDims, lastDim, padInts[padPosition]); + // Take reflected slice as per inputMap + tile = rewriter + .create( + loc, llvm::cast(tile.getType()), tile, + tile, ArrayRef({inputMap, idMap}), iteratorTypes, + [](OpBuilder &b, Location nestedLoc, ValueRange args) { + b.create(nestedLoc, args[0]); + }) + .getResult(0); } // Insert the tile in the resultTensor SmallVector insertOffsets(numDims, zero); insertOffsets[lastDim] = insertOffset[padPosition]; - resultTensor = rewriter.create(loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides); + resultTensor = rewriter.create( + loc, tile, resultTensor, insertOffsets, extractShape, allOneStrides); }; - - if(padInts[PAD_LEFT] > 0) - addTileToResult(PAD_LEFT); - if(padInts[PAD_RIGHT] > 0) - addTileToResult(PAD_RIGHT); + + if (padInts[PAD_LEFT] > 0) + addTileToResult(PAD_LEFT); + if (padInts[PAD_RIGHT] > 0) + addTileToResult(PAD_RIGHT); addTileToResult(PAD_CENTER); rewriter.replaceOpWithNewOp(op, outputType, resultTensor); return success(); } }; -} +} // namespace namespace { diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index f9ee56070d61..bfbe45afe167 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -79,7 +79,8 @@ class ConvertAtenGatherOp : public OpConversionPattern { int64_t dim; if (!matchPattern(dimValue, m_TorchConstantInt(&dim))) return op.emitError("unimplemented: dim is not constant"); - int64_t inputRank = adaptor.getSelf().getType().cast().getRank(); + int64_t inputRank = + adaptor.getSelf().getType().cast().getRank(); dim = toPositiveDim(dim, inputRank); if (!isValidDim(dim, inputRank)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); @@ -248,9 +249,9 @@ class ConvertAtenEmbeddingBagPaddingIdxOp } if (modeInt != torch_upstream::EmbeddingBagMode::MODE_SUM) { - return rewriter.notifyMatchFailure( - op, - "Unimplemented: Mean and Max mode are not supported yet for EmbeddingBag."); + return rewriter.notifyMatchFailure(op, + "Unimplemented: Mean and Max mode are " + "not supported yet for EmbeddingBag."); } bool isSparse; @@ -291,28 +292,28 @@ class ConvertAtenEmbeddingBagPaddingIdxOp SmallVector indicesExpr; indicesExpr.push_back(mlir::getAffineDimExpr(1, context)); auto indicesIndexingMap = - AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, - indicesExpr, context); + AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, + indicesExpr, context); SmallVector offsetsExpr; offsetsExpr.push_back(mlir::getAffineDimExpr(0, context)); auto offsetIndexingMap = - AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, - offsetsExpr, context); + AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, + offsetsExpr, context); SmallVector outputExpr; outputExpr.push_back(mlir::getAffineDimExpr(0, context)); outputExpr.push_back(mlir::getAffineDimExpr(2, context)); auto outputIndexingMap = - AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, - outputExpr, context); + AffineMap::get(/*dimCount=*/iterationMapDimension, /*symbolCount=*/0, + outputExpr, context); SmallVector indexingMaps = { - indicesIndexingMap, - offsetIndexingMap, - outputIndexingMap, + indicesIndexingMap, + offsetIndexingMap, + outputIndexingMap, }; // Reduce along the indices dim @@ -326,15 +327,15 @@ class ConvertAtenEmbeddingBagPaddingIdxOp Value indicesLength; if (!discardLastOffset) { SmallVector sizes{getDimOp(rewriter, loc, offsets, 0), - embeddingDim}; + embeddingDim}; initTensor = createZeroInitTensor(rewriter, loc, sizes, weightElemTy); offsetsLength = getDimOp(rewriter, loc, offsets, 0); indicesLength = getDimOp(rewriter, loc, indices, 0); } else { return rewriter.notifyMatchFailure( - op, "Unimplemented: include last offset is not yet " - "supported for EmbeddingBag."); + op, "Unimplemented: include last offset is not yet " + "supported for EmbeddingBag."); } Value embeddingBagResult = @@ -351,10 +352,10 @@ class ConvertAtenEmbeddingBagPaddingIdxOp Value indexI = b.create(loc, /*value=*/0); Value indexIToInt = castIndexToInt64(b, loc, indexI); - Value one = getConstant( - b, loc, 1, - mlir::IntegerType::get(getContext(), 64, - IntegerType::Signless)); + Value one = + getConstant(b, loc, 1, + mlir::IntegerType::get( + getContext(), 64, IntegerType::Signless)); Value offsetIndexPlusOneInt = b.create(loc, indexIToInt, one); @@ -378,7 +379,7 @@ class ConvertAtenEmbeddingBagPaddingIdxOp loc, arith::CmpIPredicate::eq, offsetsI, indicesIndex); Value offsetLessThanOrEqualToIndicesIndex = b.create(loc, offsetLessThanIndicesIndex, - offsetEqualToIndicesIndex); + offsetEqualToIndicesIndex); Value indicesIndexLessThanNextOffset = b.create(loc, arith::CmpIPredicate::slt, @@ -393,19 +394,18 @@ class ConvertAtenEmbeddingBagPaddingIdxOp castIntToIndex(b, loc, indexInIndices)); indexIntoWeight.push_back( b.create(loc, /*value=*/2)); - Value weightElem = b.create( - loc, weight, indexIntoWeight); - - Value addResult = b.create(loc, weightElem, - initTensorElem); - Value select = - b.create(loc, indicesIndexWithinBounds, - addResult, initTensorElem); + Value weightElem = + b.create(loc, weight, indexIntoWeight); + + Value addResult = + b.create(loc, weightElem, initTensorElem); + Value select = b.create( + loc, indicesIndexWithinBounds, addResult, initTensorElem); b.create(loc, select); - }) - .getResult(0); + }) + .getResult(0); - // cast outputType. + // cast outputType. auto restulType0 = typeConverter->convertType(op->getResult(0).getType()); Value castedEmbeddingBagResult = rewriter.create(loc, restulType0, embeddingBagResult); @@ -439,7 +439,7 @@ class ConvertAtenEmbeddingBagPaddingIdxOp rewriter.create(loc, resultType3, indicesOut); rewriter.replaceOp(op, {castedEmbeddingBagResult, castedOffsetResult, - castedBagSizeResult, castedMaxIndices}); + castedBagSizeResult, castedMaxIndices}); return success(); } @@ -552,7 +552,8 @@ static Value makeIndexValuePositive(OpBuilder &b, Location loc, Value index, // e.g. x: [2, 3] // x[[4], [6, 1]] -> x[6, 4] namespace { -class ConvertAtenIndexTensorHackedTwinOp : public OpConversionPattern { +class ConvertAtenIndexTensorHackedTwinOp + : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; LogicalResult diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 6d0d72075d76..c0585df0bcd7 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -165,7 +165,8 @@ class ConvertAtenFlipOp : public OpConversionPattern { Location loc = op->getLoc(); MLIRContext *context = op.getContext(); Value self = adaptor.getSelf(); - auto selfRank = adaptor.getSelf().getType().cast().getRank(); + auto selfRank = + adaptor.getSelf().getType().cast().getRank(); Type elementType = adaptor.getSelf().getType().cast().getElementType(); Value c1 = @@ -535,7 +536,8 @@ class ConvertAtenBmmOp : public OpConversionPattern { RankedTensorType lhsType = lhs.getType().cast(); RankedTensorType rhsType = rhs.getType().cast(); Type newResultType = getTypeConverter()->convertType(op.getType()); - Type resultElementType = newResultType.cast().getElementType(); + Type resultElementType = + newResultType.cast().getElementType(); Type lhsElementType = lhsType.cast().getElementType(); Type rhsElementType = rhsType.cast().getElementType(); @@ -547,13 +549,15 @@ class ConvertAtenBmmOp : public OpConversionPattern { // Convert the inputs element type equivalent to the result' element type. if (lhsElementType != rhsElementType) { if (lhsElementType != resultElementType) { - // True if the lhs element type is not equal to the result' element type. - lhs = torch_to_linalg::convertTensorToElementType( - rewriter, loc, lhs, resultElementType); + // True if the lhs element type is not equal to the result' element + // type. + lhs = torch_to_linalg::convertTensorToElementType(rewriter, loc, lhs, + resultElementType); } else { - // True if the rhs element type is not equal to the result' element type. - rhs = torch_to_linalg::convertTensorToElementType( - rewriter, loc, rhs, resultElementType); + // True if the rhs element type is not equal to the result' element + // type. + rhs = torch_to_linalg::convertTensorToElementType(rewriter, loc, rhs, + resultElementType); } } @@ -571,7 +575,8 @@ class ConvertAtenBmmOp : public OpConversionPattern { checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1); Value initTensor0 = createZeroInitTensor( - rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, resultElementType); + rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, + resultElementType); Value bmm = rewriter @@ -634,7 +639,8 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "only support constant int strides"); SmallVector dilationInts; - if (!matchPattern(op.getDilation(), m_TorchListOfConstantInts(dilationInts))) + if (!matchPattern(op.getDilation(), + m_TorchListOfConstantInts(dilationInts))) return rewriter.notifyMatchFailure(op, "only support constant int dilations"); @@ -838,8 +844,10 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { Value conv; // the code so far is able to respect all numSpacialDims - // the code below this point is numSpacialDims specific and groupSize specific - // TODO: factor out the above code into a helper function, and then separate convolution into: + // the code below this point is numSpacialDims specific and groupSize + // specific + // TODO: factor out the above code into a helper function, and then separate + // convolution into: // - grouped 1d-3d // - ungrouped 1d-3d if (groupSize == 1) { @@ -854,20 +862,20 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { .getResult(0); break; case 2: - conv = - rewriter - .create( - loc, outputTensor.getType(), ValueRange{paddedInput, weight}, - outputTensor, stridesAttr, dilationAttr) - .getResult(0); + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, weight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); break; case 3: - conv = - rewriter - .create( - loc, outputTensor.getType(), ValueRange{paddedInput, weight}, - outputTensor, stridesAttr, dilationAttr) - .getResult(0); + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, weight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); break; default: return rewriter.notifyMatchFailure( @@ -877,7 +885,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } else { - if(numSpacialDims != 2) + if (numSpacialDims != 2) return rewriter.notifyMatchFailure( op, "unimplemented: only 2D grouped convolution supported"); @@ -901,11 +909,11 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { loc, collapsedType, weight, collapsedDims); conv = rewriter - .create( - loc, outputTensor.getType(), - ValueRange{paddedInput, collapsedWeight}, outputTensor, - stridesAttr, dilationAttr) - .getResult(0); + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, collapsedWeight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, conv); @@ -979,7 +987,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { conv = rewriter.create( loc, outputTensor.getType(), conv, expandOutputTensor.getReassociationIndices()); - Type newResultType = getTypeConverter()->convertType(op.getType()); + Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); } diff --git a/lib/Conversion/TorchToLinalg/Random.cpp b/lib/Conversion/TorchToLinalg/Random.cpp index 26a2c0ea551a..35c349a6a673 100644 --- a/lib/Conversion/TorchToLinalg/Random.cpp +++ b/lib/Conversion/TorchToLinalg/Random.cpp @@ -194,7 +194,6 @@ class ConvertAtenUniformOp : public OpConversionPattern { }; } // namespace - void mlir::torch::torch_to_linalg::populateRandomPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index 289851cd3d27..da5ee799a566 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -100,11 +100,11 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { if (integerTy.isUnsigned()) return rewriter.notifyMatchFailure( op, opName + " to linalg.* requires input element type " - "to be signed in case of integer"); + "to be signed in case of integer"); } else { return rewriter.notifyMatchFailure( op, opName + " to linalg.* requires Float or Integer " - "input element type"); + "input element type"); } } @@ -144,8 +144,7 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { } Value filledTensorVal = - rewriter.create(loc, fillValue, initTensorVal) - .result(); + rewriter.create(loc, fillValue, initTensorVal).result(); // Create the affine expressions that will be used to // iterate over the input and output tensors. @@ -186,7 +185,7 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { Value resultVal, predicate; if (inElementType.isa()) { - arith::CmpFPredicate predType; + arith::CmpFPredicate predType; if (isMax) { predType = arith::CmpFPredicate::OGT; resultVal = rewriter.create( @@ -198,7 +197,7 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { } predicate = rewriter.create(nestedLoc, predType, - newValue, oldValue); + newValue, oldValue); } else { arith::CmpIPredicate predType; if (isMax) { @@ -220,8 +219,8 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { }); // This cast is required to fix the shape in the case of keepDim=True - Value valuesCast = rewriter.create( - loc, valResultType, linalgOp.getResult(0)); + Value valuesCast = rewriter.create(loc, valResultType, + linalgOp.getResult(0)); Value idxCast = rewriter.create(loc, idxResultType, linalgOp.getResult(1)); rewriter.replaceOp(op, {valuesCast, idxCast}); @@ -345,7 +344,8 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, Value self = convertScalarToDtype(b, loc, elem, resultElementType); auto abs = b.create(loc, self); AtenLinalgVectorNormOp::Adaptor adaptor(operands); - Value ord = convertScalarToDtype(b, loc, adaptor.getOrd(), resultElementType); + Value ord = + convertScalarToDtype(b, loc, adaptor.getOrd(), resultElementType); auto pow = b.create(loc, abs, ord); return b.create(loc, pow, result); } else if (isa(op)) { @@ -427,8 +427,8 @@ class ConvertReductionOp : public ConversionPattern { opInfo.tensorOperand = operands[0]; auto inputType = opInfo.tensorOperand.getType().cast(); - // `AtenSumOp`, `AtenMaxOp`, and `AtenMinOp` each reduce along all the dimensions of the - // input tensor. + // `AtenSumOp`, `AtenMaxOp`, and `AtenMinOp` each reduce along all the + // dimensions of the input tensor. for (int64_t i = 0; i < inputType.getRank(); i++) opInfo.dimSet.insert(i); diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 6afae47c1325..2b8eac49447a 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -83,209 +83,224 @@ class ConvertAtenConstantPadNdOp namespace { - // Lower aten.replication_pad2d operator into a sequence of - // tensor.extract_slice and tensor.concat operations. - - class ConvertAtenReplicationPad2dOp - : public OpConversionPattern { - public: - using OpConversionPattern::OpConversionPattern; - LogicalResult - matchAndRewrite(AtenReplicationPad2dOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - if (failed(verifyLinalgCompatibleTypes(op, rewriter))) - return failure(); - - Location loc = op->getLoc(); - Value input = adaptor.getSelf(); - auto inputType = llvm::cast(input.getType()); - int64_t inputRank = inputType.getRank(); - unsigned numDims = inputType.getRank(); - assert(numDims >= 2 && "Not enough input dimensions"); - - SmallVector padInts; - if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) - return rewriter.notifyMatchFailure( +// Lower aten.replication_pad2d operator into a sequence of +// tensor.extract_slice and tensor.concat operations. + +class ConvertAtenReplicationPad2dOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenReplicationPad2dOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (failed(verifyLinalgCompatibleTypes(op, rewriter))) + return failure(); + + Location loc = op->getLoc(); + Value input = adaptor.getSelf(); + auto inputType = llvm::cast(input.getType()); + int64_t inputRank = inputType.getRank(); + unsigned numDims = inputType.getRank(); + assert(numDims >= 2 && "Not enough input dimensions"); + + SmallVector padInts; + if (!matchPattern(op.getPadding(), m_TorchListOfConstantInts(padInts))) + return rewriter.notifyMatchFailure( op, "only support constant int pad ranges"); - uint64_t padRank = padInts.size() / 2; - if (padRank * 2 != padInts.size()) - return rewriter.notifyMatchFailure(op, "pad range size is not even"); - if (inputRank < 0 || padRank > (uint64_t)inputRank) - return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank"); - - SmallVector inputShape = getTensorSizes(rewriter, loc, input); - int64_t hDim = numDims - 1; - int64_t vDim = numDims - 2; - Value hDimSize = inputShape[hDim]; - Value vDimSize = inputShape[vDim]; - - enum tileHLoc { LEFT = 0, HCENTER = 1, RIGHT = 2 }; - enum tileVLoc { TOP = 0, VCENTER = 2, BOTTOM = 1, }; - // vTile denotes the vertical size of the tile - // hTile denotes the horizontal size of the tile - // The padding results are composed of following tiles: - // vTile[TOP]hTile[LEFT], vTile[TOP]hTile[HCENTER], vTile[TOP]hTile[RIGHT] - // vTile[VCENTER]hTile[LEFT], vTile[VCENTER]hTile[HCENTER], vTile[VCENTER]hTile[RIGHT] - // vTile[BOTTOM]hTile[LEFT], vTile[BOTTOM]hTile[HCENTER], vTile[BOTTOM]hTile[RIGHT] - // vTile[VCENTER]hTile[HCENTER] is the original input tensor - Type indexType = rewriter.getIndexType(); - Value vTile[3]; - Value hTile[3]; - vTile[VCENTER] = vDimSize; - hTile[HCENTER] = hDimSize; - vTile[TOP] = getConstant(rewriter, loc, padInts[2], indexType); - vTile[BOTTOM] = getConstant(rewriter, loc, padInts[3], indexType); - hTile[LEFT] = getConstant(rewriter, loc, padInts[0], indexType); - hTile[RIGHT] = getConstant(rewriter, loc, padInts[1], indexType); - - bool hasLeftPadding = false; - bool hasRightPadding = false; - bool hasTopPadding = false; - bool hasBottomPadding = false; - - for (auto i: {TOP, VCENTER, BOTTOM}){ - for (auto j: {LEFT, HCENTER, RIGHT}) { - auto constVtile{ - mlir::dyn_cast(vTile[i].getDefiningOp()) - .getValue() - .dyn_cast_or_null()}; - - auto constHtile{ - mlir::dyn_cast(hTile[j].getDefiningOp()) - .getValue() - .dyn_cast_or_null()}; - auto vSize = constVtile.getInt(); - auto hSize = constHtile.getInt(); - - if ((i == TOP) && (vSize > 0)) - hasTopPadding = true; - if ((i == BOTTOM) && (vSize > 0)) - hasBottomPadding = true; - if ((j == LEFT) && (hSize > 0)) - hasLeftPadding = true; - if ((j == RIGHT) && (hSize > 0)) - hasRightPadding = true; - } + uint64_t padRank = padInts.size() / 2; + if (padRank * 2 != padInts.size()) + return rewriter.notifyMatchFailure(op, "pad range size is not even"); + if (inputRank < 0 || padRank > (uint64_t)inputRank) + return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank"); + + SmallVector inputShape = getTensorSizes(rewriter, loc, input); + int64_t hDim = numDims - 1; + int64_t vDim = numDims - 2; + Value hDimSize = inputShape[hDim]; + Value vDimSize = inputShape[vDim]; + + enum tileHLoc { LEFT = 0, HCENTER = 1, RIGHT = 2 }; + enum tileVLoc { + TOP = 0, + VCENTER = 2, + BOTTOM = 1, + }; + // vTile denotes the vertical size of the tile + // hTile denotes the horizontal size of the tile + // The padding results are composed of following tiles: + // vTile[TOP]hTile[LEFT], vTile[TOP]hTile[HCENTER], vTile[TOP]hTile[RIGHT] + // vTile[VCENTER]hTile[LEFT], vTile[VCENTER]hTile[HCENTER], + // vTile[VCENTER]hTile[RIGHT] vTile[BOTTOM]hTile[LEFT], + // vTile[BOTTOM]hTile[HCENTER], vTile[BOTTOM]hTile[RIGHT] + // vTile[VCENTER]hTile[HCENTER] is the original input tensor + Type indexType = rewriter.getIndexType(); + Value vTile[3]; + Value hTile[3]; + vTile[VCENTER] = vDimSize; + hTile[HCENTER] = hDimSize; + vTile[TOP] = getConstant(rewriter, loc, padInts[2], indexType); + vTile[BOTTOM] = getConstant(rewriter, loc, padInts[3], indexType); + hTile[LEFT] = getConstant(rewriter, loc, padInts[0], indexType); + hTile[RIGHT] = getConstant(rewriter, loc, padInts[1], indexType); + + bool hasLeftPadding = false; + bool hasRightPadding = false; + bool hasTopPadding = false; + bool hasBottomPadding = false; + + for (auto i : {TOP, VCENTER, BOTTOM}) { + for (auto j : {LEFT, HCENTER, RIGHT}) { + auto constVtile{ + mlir::dyn_cast(vTile[i].getDefiningOp()) + .getValue() + .dyn_cast_or_null()}; + + auto constHtile{ + mlir::dyn_cast(hTile[j].getDefiningOp()) + .getValue() + .dyn_cast_or_null()}; + auto vSize = constVtile.getInt(); + auto hSize = constHtile.getInt(); + + if ((i == TOP) && (vSize > 0)) + hasTopPadding = true; + if ((i == BOTTOM) && (vSize > 0)) + hasBottomPadding = true; + if ((j == LEFT) && (hSize > 0)) + hasLeftPadding = true; + if ((j == RIGHT) && (hSize > 0)) + hasRightPadding = true; } + } - auto createSub = [&](Value x, Value y) { - return rewriter.create(loc, x, y); - }; - - // Extract left and right pad tiles. - Value zero = getConstant(rewriter, loc, 0, indexType); - Value one = getConstant(rewriter, loc, 1, indexType); - Value hDimSizeMinusOne = createSub(hDimSize, one); - Value vDimSizeMinusOne = createSub(vDimSize, one); - SmallVector allOneStrides(numDims, one); - - SmallVector extractOffsetsLT(numDims, zero); - extractOffsetsLT[hDim] = zero; - extractOffsetsLT[vDim] = zero; - SmallVector extractShapeLR(numDims, one); - extractShapeLR[hDim] = one; - extractShapeLR[vDim] = vDimSize; - - SmallVector extractOffsetsRight(numDims, zero); - extractOffsetsRight[hDim] = hDimSizeMinusOne; - extractOffsetsRight[vDim] = zero; - - SmallVector extractOffsetsBottom(numDims, zero); - extractOffsetsBottom[hDim] = zero; - extractOffsetsBottom[vDim] = vDimSizeMinusOne; - - SmallVector extractShapeTB(numDims, one); - extractShapeTB[hDim] = hDimSize; - extractShapeTB[vDim] = one; - - SmallVector tensorsLeft; - SmallVector tensorsRight; - SmallVector tensorsCenter; - Value centerTile; - SmallVector tensorsRes; - - if (hasLeftPadding) { - Value vCenterLeftSlice = rewriter.create( - loc, input, extractOffsetsLT, extractShapeLR, allOneStrides); - Value vLeftSlice = vCenterLeftSlice; - if (hasTopPadding) { - Value topLeftValue = rewriter.create( - loc, input, ValueRange{zero, zero, zero, zero}); - //pad vCenterLeftSlice on the top - SmallVector lowPadding(4, 0); - SmallVector highPadding(4, 0); - lowPadding[2] = padInts[2]; - vLeftSlice = torch_to_linalg::getPaddedTensor(op, rewriter, vLeftSlice, lowPadding, highPadding, topLeftValue); - } - if (hasBottomPadding) { - Value bottomLeftValue = rewriter.create (loc, input, ValueRange{zero, zero, vDimSizeMinusOne, zero}); - - //pad vLeftSlice at the bottom - SmallVector lowPadding(4, 0); - SmallVector highPadding(4, 0); - highPadding[2] = padInts[3]; - vLeftSlice = torch_to_linalg::getPaddedTensor(op, rewriter, vLeftSlice, lowPadding, highPadding, bottomLeftValue); - } - for (auto i=0; i(loc, 3, tensorsLeft); - tensorsRes.push_back(leftPadTile); + auto createSub = [&](Value x, Value y) { + return rewriter.create(loc, x, y); + }; + + // Extract left and right pad tiles. + Value zero = getConstant(rewriter, loc, 0, indexType); + Value one = getConstant(rewriter, loc, 1, indexType); + Value hDimSizeMinusOne = createSub(hDimSize, one); + Value vDimSizeMinusOne = createSub(vDimSize, one); + SmallVector allOneStrides(numDims, one); + + SmallVector extractOffsetsLT(numDims, zero); + extractOffsetsLT[hDim] = zero; + extractOffsetsLT[vDim] = zero; + SmallVector extractShapeLR(numDims, one); + extractShapeLR[hDim] = one; + extractShapeLR[vDim] = vDimSize; + + SmallVector extractOffsetsRight(numDims, zero); + extractOffsetsRight[hDim] = hDimSizeMinusOne; + extractOffsetsRight[vDim] = zero; + + SmallVector extractOffsetsBottom(numDims, zero); + extractOffsetsBottom[hDim] = zero; + extractOffsetsBottom[vDim] = vDimSizeMinusOne; + + SmallVector extractShapeTB(numDims, one); + extractShapeTB[hDim] = hDimSize; + extractShapeTB[vDim] = one; + + SmallVector tensorsLeft; + SmallVector tensorsRight; + SmallVector tensorsCenter; + Value centerTile; + SmallVector tensorsRes; + + if (hasLeftPadding) { + Value vCenterLeftSlice = rewriter.create( + loc, input, extractOffsetsLT, extractShapeLR, allOneStrides); + Value vLeftSlice = vCenterLeftSlice; + if (hasTopPadding) { + Value topLeftValue = rewriter.create( + loc, input, ValueRange{zero, zero, zero, zero}); + // pad vCenterLeftSlice on the top + SmallVector lowPadding(4, 0); + SmallVector highPadding(4, 0); + lowPadding[2] = padInts[2]; + vLeftSlice = torch_to_linalg::getPaddedTensor( + op, rewriter, vLeftSlice, lowPadding, highPadding, topLeftValue); + } + if (hasBottomPadding) { + Value bottomLeftValue = rewriter.create( + loc, input, ValueRange{zero, zero, vDimSizeMinusOne, zero}); + + // pad vLeftSlice at the bottom + SmallVector lowPadding(4, 0); + SmallVector highPadding(4, 0); + highPadding[2] = padInts[3]; + vLeftSlice = torch_to_linalg::getPaddedTensor( + op, rewriter, vLeftSlice, lowPadding, highPadding, bottomLeftValue); + } + for (auto i = 0; i < padInts[0]; ++i) { + tensorsLeft.push_back(vLeftSlice); + } + Value leftPadTile = + rewriter.create(loc, 3, tensorsLeft); + tensorsRes.push_back(leftPadTile); + } + if (hasTopPadding) { + Value topHcenterSlice = rewriter.create( + loc, input, extractOffsetsLT, extractShapeTB, allOneStrides); + for (auto i = 0; i < padInts[2]; ++i) { + tensorsCenter.push_back(topHcenterSlice); + } + } + tensorsCenter.push_back(input); + if (hasBottomPadding) { + Value bottomHcenterSlice = rewriter.create( + loc, input, extractOffsetsBottom, extractShapeTB, allOneStrides); + for (auto i = 0; i < padInts[3]; ++i) { + tensorsCenter.push_back(bottomHcenterSlice); } + } + centerTile = rewriter.create(loc, 2, tensorsCenter); + tensorsRes.push_back(centerTile); + + if (hasRightPadding) { + Value vCenterRightSlice = rewriter.create( + loc, input, extractOffsetsRight, extractShapeLR, allOneStrides); + Value vRightSlice = vCenterRightSlice; if (hasTopPadding) { - Value topHcenterSlice = rewriter.create( - loc, input, extractOffsetsLT, extractShapeTB, allOneStrides); - for (auto i = 0; i < padInts[2]; ++i) { - tensorsCenter.push_back(topHcenterSlice); - } + Value topRightValue = rewriter.create( + loc, input, ValueRange{zero, zero, zero, hDimSizeMinusOne}); + + // pad vCenterRightSlice on the top + SmallVector lowPadding(4, 0); + SmallVector highPadding(4, 0); + lowPadding[2] = padInts[2]; + vRightSlice = torch_to_linalg::getPaddedTensor( + op, rewriter, vRightSlice, lowPadding, highPadding, topRightValue); } - tensorsCenter.push_back(input); if (hasBottomPadding) { - Value bottomHcenterSlice = rewriter.create( - loc, input, extractOffsetsBottom, extractShapeTB, allOneStrides); - for (auto i = 0; i < padInts[3]; ++i) { - tensorsCenter.push_back(bottomHcenterSlice); - } + Value bottomRightValue = rewriter.create( + loc, input, + ValueRange{zero, zero, vDimSizeMinusOne, hDimSizeMinusOne}); + + // Pad vCenterRightSlice or vRightTopPaddedSlice at the bottom. + SmallVector lowPadding(4, 0); + SmallVector highPadding(4, 0); + highPadding[2] = padInts[3]; + vRightSlice = torch_to_linalg::getPaddedTensor( + op, rewriter, vRightSlice, lowPadding, highPadding, + bottomRightValue); + } + for (auto i = 0; i < padInts[1]; ++i) { + tensorsRight.push_back(vRightSlice); } - centerTile = rewriter.create(loc, 2, tensorsCenter); - tensorsRes.push_back(centerTile); - - if (hasRightPadding) { - Value vCenterRightSlice = rewriter.create( - loc, input, extractOffsetsRight, extractShapeLR, allOneStrides); - Value vRightSlice = vCenterRightSlice; - if (hasTopPadding) { - Value topRightValue = rewriter.create (loc, input, ValueRange{zero, zero, zero, hDimSizeMinusOne}); - - //pad vCenterRightSlice on the top - SmallVector lowPadding(4, 0); - SmallVector highPadding(4, 0); - lowPadding[2] = padInts[2]; - vRightSlice = torch_to_linalg::getPaddedTensor(op, rewriter, vRightSlice, lowPadding, highPadding, topRightValue); - } - if (hasBottomPadding) { - Value bottomRightValue = rewriter.create (loc, input, ValueRange{zero, zero, vDimSizeMinusOne, hDimSizeMinusOne}); - - // Pad vCenterRightSlice or vRightTopPaddedSlice at the bottom. - SmallVector lowPadding(4, 0); - SmallVector highPadding(4, 0); - highPadding[2] = padInts[3]; - vRightSlice = torch_to_linalg::getPaddedTensor(op, rewriter, vRightSlice, lowPadding, highPadding, bottomRightValue); - } - for (auto i=0; i(loc, 3, tensorsRight); - tensorsRes.push_back(rightPadTile); - } - Value resTensor = rewriter.create(loc, 3, tensorsRes); - Type newResultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, newResultType, resTensor); - return success(); + Value rightPadTile = + rewriter.create(loc, 3, tensorsRight); + tensorsRes.push_back(rightPadTile); } - }; -} + Value resTensor = rewriter.create(loc, 3, tensorsRes); + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, resTensor); + return success(); + } +}; +} // namespace namespace { // Converts constant tensor allocation like ops. @@ -348,8 +363,8 @@ class ConvertConstantTensorAllocOp : public OpConversionPattern { // Create an uninitialized tensor of `resultSize` shape and fill it with // value `fillVal`. Value constVal = getConstant(rewriter, loc, fillVal, resultElementType); - Value outputTensor = - createInitTensor(rewriter, loc, resultSizeIndex, resultElementType, constVal); + Value outputTensor = createInitTensor(rewriter, loc, resultSizeIndex, + resultElementType, constVal); rewriter.replaceOpWithNewOp(op, resultType, outputTensor); return success(); } @@ -384,7 +399,8 @@ class ConvertAtenEmptyMemoryFormatOp // Only `none`, `contiguous` and `preserve` memory_format is supported. if (!op.getMemoryFormat().getType().isa()) { int64_t memoryFormat; - if (!matchPattern(op.getMemoryFormat(), m_TorchConstantInt(&memoryFormat))) + if (!matchPattern(op.getMemoryFormat(), + m_TorchConstantInt(&memoryFormat))) return rewriter.notifyMatchFailure( op, "unimplemented: the memory format should be specified in " "an integer constant"); @@ -495,7 +511,8 @@ class ConvertAtenArangeStartStepOp typeConverter->convertType(op->getResult(0).getType()) .cast(); Type dtype = resultType.getElementType(); - Value start = convertScalarToDtype(rewriter, loc, adaptor.getStart(), dtype); + Value start = + convertScalarToDtype(rewriter, loc, adaptor.getStart(), dtype); Value end = convertScalarToDtype(rewriter, loc, adaptor.getEnd(), dtype); Value step = convertScalarToDtype(rewriter, loc, adaptor.getStep(), dtype); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 9ff4c63741b2..54317979353d 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -426,10 +426,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (isa(op)) return b.create(loc, payloadArgs[0]); - if (isa(op)){ + if (isa(op)) { Value abs = b.create(loc, payloadArgs[0]); Value infinity = b.create( - loc, b.getFloatAttr(abs.getType(), std::numeric_limits::infinity())); + loc, + b.getFloatAttr(abs.getType(), std::numeric_limits::infinity())); return createEqual(b, loc, abs.getType(), abs, infinity); } if (isa(op)) { diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 8bff5034c6b4..0d62010d7b55 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -7,13 +7,13 @@ // //===----------------------------------------------------------------------===// +#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "../PassDetail.h" #include "PopulatePatterns.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Tensor/Utils/Utils.h" #include "mlir/IR/Matchers.h" #include "torch-mlir/Conversion/TorchToLinalg/Utils.h" #include "torch-mlir/Conversion/Utils/Utils.h" diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index f0dc4aaf2dfa..00c9fcd7b88f 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -923,8 +923,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op.getA().getType().template cast().getDtype(); Type resultType = this->getTypeConverter()->convertType(op->getResult(0).getType()); - auto result = - rewriter.create(loc, adaptor.getA()); + auto result = rewriter.create(loc, adaptor.getA()); rewriter.replaceOp( op, convertScalarToDtype(rewriter, loc, result, resultType, inputDtype)); @@ -1797,8 +1796,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( #define INSERT_TENSOR_TO_SCALAR_PATTERN(AtenOp) \ target.addIllegalOp(); \ - patterns.add>(typeConverter, \ - context) + patterns.add>(typeConverter, context) INSERT_TENSOR_TO_SCALAR_PATTERN(AtenIntTensorOp); INSERT_TENSOR_TO_SCALAR_PATTERN(AtenFloatTensorOp); diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index 9c8123bfdbad..d2b0450cd19a 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -30,8 +30,8 @@ using namespace mlir::torch::torch_to_stablehlo; namespace { static Value createInitialValueForGatherScatterOp(Operation *op, - RankedTensorType constType, - PatternRewriter &rewriter) { + RankedTensorType constType, + PatternRewriter &rewriter) { auto elementTy = constType.getElementType(); if (isa(op)) { if (elementTy.isa()) { diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index e90f231c74f5..7ef69ae6712d 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -35,7 +35,8 @@ static Value createInitialValueForAtenPoolingOp(Operation *op, Type elementTy, PatternRewriter &rewriter) { auto constType = RankedTensorType::get({}, elementTy); // Avg pooling - if (isa(op)) { + if (isa(op)) { if (elementTy.isa()) { auto constAttr = DenseElementsAttr::get( constType, {APFloat::getZero( @@ -373,7 +374,6 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return success(); } - namespace { template class ConvertAtenAvgPoolOp : public ConvertAtenOp { @@ -388,45 +388,45 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { Type inputElemTy = inputTy.getElementType(); int64_t inputRank = inputTy.getRank(); RankedTensorType outTy = ConvertAtenOp::getTypeConverter() - ->convertType(op.getType()) - .template cast(); + ->convertType(op.getType()) + .template cast(); auto outShape = outTy.getShape(); - if (inputRank <= Dim) { - return op.emitError( - "avg_pooling1d/2d only supports inputs with rank higher than 1/2"); + return op.emitError( + "avg_pooling1d/2d only supports inputs with rank higher than 1/2"); } SmallVector padding, kernelSize, stride; bool ceilMode = false; bool countIncludePad = true; if (!(matchPattern(op.getKernelSize(), - m_TorchListOfConstantInts(kernelSize)))) { - return rewriter.notifyMatchFailure( - op, "non-const int kernel size unsupported!"); + m_TorchListOfConstantInts(kernelSize)))) { + return rewriter.notifyMatchFailure( + op, "non-const int kernel size unsupported!"); } if (!(matchPattern(op.getStride(), m_TorchListOfConstantInts(stride)))) { - return rewriter.notifyMatchFailure(op, "non-const int stride unsupported!"); + return rewriter.notifyMatchFailure(op, + "non-const int stride unsupported!"); } if (!(matchPattern(op.getPadding(), m_TorchListOfConstantInts(padding)))) { - return rewriter.notifyMatchFailure(op, - "non-const int padding unsupported!"); + return rewriter.notifyMatchFailure(op, + "non-const int padding unsupported!"); } if (!(matchPattern(op.getCeilMode(), m_TorchConstantBool(&ceilMode)))) { - return rewriter.notifyMatchFailure(op, - "non-const bool ceil_mode unsupported!"); + return rewriter.notifyMatchFailure( + op, "non-const bool ceil_mode unsupported!"); } if (!(matchPattern(op.getCountIncludePad(), - m_TorchConstantBool(&countIncludePad)))) { - return rewriter.notifyMatchFailure( - op, "non-const bool count_include_pad unsupported!"); + m_TorchConstantBool(&countIncludePad)))) { + return rewriter.notifyMatchFailure( + op, "non-const bool count_include_pad unsupported!"); } if constexpr (std::is_same()) { - if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride()))) - return rewriter.notifyMatchFailure( - op, "only None divisor_override supported for now!"); + if (succeeded(checkNotNone(rewriter, op, op.getDivisorOverride()))) + return rewriter.notifyMatchFailure( + op, "only None divisor_override supported for now!"); } // Prepend 1 to kernelSize, stride, dilation until they are of same rank @@ -437,33 +437,35 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { SmallVector stablehloPadding(inputRank * 2, 0); std::copy(stride.begin(), stride.end(), - stablehloStride.begin() + inputRank - Dim); + stablehloStride.begin() + inputRank - Dim); std::copy(kernelSize.begin(), kernelSize.end(), - stablehloKernelSize.begin() + inputRank - Dim); + stablehloKernelSize.begin() + inputRank - Dim); if (Dim == 1) { - stablehloPadding[stablehloPadding.size() - 2] = padding[0]; - stablehloPadding[stablehloPadding.size() - 1] = padding[0]; + stablehloPadding[stablehloPadding.size() - 2] = padding[0]; + stablehloPadding[stablehloPadding.size() - 1] = padding[0]; } else { - stablehloPadding[stablehloPadding.size() - 4] = padding[0]; - stablehloPadding[stablehloPadding.size() - 3] = padding[0]; - stablehloPadding[stablehloPadding.size() - 2] = padding[1]; - stablehloPadding[stablehloPadding.size() - 1] = padding[1]; + stablehloPadding[stablehloPadding.size() - 4] = padding[0]; + stablehloPadding[stablehloPadding.size() - 3] = padding[0]; + stablehloPadding[stablehloPadding.size() - 2] = padding[1]; + stablehloPadding[stablehloPadding.size() - 1] = padding[1]; } - Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); + Value initVal = + createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloKernelSize.size())}, - rewriter.getI64Type()), + RankedTensorType::get( + {static_cast(stablehloKernelSize.size())}, + rewriter.getI64Type()), stablehloKernelSize); DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(stablehloStride.size())}, - rewriter.getI64Type()), + rewriter.getI64Type()), stablehloStride); DenseIntElementsAttr baseDilations; DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( RankedTensorType::get({static_cast(stablehloDilation.size())}, - rewriter.getI64Type()), + rewriter.getI64Type()), stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( @@ -485,31 +487,31 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { auto secondArg = *sumBlock.args_rbegin(); { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&sumBlock); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&sumBlock); - Value sumResult = - rewriter.create(op->getLoc(), firstArg, secondArg); - rewriter.create(op->getLoc(), sumResult); + Value sumResult = + rewriter.create(op->getLoc(), firstArg, secondArg); + rewriter.create(op->getLoc(), sumResult); } // Use kernel size as the divisor if (countIncludePad) { - Value divisor; - if (Dim == 1) { - divisor = - hlo::getConstTensor(rewriter, op, {kernelSize[0]}, {}) - .value(); - } else { - divisor = hlo::getConstTensor( - rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) - .value(); - } - divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy); - DenseIntElementsAttr bcastDimensions; - rewriter.replaceOpWithNewOp( - op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions); - return success(); + Value divisor; + if (Dim == 1) { + divisor = + hlo::getConstTensor(rewriter, op, {kernelSize[0]}, {}) + .value(); + } else { + divisor = hlo::getConstTensor( + rewriter, op, {kernelSize[0] * kernelSize[1]}, {}) + .value(); + } + divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy); + DenseIntElementsAttr bcastDimensions; + rewriter.replaceOpWithNewOp( + op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions); + return success(); } // Use another mhlo.ReduceWindowOp to get the divisor @@ -518,8 +520,8 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { windowSizeConst = hlo::promoteType(rewriter, op.getLoc(), windowSizeConst, outTy); const auto &options = ConvertAtenOp::getOptions(); - auto inputShapeVec = - *hlo::getDimSizesOfTensor(rewriter, op, input, options.dimSizeIndexBits); + auto inputShapeVec = *hlo::getDimSizesOfTensor(rewriter, op, input, + options.dimSizeIndexBits); auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); @@ -544,23 +546,20 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { secondArg = *sizeBlock.args_rbegin(); { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointToStart(&sizeBlock); + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(&sizeBlock); - Value sumResult = - rewriter.create(op->getLoc(), firstArg, secondArg); - rewriter.create(op->getLoc(), sumResult); + Value sumResult = + rewriter.create(op->getLoc(), firstArg, secondArg); + rewriter.create(op->getLoc(), sumResult); } rewriter.replaceOpWithNewOp( op, outTy, reduceWindowSum.getResult(0), reduceWindowSize.getResult(0)); return success(); - } - }; -} - +} // namespace // AtenCumsumOp template <> @@ -660,10 +659,10 @@ void mlir::torch::torch_to_stablehlo::populatePoolingOpPatternsAndLegality( context, options); target.addIllegalOp(); patterns.add>(typeConverter, context, options); -#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \ +#define INSERT_ATEN_AVGPOOL_PATTERN(AtenOp, Dim) \ target.addIllegalOp(); \ - patterns.add>( \ - typeConverter, context, options) + patterns.add>(typeConverter, context, \ + options) INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool1dOp, 1); INSERT_ATEN_AVGPOOL_PATTERN(AtenAvgPool2dOp, 2); #undef INSERT_ATEN_AVGPOOL_PATTERN diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index 36f4d49e9a99..f495aa39508f 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -16,13 +16,13 @@ #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/ChloOps.h" #include "stablehlo/dialect/StablehloOps.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" using namespace mlir; using namespace mlir::torch; diff --git a/lib/Conversion/TorchToStablehlo/ViewLike.cpp b/lib/Conversion/TorchToStablehlo/ViewLike.cpp index ea19092e6c8b..507821dee638 100644 --- a/lib/Conversion/TorchToStablehlo/ViewLike.cpp +++ b/lib/Conversion/TorchToStablehlo/ViewLike.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "stablehlo/dialect/StablehloOps.h" +#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" @@ -22,7 +23,6 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include "torch-mlir/Conversion/TorchToStablehlo/StablehloLegalizeUtils.h" #include using namespace mlir; @@ -403,7 +403,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) return op->emitError("dim must be a Scalar constant"); - int64_t inputRank = adaptor.getSelf().getType().cast().getRank(); + int64_t inputRank = + adaptor.getSelf().getType().cast().getRank(); dim = toPositiveDim(dim, inputRank + 1); if (!isValidDim(dim, inputRank + 1)) return rewriter.notifyMatchFailure(op, "dim is statically invalid"); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp index acaa60ffc9ad..5ed681c6e7f9 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeCommon.cpp @@ -131,10 +131,10 @@ tosa::DivOp createBinaryOpAndCast(PatternRewriter &rewriter, } std::optional convertTorchIndexToTfIndices(PatternRewriter &rewriter, - Operation *op, - Value paramsValue, - Value indexValue, - int32_t axis) { + Operation *op, + Value paramsValue, + Value indexValue, + int32_t axis) { // For easy understanding of this algorithm, the following comments are with // an exact example: torch.aten.gather(!torch.vtensor<[1,4,3],f32>, axis=2, // !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],f32> @@ -210,9 +210,9 @@ std::optional convertTorchIndexToTfIndices(PatternRewriter &rewriter, // Lowers Gather operators to a sequence of TOSA ops. // taken from // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc -std::optional convertGatherNdOp(PatternRewriter &rewriter, - Operation *op, Type outType, - Value paramsValue, Value indicesValue) { +std::optional convertGatherNdOp(PatternRewriter &rewriter, Operation *op, + Type outType, Value paramsValue, + Value indicesValue) { auto resultType = outType.dyn_cast(); auto paramsType = paramsValue.getType().dyn_cast(); auto indicesType = indicesValue.getType().dyn_cast(); @@ -683,7 +683,6 @@ std::optional convertScatterNdOp(PatternRewriter &rewriter, .getResult(); } - // Common function for lowering reduce operations to TOSA ops. template std::optional convertReduceOpCommon( @@ -721,9 +720,8 @@ std::optional convertReduceOpCommon( auto axis_attr = rewriter.getI32IntegerAttr(axis_val); shape_vec[axis_val] = 1; - RankedTensorType reduce_type = RankedTensorType::get( - shape_vec, - reduce_element_type); + RankedTensorType reduce_type = + RankedTensorType::get(shape_vec, reduce_element_type); auto reduce_op = CreateOpAndInfer(rewriter, op->getLoc(), reduce_type, val, axis_attr); diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index b8f719792476..781a5912d83c 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -176,7 +176,8 @@ std::optional getZerosLikeTensor(PatternRewriter &rewriter, // Default template creates a constant tensor in T. template std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, - ArrayRef vec, ArrayRef shape, std::optional dtype) { + ArrayRef vec, ArrayRef shape, + std::optional dtype) { uint64_t num_total_elements = 1; for (int64_t a : shape) { num_total_elements *= a; @@ -188,7 +189,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, } auto width = sizeof(T) * 8; - if constexpr(std::is_same_v) + if constexpr (std::is_same_v) width = 1; auto const_type = @@ -199,7 +200,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, rewriter.create(op->getLoc(), const_type, const_attr); if (dtype) { - return rewriter.createOrFold( + return rewriter.createOrFold( op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); } return const_op.getResult(); @@ -209,7 +210,8 @@ std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, template <> std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, ArrayRef vec, - ArrayRef shape, std::optional dtype) { + ArrayRef shape, + std::optional dtype) { uint64_t num_total_elements = 1; for (int64_t a : shape) { num_total_elements *= a; @@ -228,7 +230,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, rewriter.create(op->getLoc(), const_type, const_attr); if (dtype) { - return rewriter.createOrFold( + return rewriter.createOrFold( op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); } return const_op.getResult(); @@ -238,7 +240,8 @@ std::optional getConstTensor(PatternRewriter &rewriter, template <> std::optional getConstTensor(PatternRewriter &rewriter, Operation *op, ArrayRef vec, - ArrayRef shape, std::optional dtype) { + ArrayRef shape, + std::optional dtype) { uint64_t num_total_elements = 1; for (int64_t a : shape) { num_total_elements *= a; @@ -256,7 +259,7 @@ std::optional getConstTensor(PatternRewriter &rewriter, rewriter.create(op->getLoc(), const_type, const_attr); if (dtype) { - return rewriter.createOrFold( + return rewriter.createOrFold( op->getLoc(), RankedTensorType::get(shape, *dtype), const_op); } return const_op.getResult(); @@ -347,23 +350,17 @@ Value promoteType(PatternRewriter &rewriter, Value input, TensorType outType) { } // Template instantiation -template std::optional getConstTensor(PatternRewriter &, - Operation *, - ArrayRef vec, - ArrayRef shape, - std::optional dtype); - -template std::optional getConstTensor(PatternRewriter &, - Operation *, - ArrayRef vec, - ArrayRef shape, - std::optional dtype); - -template std::optional getConstTensor(PatternRewriter &, - Operation *, - ArrayRef vec, - ArrayRef shape, - std::optional dtype); +template std::optional +getConstTensor(PatternRewriter &, Operation *, ArrayRef vec, + ArrayRef shape, std::optional dtype); + +template std::optional +getConstTensor(PatternRewriter &, Operation *, ArrayRef vec, + ArrayRef shape, std::optional dtype); + +template std::optional +getConstTensor(PatternRewriter &, Operation *, ArrayRef vec, + ArrayRef shape, std::optional dtype); LogicalResult getAvgPool2dAccType(PatternRewriter &rewriter, Value input, TypeAttr &accType) { diff --git a/lib/Dialect/TMTensor/Transforms/Bufferize.cpp b/lib/Dialect/TMTensor/Transforms/Bufferize.cpp index 64352ad1d5ce..1e8c91e8afd4 100644 --- a/lib/Dialect/TMTensor/Transforms/Bufferize.cpp +++ b/lib/Dialect/TMTensor/Transforms/Bufferize.cpp @@ -87,7 +87,8 @@ static TMTensorOp createTMTensorOpOnBuffers(ConversionPatternRewriter &rewriter, ValueRange outputs) { SmallVector newOperands = inputs; newOperands.append(outputs.begin(), outputs.end()); - return cast(tmtensorOp.clone(rewriter, tmtensorOp->getLoc(), {}, newOperands)); + return cast( + tmtensorOp.clone(rewriter, tmtensorOp->getLoc(), {}, newOperands)); } /// Generic conversion pattern that matches any TMTensorOp. This avoids template diff --git a/lib/Dialect/Torch/IR/TorchDialect.cpp b/lib/Dialect/Torch/IR/TorchDialect.cpp index 5c90df8e6ac4..e7fcbb434a2c 100644 --- a/lib/Dialect/Torch/IR/TorchDialect.cpp +++ b/lib/Dialect/Torch/IR/TorchDialect.cpp @@ -157,7 +157,7 @@ Operation *TorchDialect::materializeConstant(OpBuilder &builder, return builder.create(loc, intValue); } } - + if (type.isa()) { return builder.create(loc, value.cast()); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index e63a4e376013..4af9bcfc1e3b 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -203,8 +203,8 @@ static Value getScalarFloatValue(Value input, Location loc, //===----------------------------------------------------------------------===// LogicalResult MethodOp::verifySymbolUses(SymbolTableCollection &symbolTable) { - auto func = - symbolTable.lookupNearestSymbolFrom(*this, getFunctionAttr()); + auto func = symbolTable.lookupNearestSymbolFrom( + *this, getFunctionAttr()); if (!func) return emitError() << "'@" << getFunction() << "' does not reference a valid function"; @@ -453,11 +453,13 @@ void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, // If the condition is constant, delete the dead branch and inline the live // branch. patterns.add(+[](PrimIfOp op, PatternRewriter &rewriter) { - auto constantBool = op.getCondition().getDefiningOp(); + auto constantBool = + op.getCondition().getDefiningOp(); if (!constantBool) return rewriter.notifyMatchFailure(op, "non-constant condition"); - replaceOpWithRegion( - rewriter, op, constantBool.getValue() ? op.getThenRegion() : op.getElseRegion()); + replaceOpWithRegion(rewriter, op, + constantBool.getValue() ? op.getThenRegion() + : op.getElseRegion()); return success(); }); // If the thenRegion and elseRegion yield the same Value's, then use those @@ -515,14 +517,16 @@ void PrimIfOp::getCanonicalizationPatterns(RewritePatternSet &patterns, continue; newResultTypes.push_back(op->getResult(i).getType()); } - auto newIf = - rewriter.create(op->getLoc(), newResultTypes, op.getCondition()); + auto newIf = rewriter.create(op->getLoc(), newResultTypes, + op.getCondition()); rewriter.inlineRegionBefore(op.getThenRegion(), newIf.getThenRegion(), newIf.getThenRegion().end()); rewriter.inlineRegionBefore(op.getElseRegion(), newIf.getElseRegion(), newIf.getElseRegion().end()); - newIf.getThenRegion().front().getTerminator()->eraseOperands(resultsToErase); - newIf.getElseRegion().front().getTerminator()->eraseOperands(resultsToErase); + newIf.getThenRegion().front().getTerminator()->eraseOperands( + resultsToErase); + newIf.getElseRegion().front().getTerminator()->eraseOperands( + resultsToErase); SmallVector replacementValues; for (int i = 0, e = op->getNumResults(), nextNewValue = 0; i < e; ++i) { if (resultsToErase[i]) @@ -548,8 +552,8 @@ void RuntimeAssertOp::getCanonicalizationPatterns(RewritePatternSet &patterns, return failure(); if (value) { - rewriter.eraseOp(op); - return success(); + rewriter.eraseOp(op); + return success(); } // Even if the condition is statically false, the assert might never be // executed. @@ -898,10 +902,10 @@ void AtenToOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, auto rhs = op.getOther(); auto getRhsDevice = rewriter.create(op.getLoc(), rhs); auto getRhsDtype = rewriter.create(op.getLoc(), rhs); - rewriter.replaceOpWithNewOp( - op, op.getType(), lhs, getRhsDevice.getResult(), - getRhsDtype.getResult(), op.getNonBlocking(), - op.getCopy(), op.getMemoryFormat()); + rewriter.replaceOpWithNewOp( + op, op.getType(), lhs, getRhsDevice.getResult(), + getRhsDtype.getResult(), op.getNonBlocking(), op.getCopy(), + op.getMemoryFormat()); return success(); }); } @@ -996,7 +1000,7 @@ void AtenMaxOtherOp::getCanonicalizationPatterns(RewritePatternSet &patterns, // `aten.max.other` -> `aten.maximum` patterns.add(+[](AtenMaxOtherOp op, PatternRewriter &rewriter) { rewriter.replaceOpWithNewOp(op, op.getType(), op.getSelf(), - op.getOther()); + op.getOther()); return success(); }); } @@ -1934,7 +1938,7 @@ void Torch::ConstantFloatOp::getAsmResultNames( // float string representation). SmallVector buf; getValue().toString(buf, /*FormatPrecision=*/6, /*FormatMaxPadding=*/0, - /*TruncateZero=*/false); + /*TruncateZero=*/false); auto isValidMLIRIdentifierChar = [](char c) { return isalpha(c) || isdigit(c) || c == '_' || c == '$' || c == '.' || c == '-'; @@ -2045,7 +2049,8 @@ void Aten__Getitem__TOp::getCanonicalizationPatterns( // compiler treat the size as having value semantics? // There's a small number of such ops, and they are marked as `inplace_view` // in PyTorch's `native_functions.yaml` file. - rewriter.replaceOpWithNewOp(op, sizeOp.getSelf(), op.getIdx()); + rewriter.replaceOpWithNewOp(op, sizeOp.getSelf(), + op.getIdx()); return success(); }); } @@ -2073,11 +2078,13 @@ OpFoldResult AtenIsFloatingPointOp::fold(FoldAdaptor adaptor) { void AtenAddTOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](AtenAddTOp op, PatternRewriter &rewriter) { - auto lhsListConstruct = op.getA().getDefiningOp(); + auto lhsListConstruct = + op.getA().getDefiningOp(); if (!lhsListConstruct || isListPotentiallyMutated(lhsListConstruct)) return failure(); - auto rhsListConstruct = op.getB().getDefiningOp(); + auto rhsListConstruct = + op.getB().getDefiningOp(); if (!rhsListConstruct || isListPotentiallyMutated(rhsListConstruct)) return failure(); @@ -2195,7 +2202,8 @@ LogicalResult PrimTupleConstructOp::verify() { void PrimTupleIndexOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](PrimTupleIndexOp op, PatternRewriter &rewriter) { - auto tupleConstruct = op.getTup().getDefiningOp(); + auto tupleConstruct = + op.getTup().getDefiningOp(); if (!tupleConstruct) return failure(); @@ -2245,7 +2253,8 @@ void PrimUninitializedOp::getCanonicalizationPatterns( void PrimTupleUnpackOp::getCanonicalizationPatterns(RewritePatternSet &patterns, MLIRContext *context) { patterns.add(+[](PrimTupleUnpackOp op, PatternRewriter &rewriter) { - auto tupleConstruct = op.getTup().getDefiningOp(); + auto tupleConstruct = + op.getTup().getDefiningOp(); if (!tupleConstruct) return failure(); @@ -2400,9 +2409,7 @@ atenBinaryFloatOperatorFoldHelper(ArrayRef operands, // AtenAliasOp //===----------------------------------------------------------------------===// -OpFoldResult AtenAliasOp::fold(FoldAdaptor adaptor) { - return getOperand(); -} +OpFoldResult AtenAliasOp::fold(FoldAdaptor adaptor) { return getOperand(); } //===----------------------------------------------------------------------===// // AtenFloordivIntOp @@ -2481,14 +2488,12 @@ OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { - int64_t start, end, step; - if (matchPattern(getStart(), m_TorchConstantInt(&start)) && - matchPattern(getEnd(), m_TorchConstantInt(&end)) && - matchPattern(getStep(), m_TorchConstantInt(&step)) - && step == 1 - && start == 0 - && end == std::numeric_limits::max()) - return getOperand(0); + int64_t start, end, step; + if (matchPattern(getStart(), m_TorchConstantInt(&start)) && + matchPattern(getEnd(), m_TorchConstantInt(&end)) && + matchPattern(getStep(), m_TorchConstantInt(&step)) && step == 1 && + start == 0 && end == std::numeric_limits::max()) + return getOperand(0); auto inType = getOperand(0).getType().dyn_cast(); auto outType = getResult().getType().dyn_cast(); @@ -2744,7 +2749,7 @@ OpFoldResult AtenIntTensorOp::fold(FoldAdaptor adaptor) { // aten.Int.Tensor, fold to the scalar number. if (auto numToTensorScalar = getA().getDefiningOp()) return numToTensorScalar.getA(); - if (auto tensorIntOp = getA().getDefiningOp()) + if (auto tensorIntOp = getA().getDefiningOp()) return tensorIntOp.getT(); return nullptr; } @@ -2955,7 +2960,6 @@ LogicalResult AtenPermuteOp::verify() { << " elements, the output has rank " << outRank << '.'; } - // Initialization of the reverse permutation. -1 denotes an unknown // permutation index. SmallVector reversePermutation(outRank, -1); diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index a154fb4653c4..7e3f37a7b870 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -440,7 +440,7 @@ static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) { } else if (auto integerType = dtype.dyn_cast()) { return IntegerType::get(context, integerType.getWidth(), IntegerType::Signless); - } else if (dtype.isa()){ + } else if (dtype.isa()) { return dtype; } @@ -556,9 +556,9 @@ Type Torch::meetTensorTypes(BaseTensorType lhs, BaseTensorType rhs) { // TODO: These are not DRY in that the two type predicates AnyTorchDictKeyType // and AnyTorchType generate the exact same code (in TorchOps.cpp.inc). -// Unfortunately the generated implementations aren't visible/exposed ("static" linkage) -// and the predicates themselves can't be added/used in the specification of the parameters -// of the Torch_DictType. +// Unfortunately the generated implementations aren't visible/exposed ("static" +// linkage) and the predicates themselves can't be added/used in the +// specification of the parameters of the Torch_DictType. static bool isAnyTorchDictKeyType(Type type) { return type.isa() || type.isa() || type.isa() || type.isa() || diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index f4e8a60ec1cd..d1794de930b4 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -355,7 +355,7 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, auto rhsType = rhs.getType().cast(); Type outputDType = lhsType.hasDtype() ? lhsType.getOptionalDtype() - : rhsType.getOptionalDtype(); + : rhsType.getOptionalDtype(); llvm::SmallDenseMap lhsDimShapeMap; for (size_t idx = 0; idx < lhsTokens.size(); ++idx) { @@ -457,7 +457,6 @@ static LogicalResult performMatmul(PatternRewriter &rewriter, Location loc, return success(); } - static Value performLastReduceAndPermute(PatternRewriter &rewriter, Location loc, Type outType, Value input, @@ -1269,7 +1268,8 @@ class DecomposeAten_LogSoftmaxBackwardDataOp }; } // namespace -// Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into `AtenMinDimOp` +// Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into +// `AtenMinDimOp` namespace { template class DecomposeAtenArgMinMaxOp : public OpRewritePattern { @@ -1300,9 +1300,9 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern { .cast(); // If the dim type is `NoneType` i.e. reduce along all the dimensions. - // `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so first the input - // tensor is flattened to 1d tensor and then the reduction happens on the - // 0th dimension. + // `AtenMaxDimOp` and `AtenMinDimOp` do not support dim as `NoneType` so + // first the input tensor is flattened to 1d tensor and then the reduction + // happens on the 0th dimension. if (dim.getType().isa()) { BaseTensorType flattenType = inputType @@ -1317,11 +1317,11 @@ class DecomposeAtenArgMinMaxOp : public OpRewritePattern { } Value resultArg = - rewriter - .create(loc, valueTensorType, indicesTensorType, - input, dim, keepDim) - .getIndices(); - + rewriter + .create(loc, valueTensorType, indicesTensorType, input, + dim, keepDim) + .getIndices(); + rewriter.replaceOp(op, resultArg); return success(); } @@ -1959,10 +1959,12 @@ class DecomposeAtenSeluOp : public OpRewritePattern { // Define λ and α double scale = 1.0507009873554804934193349852946; double alpha = 1.6732632423543772848170429916717; - + // Create constants for λ and α - Value scaleVal = rewriter.create(loc, rewriter.getF64FloatAttr(scale)); - Value alphaVal = rewriter.create(loc, rewriter.getF64FloatAttr(alpha)); + Value scaleVal = rewriter.create( + loc, rewriter.getF64FloatAttr(scale)); + Value alphaVal = rewriter.create( + loc, rewriter.getF64FloatAttr(alpha)); // Create zero tensor for comparison Value constantZero = @@ -1972,17 +1974,21 @@ class DecomposeAtenSeluOp : public OpRewritePattern { // Calculate positive and negative parts Value constantOne = rewriter.create(loc, rewriter.getF64FloatAttr(1.0)); - Value positiveOutput = rewriter.create(loc, resType, zeroTensor, input); + Value positiveOutput = + rewriter.create(loc, resType, zeroTensor, input); Value minZeroX = rewriter.create(loc, resType, zeroTensor, input); Value expInput = rewriter.create(loc, resType, minZeroX); - Value expInputMinusOne = rewriter.create(loc, resType, expInput, constantOne, constantOne); - Value negativeOutput = rewriter.create(loc, resType, expInputMinusOne, alphaVal); + Value expInputMinusOne = rewriter.create( + loc, resType, expInput, constantOne, constantOne); + Value negativeOutput = rewriter.create( + loc, resType, expInputMinusOne, alphaVal); // Multiply the result by λ Value seluOutput = rewriter.create( loc, resType, positiveOutput, negativeOutput, constantOne); - seluOutput = rewriter.create(loc, resType, seluOutput, scaleVal); + seluOutput = + rewriter.create(loc, resType, seluOutput, scaleVal); // Replace the original operation rewriter.replaceOp(op, seluOutput); @@ -2592,79 +2598,89 @@ class DecomposeAten_ConvolutionLikeOp namespace { - static LogicalResult createTorchTransposeOpForConvTbc(PatternRewriter &rewriter, - Location loc, Value input, - int64_t dimA, int64_t dimB, - Value &transposed) { - Type transposedType; - if (failed(getTransposedType(input.getType().cast(), - dimA, dimB, transposedType))) - return failure(); - Value cstDimA = rewriter.create( - loc, rewriter.getI64IntegerAttr(dimA)); - Value cstDimB = rewriter.create( - loc, rewriter.getI64IntegerAttr(dimB)); - transposed = rewriter.create( - loc, transposedType, input, cstDimA, cstDimB); - return success(); - } +static LogicalResult createTorchTransposeOpForConvTbc(PatternRewriter &rewriter, + Location loc, Value input, + int64_t dimA, + int64_t dimB, + Value &transposed) { + Type transposedType; + if (failed(getTransposedType(input.getType().cast(), + dimA, dimB, transposedType))) + return failure(); + Value cstDimA = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimA)); + Value cstDimB = rewriter.create( + loc, rewriter.getI64IntegerAttr(dimB)); + transposed = rewriter.create( + loc, transposedType, input, cstDimA, cstDimB); + return success(); +} - class DecomposeAtenConvTbcOp : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AtenConvTbcOp op, - PatternRewriter &rewriter) const override { - Value emptyList = rewriter.create( - op.getLoc(), - Torch::ListType::get(Torch::IntType::get(op.getContext())), - SmallVector()); - Value cstFalse = rewriter.create(op.getLoc(), false); - Value oneList = rewriter.create( - op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), - SmallVector{rewriter.create(op.getLoc(), rewriter.getI64IntegerAttr(1))}); - Value padding = rewriter.create( - op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), - SmallVector{op.getPad()}); - Value groups = rewriter.create(op.getLoc(), rewriter.getI64IntegerAttr(1)); - - // convtbc has WNC layout for input and output - // and WCF layout for weight - // whereas Convolution is going to use Conv1DNcwFcwOp for 1d - // which means we need the inputs in NCW and the weight in FCW - Value selfWnc = op.getSelf(); - Value selfNwc; - Value selfNcw; - if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfWnc, 0, 1, selfNwc))) - return rewriter.notifyMatchFailure(op, "failed to transpose input to Nwc"); - if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfNwc, 1, 2, selfNcw))) - return rewriter.notifyMatchFailure(op, "failed to transpose input to Ncw"); - - Value weightWcf = op.getWeight(); - Value weightFcw; - if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), weightWcf, 0, 2, weightFcw))) - return rewriter.notifyMatchFailure(op, "failed to transpose weight to Fcw"); - - - Value outputNcw = rewriter.create( - op.getLoc(), op->getResultTypes(), selfNcw, weightFcw, op.getBias(), /*stride*/oneList, - /*padding*/ padding, /*dilation*/ oneList, - /*transpose*/ cstFalse, /*output_padding*/ emptyList, - groups); - - // convert output from Ncw to Wnc - Value outputNwc; - Value outputWnc; - if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), outputNcw, 1, 2, outputNwc))) - return rewriter.notifyMatchFailure(op, "failed to transpose output to Nwc"); - if(failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), outputNwc, 0, 1, outputWnc))) - return rewriter.notifyMatchFailure(op, "failed to transpose output to Wnc"); - rewriter.replaceOp(op, outputWnc); +class DecomposeAtenConvTbcOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenConvTbcOp op, + PatternRewriter &rewriter) const override { + Value emptyList = rewriter.create( + op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + SmallVector()); + Value cstFalse = rewriter.create(op.getLoc(), false); + Value oneList = rewriter.create( + op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + SmallVector{rewriter.create( + op.getLoc(), rewriter.getI64IntegerAttr(1))}); + Value padding = rewriter.create( + op.getLoc(), Torch::ListType::get(Torch::IntType::get(op.getContext())), + SmallVector{op.getPad()}); + Value groups = rewriter.create( + op.getLoc(), rewriter.getI64IntegerAttr(1)); + + // convtbc has WNC layout for input and output + // and WCF layout for weight + // whereas Convolution is going to use Conv1DNcwFcwOp for 1d + // which means we need the inputs in NCW and the weight in FCW + Value selfWnc = op.getSelf(); + Value selfNwc; + Value selfNcw; + if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfWnc, + 0, 1, selfNwc))) + return rewriter.notifyMatchFailure(op, + "failed to transpose input to Nwc"); + if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), selfNwc, + 1, 2, selfNcw))) + return rewriter.notifyMatchFailure(op, + "failed to transpose input to Ncw"); - return success(); - } - }; -} + Value weightWcf = op.getWeight(); + Value weightFcw; + if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), + weightWcf, 0, 2, weightFcw))) + return rewriter.notifyMatchFailure(op, + "failed to transpose weight to Fcw"); + + Value outputNcw = rewriter.create( + op.getLoc(), op->getResultTypes(), selfNcw, weightFcw, op.getBias(), + /*stride*/ oneList, + /*padding*/ padding, /*dilation*/ oneList, + /*transpose*/ cstFalse, /*output_padding*/ emptyList, groups); + + // convert output from Ncw to Wnc + Value outputNwc; + Value outputWnc; + if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), + outputNcw, 1, 2, outputNwc))) + return rewriter.notifyMatchFailure(op, + "failed to transpose output to Nwc"); + if (failed(createTorchTransposeOpForConvTbc(rewriter, op.getLoc(), + outputNwc, 0, 1, outputWnc))) + return rewriter.notifyMatchFailure(op, + "failed to transpose output to Wnc"); + rewriter.replaceOp(op, outputWnc); + return success(); + } +}; +} // namespace // Decompose aten.conv1d to aten.convolution namespace { @@ -3815,8 +3831,8 @@ class DecomposeAtenNormalFunctionalOp /*device=*/none, /*pin_memory=*/none, /*memory_format=*/none); Value stdRandN = rewriter.create(loc, resultType, randN, std); - rewriter.replaceOpWithNewOp(op, resultType, stdRandN, - mean, /*alpha=*/one); + rewriter.replaceOpWithNewOp(op, resultType, stdRandN, mean, + /*alpha=*/one); return success(); } }; @@ -6654,8 +6670,10 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - addPatternIfTargetOpIsIllegal>(patterns); - addPatternIfTargetOpIsIllegal>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenArgMinMaxOp>(patterns); + addPatternIfTargetOpIsIllegal< + DecomposeAtenArgMinMaxOp>(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -6768,8 +6786,6 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); - - GreedyRewriteConfig config; config.useTopDownTraversal = true; config.maxIterations = GreedyRewriteConfig::kNoLimit; diff --git a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp index da8be9b17e0b..239960629797 100644 --- a/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp +++ b/lib/Dialect/Torch/Transforms/GlobalizeObjectGraph.cpp @@ -170,8 +170,8 @@ class ObjectGraphInfo { auto attr = std::get<1>(t); nameStack.push_back(attr.getName().str()); if (attr.getType().isa()) { - if (failed( - recursivelyTraverse(slot.getValue().getDefiningOp()))) + if (failed(recursivelyTraverse( + slot.getValue().getDefiningOp()))) return failure(); } else if (usedSlots.find(slot) != usedSlots.end()) { // Only create the GlobalSlotOp if the slot is used at all. @@ -190,8 +190,8 @@ class ObjectGraphInfo { } for (auto method : classType.getOps()) { nameStack.push_back(method.getName().str()); - funcLinkageInfo[{nnModule, - symbolTable.lookup(method.getFunction())}] = + funcLinkageInfo[{ + nnModule, symbolTable.lookup(method.getFunction())}] = LinkageInfo{llvm::join(nameStack, "."), method.getIsPrivate()}; nameStack.pop_back(); } @@ -501,21 +501,24 @@ static LogicalResult rewriteMonomorphizedFuncClone( SmallVector toErase; auto handlePrimSetAttr = [&](PrimSetAttrOp op) { - auto instance = mapping.lookup(op.getReceiver()).getDefiningOp(); + auto instance = + mapping.lookup(op.getReceiver()).getDefiningOp(); SlotOp affectedSlot; for (auto slot : instance.getOps()) { if (slot.getName() == op.getName()) affectedSlot = slot; } OpBuilder(op).create( - op.getLoc(), objectGraphInfo.getGlobalSlotFor(affectedSlot).getSymName(), + op.getLoc(), + objectGraphInfo.getGlobalSlotFor(affectedSlot).getSymName(), op.getValue()); toErase.push_back(op); return WalkResult::advance(); }; auto handlePrimGetAttr = [&](PrimGetAttrOp op) { if (!op.getType().isa()) { - auto instance = mapping.lookup(op.getReceiver()).getDefiningOp(); + auto instance = + mapping.lookup(op.getReceiver()).getDefiningOp(); SlotOp affectedSlot; for (auto slot : instance.getOps()) { if (slot.getName() == op.getName()) diff --git a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp index c67e6dc0d3a7..1e8c90deac4e 100644 --- a/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp +++ b/lib/Dialect/Torch/Transforms/InlineGlobalSlots.cpp @@ -163,7 +163,8 @@ LogicalResult InlineGlobalSlotsAnalysis::initialize(Operation *top) { } if (auto globalSlotSet = dyn_cast(op)) { auto *state = getOrCreate( - getProgramPoint(globalSlotSet.getSlotAttr())); + getProgramPoint( + globalSlotSet.getSlotAttr())); propagateIfChanged(state, state->setSafe(false)); } // Save the InitializeGlobalSlotsOp for later referencee @@ -211,8 +212,8 @@ LogicalResult InlineGlobalSlotsAnalysis::visit(ProgramPoint point) { auto it = llvm::find(initializeGlobalSlotsOp.getSlotSymNames(), static_cast(flatSymbolRefPoint->getValue())); - Value value = initializeGlobalSlotsOp->getOperand( - std::distance(initializeGlobalSlotsOp.getSlotSymNames().begin(), it)); + Value value = initializeGlobalSlotsOp->getOperand(std::distance( + initializeGlobalSlotsOp.getSlotSymNames().begin(), it)); auto *flatSymbolRefState = getOrCreateFor(value, flatSymbolRefPoint); @@ -331,7 +332,8 @@ class InlineGlobalSlotsPass DenseSet safeToInline; for (int i = 0, e = initialize->getNumOperands(); i != e; i++) { - auto slotSymName = initialize.getSlotSymNames()[i].cast(); + auto slotSymName = + initialize.getSlotSymNames()[i].cast(); Value operand = initialize.getOperand(i); auto symbolRefPoint = solver.getProgramPoint( initialize.getSlotSymNames()[i].cast()); @@ -405,7 +407,8 @@ class InlineGlobalSlotsPass SmallVector newSlotSymNames; SmallVector newInitialValues; for (int i = 0, e = initialize.getNumOperands(); i != e; i++) { - auto slotSymName = initialize.getSlotSymNames()[i].cast(); + auto slotSymName = + initialize.getSlotSymNames()[i].cast(); if (!safeToInline.count(slotSymName)) { newSlotSymNames.push_back(slotSymName); newInitialValues.push_back(initialize.getOperand(i)); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 34874cb59635..befdf808ad5b 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -202,15 +202,16 @@ static bool satisfiesBackendContract(ModuleOp module, // Check for unimplemented operators first to give more direct diagnostics. walkResult0 = module.walk([&](Torch::OperatorOp op) { if (llvm::all_of(op.getResults(), [&op](auto res) { - return succeeded( - checkType(op.getOperation(), res.getType(), /*actuallyEmitDiagnostics=*/false)); + return succeeded(checkType(op.getOperation(), res.getType(), + /*actuallyEmitDiagnostics=*/false)); })) { return WalkResult::advance(); } if (actuallyEmitDiagnostics) { - op->emitError("unsupported by backend contract: Unimplemented operator '" - + op.getName() + "'"); + op->emitError( + "unsupported by backend contract: Unimplemented operator '" + + op.getName() + "'"); } return WalkResult::interrupt(); }); @@ -309,20 +310,22 @@ class LowerToBackendContractPass << " iterations of the simplification pipeline\n"; }); } + private: llvm::StringSet<> backendLegalOpsSet; }; class VerifyBackendContractNoDecompositionsPass - : public VerifyBackendContractNoDecompositionsBase { + : public VerifyBackendContractNoDecompositionsBase< + VerifyBackendContractNoDecompositionsPass> { public: VerifyBackendContractNoDecompositionsPass() = default; void runOnOperation() override { MLIRContext *context = &getContext(); ConversionTarget target = - getBackendContractTarget(context, /*decompose*/false, - /*backendLegalOpsSet*/{}); + getBackendContractTarget(context, /*decompose*/ false, + /*backendLegalOpsSet*/ {}); if (!satisfiesBackendContract(getOperation(), target, /*actuallyEmitDiagnostics=*/true)) { diff --git a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp index 7c3ceab3afec..a34e0208c9d9 100644 --- a/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp +++ b/lib/Dialect/Torch/Transforms/ReifyAbstractInterpCalculationsUtils.cpp @@ -158,9 +158,11 @@ void Torch::importLibraryFunctions(ModuleOp module, ModuleOp library, } } -FailureOr Torch::adjustFunctionArg( - OpBuilder &b, Location loc, Value operand, Type desiredType, - function_ref baseTransformation) { +FailureOr +Torch::adjustFunctionArg(OpBuilder &b, Location loc, Value operand, + Type desiredType, + function_ref + baseTransformation) { operand = baseTransformation(b, loc, operand, desiredType); // No need for adjustment if they already match. diff --git a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp index 6860fbb6eee8..fbbd6c48043b 100644 --- a/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp +++ b/lib/Dialect/Torch/Transforms/SimplifyDtypeCalculations.cpp @@ -90,7 +90,8 @@ class DecomposePromoteDtypesOp : public OpRewritePattern { PatternRewriter &rewriter) const override { SmallVector> ranks; SmallVector dtypes; - if (!matchPattern(op.getRanks(), m_TorchListOfOptionalConstantInts(ranks))) { + if (!matchPattern(op.getRanks(), + m_TorchListOfOptionalConstantInts(ranks))) { return rewriter.notifyMatchFailure( op, "Expected `ranks` to be a list of optional constant ints"); } diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index bf371d7c4687..e2abee51b817 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -344,9 +344,9 @@ FailureOr Torch::unsqueezeTensor(PatternRewriter &rewriter, // Checks whether the `shapeA` and `shapeB` are broadcast compatible or not. If // yes, then computes the final broadcast shape. void Torch::computeBroadcastShape(PatternRewriter &rewriter, Location loc, - Value inputA, Value inputB, - SmallVector &resultShape, - SmallVector &resultShapeValue) { + Value inputA, Value inputB, + SmallVector &resultShape, + SmallVector &resultShapeValue) { SmallVector shapeA{ inputA.getType().cast().getSizes()}; SmallVector shapeB{ @@ -514,7 +514,7 @@ Value Torch::createRank0Tensor(PatternRewriter &rewriter, Location loc, } LogicalResult Torch::getTransposedType(BaseTensorType inType, int64_t dimA, - int64_t dimB, Type &transposedType) { + int64_t dimB, Type &transposedType) { if (!inType.hasSizes()) return failure(); SmallVector shape(inType.getSizes()); diff --git a/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp index 4d38f4965df2..ac9a72586bef 100644 --- a/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp +++ b/lib/Dialect/TorchConversion/IR/TorchConversionDialect.cpp @@ -54,14 +54,14 @@ void TorchConversionDialect::initialize() { addInterfaces(); } - //===----------------------------------------------------------------------===// // Constant materializer. //===----------------------------------------------------------------------===// Operation *TorchConversionDialect::materializeConstant(OpBuilder &builder, - Attribute value, Type type, - Location loc) { + Attribute value, + Type type, + Location loc) { if (auto integerType = type.dyn_cast()) return builder.create(loc, value.cast()); diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp index 8a5c218e4f3e..1cda55724ee3 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversion.cpp @@ -7,8 +7,8 @@ // //===----------------------------------------------------------------------===// -#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" using namespace mlir; using namespace mlir::torch; @@ -57,16 +57,16 @@ static void setupTorchBoolToI1Conversion(ConversionTarget &target, typeConverter.addConversion([](Torch::BoolType type) -> std::optional { return IntegerType::get(type.getContext(), 1); }); - typeConverter.addTargetMaterialization([](OpBuilder &builder, - IntegerType type, ValueRange inputs, - Location loc) -> std::optional { - // Other builtin integer types could be handled by other materializers. - if (!(type.getWidth() == 1 && type.isSignless())) - return std::nullopt; - assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addTargetMaterialization( + [](OpBuilder &builder, IntegerType type, ValueRange inputs, + Location loc) -> std::optional { + // Other builtin integer types could be handled by other materializers. + if (!(type.getWidth() == 1 && type.isSignless())) + return std::nullopt; + assert(inputs.size() == 1); + assert(inputs[0].getType().isa()); + return builder.create(loc, inputs[0]).getResult(); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::BoolType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -83,19 +83,19 @@ static void setupTorchIntToI64Conversion(ConversionTarget &target, typeConverter.addConversion([](Torch::IntType type) -> std::optional { return IntegerType::get(type.getContext(), 64); }); - typeConverter.addTargetMaterialization([](OpBuilder &builder, - IntegerType type, ValueRange inputs, - Location loc) -> std::optional { - // Other builtin integer types could be handled by other materializers. - if (!(type.getWidth() == 64 && type.isSignless())) - return std::nullopt; - // Other input type to be converted to i64 are handled by other - // materializers. - if (!inputs[0].getType().isa()) - return std::nullopt; - assert(inputs.size() == 1); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addTargetMaterialization( + [](OpBuilder &builder, IntegerType type, ValueRange inputs, + Location loc) -> std::optional { + // Other builtin integer types could be handled by other materializers. + if (!(type.getWidth() == 64 && type.isSignless())) + return std::nullopt; + // Other input type to be converted to i64 are handled by other + // materializers. + if (!inputs[0].getType().isa()) + return std::nullopt; + assert(inputs.size() == 1); + return builder.create(loc, inputs[0]).getResult(); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::IntType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -112,13 +112,13 @@ static void setupTorchFloatToF64Conversion(ConversionTarget &target, typeConverter.addConversion([](Torch::FloatType type) -> std::optional { return Float64Type::get(type.getContext()); }); - typeConverter.addTargetMaterialization([](OpBuilder &builder, - Float64Type type, ValueRange inputs, - Location loc) -> std::optional { - assert(inputs.size() == 1); - assert(inputs[0].getType().isa()); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addTargetMaterialization( + [](OpBuilder &builder, Float64Type type, ValueRange inputs, + Location loc) -> std::optional { + assert(inputs.size() == 1); + assert(inputs[0].getType().isa()); + return builder.create(loc, inputs[0]).getResult(); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::FloatType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); @@ -133,22 +133,23 @@ static void setupTorchGeneratorToI64Conversion(ConversionTarget &target, TypeConverter &typeConverter) { target.addLegalOp(); - typeConverter.addConversion([](Torch::GeneratorType type) -> std::optional { - return IntegerType::get(type.getContext(), 64); - }); - typeConverter.addTargetMaterialization([](OpBuilder &builder, - IntegerType type, ValueRange inputs, - Location loc) -> std::optional { - // Other builtin integer types could be handled by other materializers. - if (!(type.getWidth() == 64 && type.isSignless())) - return std::nullopt; - // Other input type to be converted to i64 are handled by other - // materializers. - if (!inputs[0].getType().isa()) - return std::nullopt; - assert(inputs.size() == 1); - return builder.create(loc, inputs[0]).getResult(); - }); + typeConverter.addConversion( + [](Torch::GeneratorType type) -> std::optional { + return IntegerType::get(type.getContext(), 64); + }); + typeConverter.addTargetMaterialization( + [](OpBuilder &builder, IntegerType type, ValueRange inputs, + Location loc) -> std::optional { + // Other builtin integer types could be handled by other materializers. + if (!(type.getWidth() == 64 && type.isSignless())) + return std::nullopt; + // Other input type to be converted to i64 are handled by other + // materializers. + if (!inputs[0].getType().isa()) + return std::nullopt; + assert(inputs.size() == 1); + return builder.create(loc, inputs[0]).getResult(); + }); auto sourceMaterialization = [](OpBuilder &builder, Torch::GeneratorType type, ValueRange inputs, Location loc) -> Value { assert(inputs.size() == 1); diff --git a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp index 175a3cd14804..514d05234486 100644 --- a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp +++ b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp @@ -18,8 +18,8 @@ #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" -#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" using namespace mlir; using namespace mlir::torch; @@ -65,7 +65,8 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { auto getConstantIntegerFromDefiningOp = [](Value operand, int &extractedInt) { - auto castOp = dyn_cast(operand.getDefiningOp()); + auto castOp = + dyn_cast(operand.getDefiningOp()); if (!castOp) { return failure(); } @@ -83,7 +84,8 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { return failure(); } int unpackedBitWidth; - if (failed(getConstantIntegerFromDefiningOp(unpackedTypeWidth, unpackedBitWidth))) { + if (failed(getConstantIntegerFromDefiningOp(unpackedTypeWidth, + unpackedBitWidth))) { return failure(); } if (unpackedBitWidth != @@ -103,32 +105,35 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { // expand lhs std::vector lhsExpandedShape = {lhsShape[0], lhsShape[1], lhsReductDimSize / gs, gs}; - RankedTensorType lhsExpandedType = RankedTensorType::get(lhsExpandedShape, elementType); + RankedTensorType lhsExpandedType = + RankedTensorType::get(lhsExpandedShape, elementType); SmallVector lhsReassociation = {{0}, {1}, {2, 3}}; Value lhsExpanded = rewriter.create( - loc, lhsExpandedType, lhs, lhsReassociation); + loc, lhsExpandedType, lhs, lhsReassociation); // expand rhs - std::vector rhsExpandedShape = {rhsShape[0], rhsReductDimSize/gs, gs}; - RankedTensorType rhsExpandedType = RankedTensorType::get(rhsExpandedShape, rhsElementType); + std::vector rhsExpandedShape = {rhsShape[0], rhsReductDimSize / gs, + gs}; + RankedTensorType rhsExpandedType = + RankedTensorType::get(rhsExpandedShape, rhsElementType); SmallVector rhsReassociation = {{0}, {1, 2}}; Value rhsExpanded = rewriter.create( - loc, rhsExpandedType, rhsQuant, rhsReassociation); + loc, rhsExpandedType, rhsQuant, rhsReassociation); Value cst0 = rewriter.create( - loc, FloatAttr::get(elementType, 0.0)); + loc, FloatAttr::get(elementType, 0.0)); - Value emptyDequant = rewriter.create( - loc, rhsExpandedShape, elementType); + Value emptyDequant = + rewriter.create(loc, rhsExpandedShape, elementType); SmallVector dynDims; for (int i = 0; i < lhsType.getRank(); i++) { if (lhsType.isDynamicDim(i)) { dynDims.push_back(rewriter.create(loc, lhs, i)); } } - Value empty = rewriter.create( - loc, resultShape, elementType, dynDims); - Value output = rewriter.create( - loc, cst0, empty).getResult(0); + Value empty = rewriter.create(loc, resultShape, + elementType, dynDims); + Value output = + rewriter.create(loc, cst0, empty).getResult(0); AffineExpr d0, d1, d2, d3, d4; bindDims(getContext(), d0, d1, d2, d3, d4); @@ -141,12 +146,12 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { SmallVector dqIndexingMaps = {map, map1, map1, map}; SmallVector matIndexingMaps = {map2, map3, map4}; - SmallVector dequantIteratorTypes(3, utils::IteratorType::parallel); + SmallVector dequantIteratorTypes( + 3, utils::IteratorType::parallel); SmallVector matmulIteratorTypes = { - utils::IteratorType::parallel, utils::IteratorType::parallel, - utils::IteratorType::parallel, utils::IteratorType::reduction, - utils::IteratorType::reduction - }; + utils::IteratorType::parallel, utils::IteratorType::parallel, + utils::IteratorType::parallel, utils::IteratorType::reduction, + utils::IteratorType::reduction}; Value rhsDequant = rewriter @@ -157,9 +162,12 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { /*iteratorTypes=*/dequantIteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { Value w = args[0], scale = args[1], zeroPoint = args[2]; - Value extw = b.create(loc, rewriter.getI32Type(), w); - Value fp_extw = b.create(loc, rewriter.getF16Type(), extw); - Value shifted = b.create(loc, fp_extw, zeroPoint); + Value extw = + b.create(loc, rewriter.getI32Type(), w); + Value fp_extw = b.create( + loc, rewriter.getF16Type(), extw); + Value shifted = + b.create(loc, fp_extw, zeroPoint); Value dqw = b.create(loc, shifted, scale); b.create(loc, dqw); }) @@ -168,8 +176,8 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { Value matmulDequant = rewriter .create( - loc, output.getType(), - ValueRange{lhsExpanded, rhsDequant}, output, + loc, output.getType(), ValueRange{lhsExpanded, rhsDequant}, + output, /*indexingMaps=*/matIndexingMaps, /*iteratorTypes=*/matmulIteratorTypes, [&](OpBuilder &b, Location loc, ValueRange args) { @@ -188,7 +196,8 @@ class ConvertCustomQuantizedMatmulOp : public OpConversionPattern { namespace { class ConvertCustomQuantOpPass - : public TorchConversion::ConvertCustomQuantOpBase { + : public TorchConversion::ConvertCustomQuantOpBase< + ConvertCustomQuantOpPass> { void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); registry.insert(); @@ -213,8 +222,8 @@ class ConvertCustomQuantOpPass target.addIllegalOp(); patterns.add(typeConverter, context); - if (failed( - applyPartialConversion(getOperation(), target, std::move(patterns)))) + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) signalPassFailure(); } }; diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp index 93d7de8250a7..5ad3fa1c9f4f 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyLinalgOnTensorsBackendContract.cpp @@ -33,7 +33,6 @@ using namespace mlir::torch; using namespace mlir::torch::TorchConversion; using namespace TMTensor; - namespace { class VerifyLinalgOnTensorsBackendContractPass : public VerifyLinalgOnTensorsBackendContractBase< @@ -96,7 +95,8 @@ class VerifyLinalgOnTensorsBackendContractPass // We avoid `module.emitError()` so that mlir-print-op-on-diagnostics // doesn't unnecessarily spew out the entire module. emitError(module.getLoc()) - << "Module does not conform to the linalg-on-tensors backend contract. " + << "Module does not conform to the linalg-on-tensors backend " + "contract. " "See dialect conversion legality information above."; return signalPassFailure(); } diff --git a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp b/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp index 888f29adedb2..c6085f419eac 100644 --- a/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp +++ b/lib/Dialect/TorchConversion/Transforms/VerifyStablehloBackendContract.cpp @@ -45,7 +45,8 @@ class VerifyStablehloBackendContractPass ConversionTarget target(*context); // Structural operations. - target.addDynamicallyLegalOp(opHasLegalTypes); + target.addDynamicallyLegalOp( + opHasLegalTypes); // Shape operations. target.addDynamicallyLegalOp(opHasLegalTypes); diff --git a/projects/ltc/csrc/base_lazy_backend/backend_impl.cpp b/projects/ltc/csrc/base_lazy_backend/backend_impl.cpp index bd4fe52b7b22..dc044879669e 100644 --- a/projects/ltc/csrc/base_lazy_backend/backend_impl.cpp +++ b/projects/ltc/csrc/base_lazy_backend/backend_impl.cpp @@ -31,18 +31,18 @@ TorchMlirBackendData::TorchMlirBackendData(BackendDevice device, Shape shape) PRINT_FUNCTION(); } TorchMlirBackendData::TorchMlirBackendData( - BackendDevice device, Shape shape, std::shared_ptr info) + BackendDevice device, Shape shape, std::shared_ptr info) : BackendData(device, shape), info_(info) { PRINT_FUNCTION(); } -TorchMlirBackendData::TorchMlirBackendData( - const at::Scalar& scalar, BackendDevice device) +TorchMlirBackendData::TorchMlirBackendData(const at::Scalar &scalar, + BackendDevice device) : BackendData(device, Shape(scalar.type(), {})), info_(std::make_shared(scalar)) { PRINT_FUNCTION(); } -TorchMlirBackendData::TorchMlirBackendData( - const at::Tensor& tensor, BackendDevice device, Shape shape) +TorchMlirBackendData::TorchMlirBackendData(const at::Tensor &tensor, + BackendDevice device, Shape shape) : BackendData(device, shape), info_(std::make_shared(tensor)) { PRINT_FUNCTION(); @@ -52,19 +52,18 @@ BackendData::Handle TorchMlirBackendData::GetHandle() { return reinterpret_cast(this); } -void TorchMlirBackendData::Assign(const BackendData& data) { - const TorchMlirBackendData* torch_mlir_data = - dynamic_cast(&data); - TORCH_CHECK( - torch_mlir_data, - "Invalid Backend Data Pointer. Expected TorchMlirBackendData."); +void TorchMlirBackendData::Assign(const BackendData &data) { + const TorchMlirBackendData *torch_mlir_data = + dynamic_cast(&data); + TORCH_CHECK(torch_mlir_data, + "Invalid Backend Data Pointer. Expected TorchMlirBackendData."); info_ = torch_mlir_data->info_; } bool TorchMlirBackendData::HasValue() const { return bool(info_); } -BackendData::Info* TorchMlirBackendData::mlir_info() const { +BackendData::Info *TorchMlirBackendData::mlir_info() const { return info_.get(); } @@ -77,8 +76,8 @@ void TorchMlirBackendImpl::PrepareToExit() const {} * IR Tracing * */ -const IrBuilder* TorchMlirBackendImpl::GetIrBuilder() const { - static const IrBuilder* builder = new TorchMlirIrBuilder(); +const IrBuilder *TorchMlirBackendImpl::GetIrBuilder() const { + static const IrBuilder *builder = new TorchMlirIrBuilder(); return builder; } @@ -87,28 +86,29 @@ const IrBuilder* TorchMlirBackendImpl::GetIrBuilder() const { * */ BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromTensor( - const at::Tensor& tensor, const Shape& shape, - const BackendDevice& device) const { + const at::Tensor &tensor, const Shape &shape, + const BackendDevice &device) const { PRINT_FUNCTION(); return std::make_shared(tensor, device, shape); } BackendDataPtr TorchMlirBackendImpl::MakeComputationDataFromScalar( - const at::Scalar& scalar, const BackendDevice& device) const { + const at::Scalar &scalar, const BackendDevice &device) const { PRINT_FUNCTION(); return std::make_shared(scalar, device); } -BackendDataPtr TorchMlirBackendImpl::CreateDataPlaceholder( - const BackendDevice& device, const Shape& shape) const { +BackendDataPtr +TorchMlirBackendImpl::CreateDataPlaceholder(const BackendDevice &device, + const Shape &shape) const { PRINT_FUNCTION(); return std::make_shared(device, shape); } BackendDataPtr -TorchMlirBackendImpl::GetComputationDataFromNode(const Node* node) const { +TorchMlirBackendImpl::GetComputationDataFromNode(const Node *node) const { PRINT_FUNCTION(); - const auto* device_data_node = dynamic_cast(node); + const auto *device_data_node = dynamic_cast(node); if (!device_data_node) { return nullptr; } @@ -120,14 +120,13 @@ at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData( c10::optional logical_scalar_type) const { PRINT_FUNCTION(); - TorchMlirBackendData* torch_mlir_data = - dynamic_cast(data.get()); - TORCH_CHECK( - torch_mlir_data, - "Invalid Backend Data Pointer. Expected TorchMlirBackendData."); + TorchMlirBackendData *torch_mlir_data = + dynamic_cast(data.get()); + TORCH_CHECK(torch_mlir_data, + "Invalid Backend Data Pointer. Expected TorchMlirBackendData."); - TorchMlirBackendData::Info* info = - dynamic_cast(torch_mlir_data->mlir_info()); + TorchMlirBackendData::Info *info = + dynamic_cast(torch_mlir_data->mlir_info()); TORCH_CHECK( info, "Invalid Backend Data Pointer. Expected TorchMlirBackendData::Info."); @@ -140,17 +139,19 @@ at::Tensor TorchMlirBackendImpl::MakeTensorFromComputationData( * */ std::unique_ptr TorchMlirBackendImpl::CreateLoweringContext( - const std::string& name, BackendDevice device, - c10::ArrayRef post_order, Util::EmissionMap emit_status) const { + const std::string &name, BackendDevice device, + c10::ArrayRef post_order, + Util::EmissionMap emit_status) const { PRINT_FUNCTION(); return std::make_unique( name, std::forward(device), - std::forward>(post_order), + std::forward>(post_order), std::forward(emit_status)); } -std::unique_ptr TorchMlirBackendImpl::CreateLoweringContext( - const std::string& name, BackendDevice device) const { +std::unique_ptr +TorchMlirBackendImpl::CreateLoweringContext(const std::string &name, + BackendDevice device) const { PRINT_FUNCTION(); return std::make_unique( name, std::forward(device)); @@ -175,9 +176,8 @@ at::DeviceType TorchMlirBackendImpl::EagerFallbackDeviceType() const { // Query all available backend devices std::vector TorchMlirBackendImpl::GetBackendDevices() const { PRINT_FUNCTION(); - return { - GetBackendDevice(c10::Device(c10::kLazy, 0)), - GetBackendDevice(c10::Device(c10::kCPU, 0))}; + return {GetBackendDevice(c10::Device(c10::kLazy, 0)), + GetBackendDevice(c10::Device(c10::kCPU, 0))}; } // Map a particular c10:: device to a concrete backend device diff --git a/projects/ltc/csrc/base_lazy_backend/backend_impl.h b/projects/ltc/csrc/base_lazy_backend/backend_impl.h index c77033593ba3..4029cab1ea90 100644 --- a/projects/ltc/csrc/base_lazy_backend/backend_impl.h +++ b/projects/ltc/csrc/base_lazy_backend/backend_impl.h @@ -41,27 +41,28 @@ class TORCH_API TorchMlirBackendData : public BackendData { name = ss.str(); ++i; } - Info(const Info& other) + Info(const Info &other) : tensor{other.tensor}, scalar{other.scalar}, requires_grad{other.requires_grad}, name{other.name} {} - Info(const at::Tensor& tensor) + Info(const at::Tensor &tensor) : tensor{tensor}, requires_grad{tensor.requires_grad()} {} - Info(const at::Scalar& scalar) : scalar{scalar}, requires_grad(false) {} + Info(const at::Scalar &scalar) : scalar{scalar}, requires_grad(false) {} }; TorchMlirBackendData(BackendDevice device, Shape shape); - TorchMlirBackendData(BackendDevice device, Shape shape, std::shared_ptr info); - TorchMlirBackendData(const at::Scalar& scalar, BackendDevice device); - TorchMlirBackendData( - const at::Tensor& tensor, BackendDevice device, Shape shape); + TorchMlirBackendData(BackendDevice device, Shape shape, + std::shared_ptr info); + TorchMlirBackendData(const at::Scalar &scalar, BackendDevice device); + TorchMlirBackendData(const at::Tensor &tensor, BackendDevice device, + Shape shape); virtual BackendData::Handle GetHandle() override; - virtual void Assign(const BackendData& data) override; + virtual void Assign(const BackendData &data) override; virtual bool HasValue() const override; - BackendData::Info* mlir_info() const; + BackendData::Info *mlir_info() const; protected: std::shared_ptr info_; @@ -80,7 +81,7 @@ class TORCH_API TorchMlirBackendImpl : public BackendImplInterface { * IR Tracing * */ - const IrBuilder* GetIrBuilder() const override; + const IrBuilder *GetIrBuilder() const override; /** * Configuration @@ -91,19 +92,22 @@ class TORCH_API TorchMlirBackendImpl : public BackendImplInterface { * Data Transfer * */ - virtual BackendDataPtr MakeComputationDataFromTensor( - const at::Tensor& tensor, const Shape& shape, - const BackendDevice& device) const override; + virtual BackendDataPtr + MakeComputationDataFromTensor(const at::Tensor &tensor, const Shape &shape, + const BackendDevice &device) const override; - virtual BackendDataPtr MakeComputationDataFromScalar( - const at::Scalar& scalar, const BackendDevice& device) const override; + virtual BackendDataPtr + MakeComputationDataFromScalar(const at::Scalar &scalar, + const BackendDevice &device) const override; - virtual BackendDataPtr CreateDataPlaceholder( - const BackendDevice& device, const Shape& shape) const override; + virtual BackendDataPtr + CreateDataPlaceholder(const BackendDevice &device, + const Shape &shape) const override; // Gets backend data if the node is a device data node. Otherwise returns // nullptr. - virtual BackendDataPtr GetComputationDataFromNode(const Node*) const override; + virtual BackendDataPtr + GetComputationDataFromNode(const Node *) const override; virtual at::Tensor MakeTensorFromComputationData( const BackendDataPtr data, @@ -113,13 +117,14 @@ class TORCH_API TorchMlirBackendImpl : public BackendImplInterface { * Lowering, Compilation, Execution * */ - virtual std::unique_ptr CreateLoweringContext( - const std::string& name, BackendDevice device, - c10::ArrayRef post_order, - Util::EmissionMap emit_status) const override; + virtual std::unique_ptr + CreateLoweringContext(const std::string &name, BackendDevice device, + c10::ArrayRef post_order, + Util::EmissionMap emit_status) const override; - virtual std::unique_ptr CreateLoweringContext( - const std::string& name, BackendDevice device) const override; + virtual std::unique_ptr + CreateLoweringContext(const std::string &name, + BackendDevice device) const override; // TODO(whc) need to keep this? // virtual std::vector GetCompilationDevices( diff --git a/projects/ltc/csrc/base_lazy_backend/dynamic_ir.cpp b/projects/ltc/csrc/base_lazy_backend/dynamic_ir.cpp index ca6d80f1f419..c11c1563bb5d 100644 --- a/projects/ltc/csrc/base_lazy_backend/dynamic_ir.cpp +++ b/projects/ltc/csrc/base_lazy_backend/dynamic_ir.cpp @@ -16,20 +16,18 @@ namespace torch { namespace lazy { DimensionNode::DimensionNode(OpKind op, OpList operands, hash_t hash_seed) - : TorchMlirNode( - op, operands, /*num_outputs=*/1, - /* hash_seed */ HashCombine(op.hash(), hash_seed)) {} + : TorchMlirNode(op, operands, /*num_outputs=*/1, + /* hash_seed */ HashCombine(op.hash(), hash_seed)) {} std::string DimensionNode::ToString() const { return "DimensionNode"; } SizeNode::SizeNode(Value input, size_t dim) - : DimensionNode( - OpKind{c10::Symbol::fromQualString("aten::size")}, {input}, - MHash(dim)), + : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::size")}, {input}, + MHash(dim)), dim_(dim){}; int64_t SizeNode::getStaticValue() const { - return dynamic_cast(operand(0).node) + return dynamic_cast(operand(0).node) ->shape(0) .size(dim_); } @@ -40,8 +38,9 @@ SizeAdd::SizeAdd(Value a, Value b) : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::add")}, {a, b}){}; int64_t SizeAdd::getStaticValue() const { - return dynamic_cast(operand(0).node)->getStaticValue() + - dynamic_cast(operand(1).node)->getStaticValue(); + return dynamic_cast(operand(0).node) + ->getStaticValue() + + dynamic_cast(operand(1).node)->getStaticValue(); } std::string SizeAdd::ToString() const { return "SizeAdd"; } @@ -50,8 +49,9 @@ SizeMul::SizeMul(Value a, Value b) : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::mul")}, {a, b}){}; int64_t SizeMul::getStaticValue() const { - return dynamic_cast(operand(0).node)->getStaticValue() * - dynamic_cast(operand(1).node)->getStaticValue(); + return dynamic_cast(operand(0).node) + ->getStaticValue() * + dynamic_cast(operand(1).node)->getStaticValue(); } std::string SizeMul::ToString() const { return "SizeMul"; } @@ -61,11 +61,12 @@ SizeDiv::SizeDiv(Value a, Value b) int64_t SizeDiv::getStaticValue() const { TORCH_CHECK( - dynamic_cast(operand(1).node)->getStaticValue() != + dynamic_cast(operand(1).node)->getStaticValue() != 0, "Can't divide a dimension by zero"); - return dynamic_cast(operand(0).node)->getStaticValue() / - dynamic_cast(operand(1).node)->getStaticValue(); + return dynamic_cast(operand(0).node) + ->getStaticValue() / + dynamic_cast(operand(1).node)->getStaticValue(); } std::string SizeDiv::ToString() const { return "SizeDiv"; } diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.cpp b/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.cpp index 7e6f40c5c2e9..a27889ad0895 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.cpp +++ b/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.cpp @@ -12,14 +12,14 @@ #include +#include "mlir-c/IR.h" +#include "mlir-c/Pass.h" +#include "torch-mlir-c/Registration.h" +#include "torch-mlir-c/Transforms.h" #include #include -#include #include -#include "torch-mlir-c/Registration.h" -#include "torch-mlir-c/Transforms.h" -#include "mlir-c/IR.h" -#include "mlir-c/Pass.h" +#include #include "backend_impl.h" #include "jit_ir_importer/function_importer.h" @@ -38,8 +38,8 @@ namespace lazy { // TorchMlir Lowering Context /////////////////////////////////////////////////////////////////////////////// -TorchMlirLoweringContext::TorchMlirLoweringContext( - const std::string& name, BackendDevice device) +TorchMlirLoweringContext::TorchMlirLoweringContext(const std::string &name, + BackendDevice device) : LoweringContext(name, std::forward(device)), graph_(std::make_shared()), function_( @@ -49,11 +49,12 @@ TorchMlirLoweringContext::TorchMlirLoweringContext( } TorchMlirLoweringContext::TorchMlirLoweringContext( - const std::string& name, BackendDevice device, - c10::ArrayRef post_order, Util::EmissionMap emit_status) + const std::string &name, BackendDevice device, + c10::ArrayRef post_order, + Util::EmissionMap emit_status) : LoweringContext( name, std::forward(device), - std::forward>(post_order), + std::forward>(post_order), std::forward(emit_status)), graph_(std::make_shared()), function_( @@ -66,9 +67,9 @@ TorchMlirLoweringContext::TorchMlirLoweringContext( } } -void TorchMlirLoweringContext::Lower(const Node* node) { - if (auto* torch_mlir_node = - dynamic_cast(node)) { +void TorchMlirLoweringContext::Lower(const Node *node) { + if (auto *torch_mlir_node = + dynamic_cast(node)) { TorchMlirOpVector ops = torch_mlir_node->Lower(function_, this); CHECK(!ops.empty()) << "Failed to lower: " << *node; TORCH_CHECK_EQ(node->num_outputs(), ops.size()); @@ -82,19 +83,19 @@ void TorchMlirLoweringContext::Lower(const Node* node) { } void TorchMlirLoweringContext::SetUpAlias( - const std::vector& output_index, int64_t param_number, - const std::vector& param_index, bool must_alias) { + const std::vector &output_index, int64_t param_number, + const std::vector ¶m_index, bool must_alias) { input_output_aliases_.push_back( {output_index, param_number, param_index, must_alias}); } bool TorchMlirLoweringContext::CheckResultShape( - const BackendDataPtr& parameter_data, size_t result_idx) { - TORCH_CHECK( - result_idx < root_tuple_.size(), "Tried getting result shape at index ", - result_idx, " which is out of bounds!"); + const BackendDataPtr ¶meter_data, size_t result_idx) { + TORCH_CHECK(result_idx < root_tuple_.size(), + "Tried getting result shape at index ", result_idx, + " which is out of bounds!"); - torch::jit::Value* output = root_tuple_[result_idx]; + torch::jit::Value *output = root_tuple_[result_idx]; if (c10::TensorTypePtr tensor_type = output->type()->cast()) { @@ -111,7 +112,7 @@ bool TorchMlirLoweringContext::CheckResultShape( return false; } -size_t TorchMlirLoweringContext::AddResult(const Output& output) { +size_t TorchMlirLoweringContext::AddResult(const Output &output) { PRINT_FUNCTION(); return AddResult(GetOutputOp(output)); @@ -120,9 +121,10 @@ size_t TorchMlirLoweringContext::AddResult(const Output& output) { // Associates the given output with the input parameter of the given index and // shape. Only used for the operator-by-operator execution, mostly for // debugging purposes. -void TorchMlirLoweringContext::AddParameter( - const torch::lazy::Output& output, size_t index, - const torch::lazy::Shape& shape, const std::string& name) { +void TorchMlirLoweringContext::AddParameter(const torch::lazy::Output &output, + size_t index, + const torch::lazy::Shape &shape, + const std::string &name) { UNIMPLEMENTED_FUNCTION_ERROR(); } @@ -136,7 +138,7 @@ ComputationPtr TorchMlirLoweringContext::Build() { torch::jit::RefineTupleTypes(graph_); // Insert return values into graph. - for (torch::jit::Value* output : root_tuple_) { + for (torch::jit::Value *output : root_tuple_) { graph_->block()->registerOutput(output); } @@ -152,7 +154,6 @@ ComputationPtr TorchMlirLoweringContext::Build() { /*getArgAttribute=*/[](int) -> MlirAttribute { return {nullptr}; }, /*importOptions=*/{/*assumeTensorsHaveValueSemantics=*/true}); - // Convert MlirOperation to MlirModule. MlirLocation loc = mlirLocationUnknownGet(mlir_context_); MlirModule module_op = mlirModuleCreateEmpty(loc); @@ -162,14 +163,10 @@ ComputationPtr TorchMlirLoweringContext::Build() { // Apply passes to verify generated MLIR. auto pass_manager = mlirPassManagerCreate(mlir_context_); mlirPassManagerAddOwnedPass( - pass_manager, - mlirCreateVerifyBackendContractNoDecompositions() - ); + pass_manager, mlirCreateVerifyBackendContractNoDecompositions()); - MlirLogicalResult result = mlirPassManagerRunOnOp( - pass_manager, - mlirModuleGetOperation(module_op) - ); + MlirLogicalResult result = + mlirPassManagerRunOnOp(pass_manager, mlirModuleGetOperation(module_op)); if (mlirLogicalResultIsFailure(result)) { throw std::runtime_error("MLIR verification has failed."); @@ -178,12 +175,14 @@ ComputationPtr TorchMlirLoweringContext::Build() { return CreateComputation(module_op); } -ComputationPtr TorchMlirLoweringContext::CreateComputation(MlirModule module_op) { - return std::make_shared( - module_op, mlir_context_, graph_, parameter_names_, input_output_aliases_); +ComputationPtr +TorchMlirLoweringContext::CreateComputation(MlirModule module_op) { + return std::make_shared(module_op, mlir_context_, + graph_, parameter_names_, + input_output_aliases_); } -torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) { +torch::jit::Value *TorchMlirLoweringContext::GetOutputOp(const Output &output) { PRINT_FUNCTION(); auto it = emitted_outputs_.find(output); @@ -195,15 +194,14 @@ torch::jit::Value* TorchMlirLoweringContext::GetOutputOp(const Output& output) { // At this point the output better be present, otherwise there is an issue // with the lowering code. it = emitted_outputs_.find(output); - TORCH_CHECK( - it != emitted_outputs_.end(), - "No MLIR operation emitted for output: ", output.ToString()); + TORCH_CHECK(it != emitted_outputs_.end(), + "No MLIR operation emitted for output: ", output.ToString()); } return it->second; } -void TorchMlirLoweringContext::AssignOutputOp( - const Output& output, torch::jit::Value* op) { +void TorchMlirLoweringContext::AssignOutputOp(const Output &output, + torch::jit::Value *op) { PRINT_FUNCTION(); auto torch_mlir_node = @@ -211,48 +209,44 @@ void TorchMlirLoweringContext::AssignOutputOp( std::vector source_files, functions; std::vector line_numbers; - const auto& metadata = torch_mlir_node->metadata(); - const auto& frames = metadata.frame_info; + const auto &metadata = torch_mlir_node->metadata(); + const auto &frames = metadata.frame_info; if (!frames.empty()) { static std::vector g_roots = - string_split(sys_util::GetEnvString("LTC_IR_DEBUG_ROOT_PATH", ""), ":"); + string_split(sys_util::GetEnvString("LTC_IR_DEBUG_ROOT_PATH", ""), ":"); std::for_each(frames.rbegin(), frames.rend(), - [&](const torch::lazy::SourceLocation& location) { - functions.push_back(location.function); - line_numbers.push_back(location.line); - - std::string file_name = location.file; - for (const std::string& root : g_roots) { - if (startswith(file_name, root)) { - // location.file starts with root, strip it off - file_name = file_name.substr(root.size()); - break; - } - } - source_files.push_back(file_name); - }); + [&](const torch::lazy::SourceLocation &location) { + functions.push_back(location.function); + line_numbers.push_back(location.line); + + std::string file_name = location.file; + for (const std::string &root : g_roots) { + if (startswith(file_name, root)) { + // location.file starts with root, strip it off + file_name = file_name.substr(root.size()); + break; + } + } + source_files.push_back(file_name); + }); if (!source_files.empty()) { - op->node()->ss_( - c10::Symbol::attr("source_files"), source_files); - op->node()->ss_( - c10::Symbol::attr("functions"), functions); - op->node()->is_( - c10::Symbol::attr("line_numbers"), line_numbers); + op->node()->ss_(c10::Symbol::attr("source_files"), source_files); + op->node()->ss_(c10::Symbol::attr("functions"), functions); + op->node()->is_(c10::Symbol::attr("line_numbers"), line_numbers); } } auto scope = ::c10::Symbol::scope(metadata.scope); - op->node()->setScope( - c10::make_intrusive()->push(scope)); + op->node()->setScope(c10::make_intrusive()->push(scope)); emitted_outputs_[output] = std::move(op); } -torch::jit::Value* TorchMlirLoweringContext::GetParameter(BackendDataPtr data) { +torch::jit::Value *TorchMlirLoweringContext::GetParameter(BackendDataPtr data) { PRINT_FUNCTION(); - if (!dynamic_cast(data.get())) { + if (!dynamic_cast(data.get())) { TORCH_CHECK( false, "Expected TorchMlirBackendData. Got some other BackendData type"); @@ -263,20 +257,21 @@ torch::jit::Value* TorchMlirLoweringContext::GetParameter(BackendDataPtr data) { auto it = parameters_map_.find(handle); if (it == parameters_map_.end()) { - torch::jit::Value* param = + torch::jit::Value *param = graph_->addInput(c10::str("p", parameters_.size())); - auto* info = dynamic_cast(mlir_data->mlir_info()); + auto *info = + dynamic_cast(mlir_data->mlir_info()); TORCH_CHECK(info, "Expected TorchMlirBackendData::Info"); if (info->scalar.has_value()) { - auto& scalar = info->scalar.value(); + auto &scalar = info->scalar.value(); if (scalar.isFloatingPoint()) { param->setType(c10::FloatType::get()); } else if (scalar.isIntegral(true)) { param->setType(c10::IntType::get()); } else { - TORCH_CHECK( - false, "Unhandled scalar type: ", c10::toString(scalar.type())); + TORCH_CHECK(false, + "Unhandled scalar type: ", c10::toString(scalar.type())); } } else { // Save parameter shape information. @@ -305,7 +300,7 @@ std::shared_ptr TorchMlirLoweringContext::graph() const { return graph_; } -size_t TorchMlirLoweringContext::AddResult(torch::jit::Value* op) { +size_t TorchMlirLoweringContext::AddResult(torch::jit::Value *op) { PRINT_FUNCTION(); root_tuple_.push_back(std::move(op)); return root_tuple_.size() - 1; @@ -313,9 +308,9 @@ size_t TorchMlirLoweringContext::AddResult(torch::jit::Value* op) { // Sync vector of c10::Argument with type specified from parallel list of // jit::Value. There must be a 1:1 map between elements of args and values. -std::vector sync_argument_types( - const std::vector& args, - c10::ArrayRef values) { +std::vector +sync_argument_types(const std::vector &args, + c10::ArrayRef values) { TORCH_CHECK( args.size() == values.size(), "Expected 1:1 mapping between list of c10::Argument and jit::Value! Got ", @@ -362,7 +357,7 @@ void TorchMlirLoweringContext::RegisterMlirDialects() { TorchMlirComputation::TorchMlirComputation( MlirModule module_op, MlirContext mlir_context, - const std::shared_ptr& graph, + const std::shared_ptr &graph, std::unordered_map parameters_map, InputOutputAliases input_output_aliases) : module_op_(std::move(module_op)), mlir_context_(std::move(mlir_context)), @@ -377,26 +372,25 @@ TorchMlirComputation::TorchMlirComputation( } } -int TorchMlirComputation::parameters_size() const { - return num_parameters_; -} +int TorchMlirComputation::parameters_size() const { return num_parameters_; } -const std::vector& +const std::vector & TorchMlirComputation::parameter_shapes() const { throw std::runtime_error( "todo(whc) implement ts computation shapes or change interface"); return parameter_shapes_; } -const std::vector& TorchMlirComputation::parameter_names() const { +const std::vector &TorchMlirComputation::parameter_names() const { return parameter_names_; } -const std::unordered_map& TorchMlirComputation::parameters_map() const { +const std::unordered_map & +TorchMlirComputation::parameters_map() const { return parameters_map_; } -const torch::lazy::Shape& TorchMlirComputation::result_shape() const { +const torch::lazy::Shape &TorchMlirComputation::result_shape() const { throw std::runtime_error( "todo(whc) implement ts computation shapes or change interface"); return result_shape_; @@ -411,13 +405,9 @@ MlirOperation TorchMlirComputation::func_op() const { return mlirBlockGetFirstOperation(block); } -MlirModule TorchMlirComputation::module_op() const { - return module_op_; -} +MlirModule TorchMlirComputation::module_op() const { return module_op_; } -MlirContext TorchMlirComputation::mlir_context() const { - return mlir_context_; -} +MlirContext TorchMlirComputation::mlir_context() const { return mlir_context_; } const std::string TorchMlirComputation::debug_string() const { std::stringstream ss; @@ -430,7 +420,7 @@ const std::string TorchMlirComputation::debug_string() const { // Parameter names ss << "Parameter names:\n"; - for (auto& p : parameter_names_) { + for (auto &p : parameter_names_) { ss << " " << p << "\n"; } ss << "\n"; @@ -451,10 +441,10 @@ const std::string TorchMlirComputation::debug_string() const { const std::string TorchMlirComputation::to_string() const { // Since we use the C-MLIR API, we need to use a callback to print. - MlirStringCallback print_callback = [](MlirStringRef part, void* user_data) { + MlirStringCallback print_callback = [](MlirStringRef part, void *user_data) { // user_data is a void ptr to some data structure of our choice -- in this // case, the string stream where we'll be accumulating the strings. - std::stringstream* ss_ptr = static_cast(user_data); + std::stringstream *ss_ptr = static_cast(user_data); *ss_ptr << std::string(part.data, part.length); }; std::stringstream ss; @@ -462,7 +452,8 @@ const std::string TorchMlirComputation::to_string() const { // Setup flags for MLIR serialization. MlirOpPrintingFlags flags = mlirOpPrintingFlagsCreate(); mlirOpPrintingFlagsEnableDebugInfo(flags, FLAGS_torch_lazy_ir_debug, false); - mlirOperationPrintWithFlags(mlirModuleGetOperation(module_op_), flags, print_callback, &ss); + mlirOperationPrintWithFlags(mlirModuleGetOperation(module_op_), flags, + print_callback, &ss); return ss.str(); } diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.h b/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.h index f62a71ce7945..3b226b46896a 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.h +++ b/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.h @@ -39,35 +39,34 @@ class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext { }; using InputOutputAliases = std::vector; - TorchMlirLoweringContext( - const std::string& name, torch::lazy::BackendDevice device); - TorchMlirLoweringContext( - const std::string& name, torch::lazy::BackendDevice device, - c10::ArrayRef post_order, - torch::lazy::Util::EmissionMap emit_status); + TorchMlirLoweringContext(const std::string &name, + torch::lazy::BackendDevice device); + TorchMlirLoweringContext(const std::string &name, + torch::lazy::BackendDevice device, + c10::ArrayRef post_order, + torch::lazy::Util::EmissionMap emit_status); - void Lower(const Node* node); + void Lower(const Node *node); // Adds a new input/output alias. - void SetUpAlias( - const std::vector& output_index, int64_t param_number, - const std::vector& param_index, - bool must_alias = false) override; + void SetUpAlias(const std::vector &output_index, + int64_t param_number, const std::vector ¶m_index, + bool must_alias = false) override; // Check if parameter shape matches result at index. - bool CheckResultShape( - const BackendDataPtr& parameter_data, size_t result_idx) override; + bool CheckResultShape(const BackendDataPtr ¶meter_data, + size_t result_idx) override; // Adds the given output as a component of the result tuple and returns its // assigned position within the tuple. - size_t AddResult(const torch::lazy::Output& output) override; + size_t AddResult(const torch::lazy::Output &output) override; // Associates the given output with the input parameter of the given index and // shape. Only used for the operator-by-operator execution, mostly for // debugging purposes. - void AddParameter( - const torch::lazy::Output& output, size_t index, - const torch::lazy::Shape& shape, const std::string& name) override; + void AddParameter(const torch::lazy::Output &output, size_t index, + const torch::lazy::Shape &shape, + const std::string &name) override; // Build the computation capturing all the operations created with the // embedded builder (returned by the builder() API). @@ -78,27 +77,27 @@ class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext { // Retrieves the lowered operation for an output. If the requested output is // not available yet, the graph behind the output's Node is lowered, and the // corresponding TS operation returned. - torch::jit::Value* GetOutputOp(const Output& output); + torch::jit::Value *GetOutputOp(const Output &output); // Assigns the given TS operation to the specified output. As outputs are // lowered in a post-order fashion, later nodes should always find their // operands among the emitted outputs. - void AssignOutputOp(const Output& output, torch::jit::Value* op); + void AssignOutputOp(const Output &output, torch::jit::Value *op); // If a parameter associated with data has already been declared, it will be // returned. Otherwise a new one will be created, associated with the tensor // held in data. - torch::jit::Value* GetParameter(BackendDataPtr data); + torch::jit::Value *GetParameter(BackendDataPtr data); std::shared_ptr graph() const; protected: struct Parameter { - torch::jit::Value* param; + torch::jit::Value *param; size_t index = 0; }; - size_t AddResult(torch::jit::Value* op); + size_t AddResult(torch::jit::Value *op); // Creates a jit::Function from the current jit::Graph. Input and output // type information is patched to include shape. @@ -113,8 +112,8 @@ class TORCH_API TorchMlirLoweringContext : public torch::lazy::LoweringContext { MlirContext mlir_context_; std::unordered_map parameters_map_; std::unordered_map parameter_names_; - std::vector root_tuple_; - OutputMap emitted_outputs_; + std::vector root_tuple_; + OutputMap emitted_outputs_; }; class TORCH_API TorchMlirComputation : public torch::lazy::Computation { @@ -122,21 +121,20 @@ class TORCH_API TorchMlirComputation : public torch::lazy::Computation { using InputOutputAliases = TorchMlirLoweringContext::InputOutputAliases; using InputOutputAlias = TorchMlirLoweringContext::InputOutputAlias; - TorchMlirComputation( - MlirModule module_op, MlirContext mlir_context, - const std::shared_ptr& graph, - std::unordered_map parameters_map, - InputOutputAliases input_output_aliases); + TorchMlirComputation(MlirModule module_op, MlirContext mlir_context, + const std::shared_ptr &graph, + std::unordered_map parameters_map, + InputOutputAliases input_output_aliases); int parameters_size() const override; - const std::vector& parameter_shapes() const override; + const std::vector ¶meter_shapes() const override; - const std::vector& parameter_names() const override; + const std::vector ¶meter_names() const override; - const std::unordered_map& parameters_map() const; + const std::unordered_map ¶meters_map() const; - const torch::lazy::Shape& result_shape() const override; + const torch::lazy::Shape &result_shape() const override; std::shared_ptr graph() const; diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_native_functions.cpp b/projects/ltc/csrc/base_lazy_backend/mlir_native_functions.cpp index 7d9fe056dc30..a0e4bae76db6 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_native_functions.cpp +++ b/projects/ltc/csrc/base_lazy_backend/mlir_native_functions.cpp @@ -10,8 +10,8 @@ // https://github.com/pytorch/pytorch/blob/master/torch/csrc/lazy/ts_backend/ts_native_functions.cpp //===----------------------------------------------------------------------===// -#include #include +#include #include #include #include @@ -33,16 +33,16 @@ #include "generated/LazyIr.h" #include "generated/LazyNativeFunctions.h" #include "generated/shape_inference.h" -#include "ops/to_copy.h" -#include "ops/unbind_int.h" -#include "ops/split.h" #include "ops/index.h" #include "ops/ivalue.h" +#include "ops/split.h" +#include "ops/to_copy.h" +#include "ops/unbind_int.h" #include "utils/exception.h" #include "utils/sys_utils.h" namespace { -at::Tensor to_meta(const at::Tensor& tensor) { +at::Tensor to_meta(const at::Tensor &tensor) { // undefined tensors can't be converted to the meta device, since they don't // have sizes/strides if (!tensor.defined()) @@ -60,7 +60,7 @@ at::Tensor to_meta(const at::Tensor& tensor) { return out; } -c10::optional to_meta(const c10::optional& tensor) { +c10::optional to_meta(const c10::optional &tensor) { if (tensor.has_value()) { return to_meta(*tensor); } @@ -70,16 +70,17 @@ c10::optional to_meta(const c10::optional& tensor) { std::vector to_meta(at::ITensorListRef t_list) { std::vector outs; outs.reserve(t_list.size()); - for (const auto& tensor : t_list) { + for (const auto &tensor : t_list) { outs.push_back(to_meta(tensor)); } return outs; } -c10::List> to_meta(const c10::List>& t_list) { +c10::List> +to_meta(const c10::List> &t_list) { c10::List> outs; outs.reserve(t_list.size()); - for (const auto& tensor : t_list) { + for (const auto &tensor : t_list) { outs.push_back(to_meta(tensor)); } return outs; @@ -91,9 +92,9 @@ namespace lazy { namespace { -at::Tensor CreateLtcTensor( - const at::Tensor& tensor, - const c10::optional& device) { +at::Tensor +CreateLtcTensor(const at::Tensor &tensor, + const c10::optional &device) { if (tensor.defined() && device) { return torch::lazy::CreateAtenFromLtcTensor( torch::lazy::LazyTensor::Create(tensor, *device)); @@ -102,7 +103,7 @@ at::Tensor CreateLtcTensor( } c10::optional -GetLtcDevice(const c10::optional& device) { +GetLtcDevice(const c10::optional &device) { if (!device) { return c10::nullopt; } @@ -112,24 +113,23 @@ GetLtcDevice(const c10::optional& device) { return torch::lazy::atenDeviceToBackendDevice(*device); } -torch::lazy::Value MaybeExpand( - const torch::lazy::Value& input, const torch::lazy::Shape& target_shape) { +torch::lazy::Value MaybeExpand(const torch::lazy::Value &input, + const torch::lazy::Shape &target_shape) { if (input.shape().sizes() == target_shape.sizes()) { return input; } - return torch::lazy::MakeExpand( - input, target_shape.sizes().vec(), - /*is_scalar_expand=*/false); + return torch::lazy::MakeExpand(input, target_shape.sizes().vec(), + /*is_scalar_expand=*/false); } -void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) { +void copy_(torch::lazy::LazyTensorPtr &input, torch::lazy::LazyTensorPtr &src) { if (input->GetDevice() == src->GetDevice()) { torch::lazy::Value copy_value; if (input->dtype() == src->dtype()) { copy_value = src->GetIrValue(); } else { - copy_value = torch::lazy::MakeCast( - src->GetIrValue(), input->dtype(), src->dtype()); + copy_value = torch::lazy::MakeCast(src->GetIrValue(), input->dtype(), + src->dtype()); } input->SetIrValue(MaybeExpand(copy_value, input->shape())); } else { @@ -146,15 +146,17 @@ void copy_(torch::lazy::LazyTensorPtr& input, torch::lazy::LazyTensorPtr& src) { // clone is special in LT because we make it a no-op. // This should be safe to do, because every operator in the LT is functional. -at::Tensor LazyNativeFunctions::clone( - const at::Tensor& self, c10::optional memory_format) { +at::Tensor +LazyNativeFunctions::clone(const at::Tensor &self, + c10::optional memory_format) { auto self_lt = torch::lazy::TryGetLtcTensor(self); return torch::lazy::CreateAtenFromLtcTensor( self_lt->Create(self_lt->GetIrValue(), self_lt->GetDevice())); } -at::Tensor LazyNativeFunctions::_copy_from( - const at::Tensor& self, const at::Tensor& dst, bool non_blocking) { +at::Tensor LazyNativeFunctions::_copy_from(const at::Tensor &self, + const at::Tensor &dst, + bool non_blocking) { TORCH_LAZY_FN_COUNTER("lazy::"); auto dst_tensor = torch::lazy::TryGetLtcTensor(dst); auto self_tensor = torch::lazy::TryGetLtcTensor(self); @@ -199,16 +201,16 @@ at::Tensor LazyNativeFunctions::_copy_from( } } else { copy_(dst_tensor, self_tensor); - auto* impl = - dynamic_cast(dst.unsafeGetTensorImpl()); + auto *impl = + dynamic_cast(dst.unsafeGetTensorImpl()); impl->set_tensor(dst_tensor); } } return dst; } -at::Tensor LazyNativeFunctions::_copy_from_and_resize( - const at::Tensor& self, const at::Tensor& dst) { +at::Tensor LazyNativeFunctions::_copy_from_and_resize(const at::Tensor &self, + const at::Tensor &dst) { TORCH_LAZY_FN_COUNTER("lazy::"); auto dst_tensor = torch::lazy::TryGetLtcTensor(dst); auto self_tensor = torch::lazy::TryGetLtcTensor(self); @@ -223,8 +225,8 @@ at::Tensor LazyNativeFunctions::_copy_from_and_resize( dst.resize_as_(typed_tensor).copy_(typed_tensor); } else { // at this point we know dst is a lazy tensor - auto* dest_impl = - dynamic_cast(dst.unsafeGetTensorImpl()); + auto *dest_impl = + dynamic_cast(dst.unsafeGetTensorImpl()); dest_impl->tensor()->UpdateFromTensorOut(self_tensor); dest_impl->force_refresh_sizes(); } @@ -232,15 +234,16 @@ at::Tensor LazyNativeFunctions::_copy_from_and_resize( } at::Tensor LazyNativeFunctions::_to_copy( - const at::Tensor& self, c10::optional dtype, + const at::Tensor &self, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory, bool non_blocking, c10::optional memory_format) { PRINT_FUNCTION(); auto options = self.options(); if (dtype) { - // I put each of these setters in a conditional instead of doing `self.options().dtype(dtype).layout(layout)... - // because calling .dtype(nullopt) on an options() that already has dtype appears to wipe it + // I put each of these setters in a conditional instead of doing + // `self.options().dtype(dtype).layout(layout)... because calling + // .dtype(nullopt) on an options() that already has dtype appears to wipe it options = options.dtype(dtype); } if (layout) { @@ -261,8 +264,9 @@ at::Tensor LazyNativeFunctions::_to_copy( if (!lazy_self && device && device->type() == c10::kLazy) { // Case 1: eager->lazy (we create a new lazy tensor) // See Note [Lazy Tensor Functionalization] - // Invariant: if the functionalization key is in the exclude set, then we're expected - // to return an ordinary tensor, which will be "lifted" into a functional wrapper later. + // Invariant: if the functionalization key is in the exclude set, then we're + // expected to return an ordinary tensor, which will be "lifted" into a + // functional wrapper later. bool functionalize_output = !c10::impl::tls_local_dispatch_key_set().excluded_.has( c10::DispatchKey::Functionalize); @@ -270,7 +274,8 @@ at::Tensor LazyNativeFunctions::_to_copy( self, options, *device, /*non_blocking=*/non_blocking, /*functionalize_output=*/functionalize_output); } else if (device && device->type() != c10::kLazy) { - // Case 2: lazy->eager (forces a graph break since we are materializing a tensor) + // Case 2: lazy->eager (forces a graph break since we are materializing a + // tensor) TORCH_INTERNAL_ASSERT(lazy_self); auto eager_tensor = lazy_self->ToTensor(/*detached=*/true); @@ -278,22 +283,24 @@ at::Tensor LazyNativeFunctions::_to_copy( auto moved_eager_tensor = eager_tensor.to(options, /*non_blocking=*/non_blocking, /*copy=*/true); return moved_eager_tensor; - } else if ( - device && device->type() == c10::kLazy && device->has_index() && - device->index() != self.device().index()) { + } else if (device && device->type() == c10::kLazy && device->has_index() && + device->index() != self.device().index()) { // Case 3: lazy:0 -> lazy:1 // TODO(whc) what do we actually want to do here? // option 1: materialize, move eager tensor, create new lazy tensor - // - this should be our default, as it is what would happen before we implemented _to_copy + // - this should be our default, as it is what would happen before we + // implemented _to_copy // - actually combines case 1 + case 2 // option 2: support multiple devices inside one lazy/TS executor (case 4) - // - but: we may have other assumptions that there is just one device per executor? so don't take this lightly + // - but: we may have other assumptions that there is just one device + // per executor? so don't take this lightly TORCH_INTERNAL_ASSERT(lazy_self); auto eager_tensor = lazy_self->ToTensor(/*detached=*/true); // we move the eager tensor to the 'eager' equivalent of our lazy device - // e.g. if our device is lazy:1, the backend maps that to cuda:1, which is what we use + // e.g. if our device is lazy:1, the backend maps that to cuda:1, which is + // what we use auto eager_device = c10::Device( torch::lazy::getBackend()->EagerFallbackDeviceType(), device->index()); options = options.device(eager_device); @@ -305,12 +312,14 @@ at::Tensor LazyNativeFunctions::_to_copy( return torch::lazy::CreateAtenFromLtcTensor(lazy_self); } else { - // Case 4: lazy->lazy (special case: keep the _to_copy INSIDE the lazy graph) - - // Note: captured _to_copy will be executed with real eager tensors, not lazy tensors. - // We DO NOT want to burn 'lazy:0' as the device into this captured IR, or we will try to - // convert an eager tensor back to a lazy one inside the torchscript executor - // lazy:0 -> lazy:1 is handled in case3, so we can safely drop the device argument + // Case 4: lazy->lazy (special case: keep the _to_copy INSIDE the lazy + // graph) + + // Note: captured _to_copy will be executed with real eager tensors, not + // lazy tensors. We DO NOT want to burn 'lazy:0' as the device into this + // captured IR, or we will try to convert an eager tensor back to a lazy one + // inside the torchscript executor lazy:0 -> lazy:1 is handled in case3, so + // we can safely drop the device argument device = c10::nullopt; auto shapes = torch::lazy::compute_shape__to_copy( @@ -327,257 +336,297 @@ at::Tensor LazyNativeFunctions::_to_copy( } }; -at::Tensor LazyNativeFunctions::_unsafe_view( - const at::Tensor& self, at::IntArrayRef size) { +at::Tensor LazyNativeFunctions::_unsafe_view(const at::Tensor &self, + at::IntArrayRef size) { TORCH_LAZY_FN_COUNTER("lazy::"); - return LazyNativeFunctions::view_copy_symint(self, c10::fromIntArrayRefSlow(size)); + return LazyNativeFunctions::view_copy_symint(self, + c10::fromIntArrayRefSlow(size)); } -at::Tensor LazyNativeFunctions::t(const at::Tensor& self) { +at::Tensor LazyNativeFunctions::t(const at::Tensor &self) { TORCH_LAZY_FN_COUNTER("lazy::"); return at::functionalization::functionalize_aten_op::call(self); } -std::vector LazyNativeFunctions::unbind_copy(const at::Tensor & self, int64_t dim) { +std::vector LazyNativeFunctions::unbind_copy(const at::Tensor &self, + int64_t dim) { TORCH_LAZY_FN_COUNTER("lazy::"); auto common_device = torch::lazy::GetBackendDevice(self); TORCH_INTERNAL_ASSERT(common_device); - - LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); - torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue(), dim); + + LazyTensorPtr lazy_self = + torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); + torch::lazy::NodePtr node = + torch::lazy::ReuseNode(lazy_self->GetIrValue(), dim); if (!node) { auto self_meta = to_meta(self); - auto out_meta = at::compositeexplicitautogradnonfunctional::unbind_copy(self_meta, dim); - + auto out_meta = + at::compositeexplicitautogradnonfunctional::unbind_copy(self_meta, dim); + std::vector shapes; - for (const auto & shape : out_meta) { + for (const auto &shape : out_meta) { shapes.push_back( - torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec()) - ); + torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec())); } - if(torch::lazy::symbolicShapeEnabled()){ - std::vector inputs = { self, dim }; - const char* schema_str = "aten::unbind_copy.int(Tensor self, int dim=0) -> Tensor[]"; + if (torch::lazy::symbolicShapeEnabled()) { + std::vector inputs = {self, dim}; + const char *schema_str = + "aten::unbind_copy.int(Tensor self, int dim=0) -> Tensor[]"; applySymbolicShapesOnLT(schema_str, inputs, shapes); } - node = torch::lazy::MakeNode(lazy_self->GetIrValue(), dim, std::move(shapes)); + node = torch::lazy::MakeNode(lazy_self->GetIrValue(), dim, + std::move(shapes)); CacheNode(node); } - + std::vector result; for (size_t i = 0; i < node->num_outputs(); ++i) { result.push_back( - torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device) - ) - ); + torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create( + torch::lazy::Value(node, i), *common_device))); } return result; } -std::vector LazyNativeFunctions::split_with_sizes_copy_symint(const at::Tensor & self, c10::SymIntArrayRef split_sizes, int64_t dim) { +std::vector LazyNativeFunctions::split_with_sizes_copy_symint( + const at::Tensor &self, c10::SymIntArrayRef split_sizes, int64_t dim) { TORCH_LAZY_FN_COUNTER("lazy::"); auto common_device = torch::lazy::GetBackendDevice(self); TORCH_INTERNAL_ASSERT(common_device); - LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); - torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim); + LazyTensorPtr lazy_self = + torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); + torch::lazy::NodePtr node = torch::lazy::ReuseNode( + lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim); if (!node) { auto self_meta = to_meta(self); - auto out_meta = at::compositeexplicitautogradnonfunctional::split_with_sizes_copy_symint(self_meta, split_sizes, dim); + auto out_meta = at::compositeexplicitautogradnonfunctional:: + split_with_sizes_copy_symint(self_meta, split_sizes, dim); std::vector shapes; - for (const auto & shape : out_meta) { + for (const auto &shape : out_meta) { shapes.push_back( - torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec()) - ); + torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec())); } - if(torch::lazy::symbolicShapeEnabled()){ - std::vector inputs = { self, split_sizes, dim }; - const char* schema_str = "aten::split_with_sizes_copy(Tensor self, SymInt[] split_sizes, int dim=0) -> Tensor[]"; - applySymbolicShapesOnLT(schema_str, inputs, shapes); + if (torch::lazy::symbolicShapeEnabled()) { + std::vector inputs = {self, split_sizes, dim}; + const char *schema_str = "aten::split_with_sizes_copy(Tensor self, " + "SymInt[] split_sizes, int dim=0) -> Tensor[]"; + applySymbolicShapesOnLT(schema_str, inputs, shapes); } - node = torch::lazy::MakeNode(lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim, std::move(shapes)); + node = torch::lazy::MakeNode( + lazy_self->GetIrValue(), GetSymIntArrayRefValue(split_sizes), dim, + std::move(shapes)); CacheNode(node); } std::vector result; for (size_t i = 0; i < node->num_outputs(); ++i) { result.push_back( - torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device) - ) - ); + torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create( + torch::lazy::Value(node, i), *common_device))); } return result; } -std::vector LazyNativeFunctions::split_copy_symint(const at::Tensor & self, c10::SymInt split_size, int64_t dim) { +std::vector +LazyNativeFunctions::split_copy_symint(const at::Tensor &self, + c10::SymInt split_size, int64_t dim) { TORCH_LAZY_FN_COUNTER("lazy::"); auto common_device = torch::lazy::GetBackendDevice(self); TORCH_INTERNAL_ASSERT(common_device); - LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); - torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue(), GetSymIntValue(split_size), dim); + LazyTensorPtr lazy_self = + torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); + torch::lazy::NodePtr node = torch::lazy::ReuseNode( + lazy_self->GetIrValue(), GetSymIntValue(split_size), dim); if (!node) { auto self_meta = to_meta(self); - auto out_meta = at::compositeexplicitautogradnonfunctional::split_copy_symint(self_meta, split_size, dim); + auto out_meta = + at::compositeexplicitautogradnonfunctional::split_copy_symint( + self_meta, split_size, dim); std::vector shapes; - for (const auto & shape : out_meta) { + for (const auto &shape : out_meta) { shapes.push_back( - torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec()) - ); + torch::lazy::Shape(shape.scalar_type(), shape.sizes().vec())); } const size_t num_outputs = shapes.size(); - if(torch::lazy::symbolicShapeEnabled()){ - std::vector inputs = { self, split_size, dim }; - const char* schema_str = "aten::split_copy.Tensor(Tensor self, SymInt split_size, int dim=0) -> Tensor[]"; - applySymbolicShapesOnLT(schema_str, inputs, shapes); + if (torch::lazy::symbolicShapeEnabled()) { + std::vector inputs = {self, split_size, dim}; + const char *schema_str = "aten::split_copy.Tensor(Tensor self, SymInt " + "split_size, int dim=0) -> Tensor[]"; + applySymbolicShapesOnLT(schema_str, inputs, shapes); } - node = torch::lazy::MakeNode(lazy_self->GetIrValue(), GetSymIntValue(split_size), dim, std::move(shapes), num_outputs); + node = torch::lazy::MakeNode( + lazy_self->GetIrValue(), GetSymIntValue(split_size), dim, + std::move(shapes), num_outputs); CacheNode(node); } std::vector result; for (size_t i = 0; i < node->num_outputs(); ++i) { result.push_back( - torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::LazyTensor::Create(torch::lazy::Value(node, i), *common_device) - ) - ); + torch::lazy::CreateAtenFromLtcTensor(torch::lazy::LazyTensor::Create( + torch::lazy::Value(node, i), *common_device))); } return result; } -at::Tensor LazyNativeFunctions::index(const at::Tensor & self, const c10::List> & indices) { +at::Tensor LazyNativeFunctions::index( + const at::Tensor &self, + const c10::List> &indices) { TORCH_LAZY_FN_COUNTER("lazy::"); auto common_device = torch::lazy::GetBackendDevice(self); TORCH_INTERNAL_ASSERT(common_device); - LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); + LazyTensorPtr lazy_self = + torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); std::vector values; - for (const auto & it : indices) { + for (const auto &it : indices) { c10::optional tensor = it; - LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor())); - values.push_back(lazy_tensor ? lazy_tensor->GetIrValue() : torch::lazy::Value(MakeNode(c10::IValue()), 0)); + LazyTensorPtr lazy_tensor = + torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor())); + values.push_back( + lazy_tensor + ? lazy_tensor->GetIrValue() + : torch::lazy::Value(MakeNode(c10::IValue()), 0)); } auto list = MakeNode(values); - torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue(), list); + torch::lazy::NodePtr node = + torch::lazy::ReuseNode(lazy_self->GetIrValue(), list); if (!node) { auto self_meta = to_meta(self); auto indices_meta = to_meta(indices); auto out_meta = at::meta::index(self_meta, indices_meta); - std::vector shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; + std::vector shapes{ + torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; TORCH_INTERNAL_ASSERT(shapes.size() == 1); - if(torch::lazy::symbolicShapeEnabled()) { - std::vector inputs = { self, indices }; - const char* schema_str = "aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor"; + if (torch::lazy::symbolicShapeEnabled()) { + std::vector inputs = {self, indices}; + const char *schema_str = + "aten::index.Tensor(Tensor self, Tensor?[] indices) -> Tensor"; applySymbolicShapesOnLT(schema_str, inputs, shapes); } - node = torch::lazy::MakeNode(lazy_self->GetIrValue(), list, std::move(shapes)); + node = torch::lazy::MakeNode(lazy_self->GetIrValue(), list, + std::move(shapes)); CacheNode(node); } auto result = torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::LazyTensor::Create(std::move(node), *common_device)); + torch::lazy::LazyTensor::Create(std::move(node), *common_device)); return result; } -at::Tensor LazyNativeFunctions::index_put(const at::Tensor & self, const c10::List> & indices, const at::Tensor & values, bool accumulate) { +at::Tensor LazyNativeFunctions::index_put( + const at::Tensor &self, const c10::List> &indices, + const at::Tensor &values, bool accumulate) { TORCH_LAZY_FN_COUNTER("lazy::"); auto common_device = torch::lazy::GetBackendDevice(self); TORCH_INTERNAL_ASSERT(common_device); - LazyTensorPtr lazy_self = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); - LazyTensorPtr lazy_valeus = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(values, *common_device); + LazyTensorPtr lazy_self = + torch::lazy::GetLtcTensorOrCreateForWrappedNumber(self, *common_device); + LazyTensorPtr lazy_valeus = + torch::lazy::GetLtcTensorOrCreateForWrappedNumber(values, *common_device); std::vector indices_vector; - for (const auto & it : indices) { + for (const auto &it : indices) { c10::optional tensor = it; - LazyTensorPtr lazy_tensor = torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor())); - indices_vector.push_back(lazy_tensor ? lazy_tensor->GetIrValue() : torch::lazy::Value(MakeNode(c10::IValue()), 0)); + LazyTensorPtr lazy_tensor = + torch::lazy::TryGetLtcTensor(tensor.value_or(at::Tensor())); + indices_vector.push_back( + lazy_tensor + ? lazy_tensor->GetIrValue() + : torch::lazy::Value(MakeNode(c10::IValue()), 0)); } auto indices_list = MakeNode(indices_vector); - torch::lazy::NodePtr node = torch::lazy::ReuseNode(lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(), accumulate); + torch::lazy::NodePtr node = + torch::lazy::ReuseNode(lazy_self->GetIrValue(), indices_list, + lazy_valeus->GetIrValue(), accumulate); if (!node) { auto self_meta = to_meta(self); auto indices_meta = to_meta(indices); auto values_meta = to_meta(values); - auto out_meta = at::compositeexplicitautograd::index_put(self_meta, indices_meta, values_meta, accumulate); + auto out_meta = at::compositeexplicitautograd::index_put( + self_meta, indices_meta, values_meta, accumulate); - std::vector shapes{torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; + std::vector shapes{ + torch::lazy::Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; TORCH_INTERNAL_ASSERT(shapes.size() == 1); - if(torch::lazy::symbolicShapeEnabled()) { - std::vector inputs = { self, indices, values }; - const char* schema_str = "aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool accumulate=False) -> Tensor"; + if (torch::lazy::symbolicShapeEnabled()) { + std::vector inputs = {self, indices, values}; + const char *schema_str = + "aten::index_put(Tensor self, Tensor?[] indices, Tensor values, bool " + "accumulate=False) -> Tensor"; applySymbolicShapesOnLT(schema_str, inputs, shapes); } - node = torch::lazy::MakeNode(lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(), accumulate, std::move(shapes)); + node = torch::lazy::MakeNode( + lazy_self->GetIrValue(), indices_list, lazy_valeus->GetIrValue(), + accumulate, std::move(shapes)); CacheNode(node); } auto result = torch::lazy::CreateAtenFromLtcTensor( - torch::lazy::LazyTensor::Create(std::move(node), *common_device)); + torch::lazy::LazyTensor::Create(std::move(node), *common_device)); return result; } // This is needed by the torch.tensor constructor. // LazyTensor always opts into functionalization. -// "lifting" a tensor for functionalization means wrapping it in a FunctionalTensorWrapper object. -at::Tensor LazyNativeFunctions::lift(const at::Tensor& tensor) { +// "lifting" a tensor for functionalization means wrapping it in a +// FunctionalTensorWrapper object. +at::Tensor LazyNativeFunctions::lift(const at::Tensor &tensor) { TORCH_INTERNAL_ASSERT( !at::functionalization::impl::isFunctionalTensor(tensor)); return at::functionalization::impl::to_functional_tensor(tensor); } -at::Tensor LazyNativeFunctions::lift_fresh(const at::Tensor& tensor) { +at::Tensor LazyNativeFunctions::lift_fresh(const at::Tensor &tensor) { TORCH_INTERNAL_ASSERT( !at::functionalization::impl::isFunctionalTensor(tensor)); return at::functionalization::impl::to_functional_tensor(tensor); } -// All of the below ops correspond to CompositeExplicitAutograd kernels from core -// that call into view operators internally. -// These are all composite ops that LTC can technically re-use / get for free, -// but we need to "functionalize" them to remove the view ops before we can use them. +// All of the below ops correspond to CompositeExplicitAutograd kernels from +// core that call into view operators internally. These are all composite ops +// that LTC can technically re-use / get for free, but we need to +// "functionalize" them to remove the view ops before we can use them. at::Tensor LazyNativeFunctions::block_diag(at::TensorList tensors) { return at::functionalization::functionalize_aten_op::call(tensors); } at::Tensor LazyNativeFunctions::new_empty_strided_symint( - const at::Tensor& self, - c10::SymIntArrayRef size, - c10::SymIntArrayRef stride, - c10::optional dtype, - c10::optional layout, - c10::optional device, + const at::Tensor &self, c10::SymIntArrayRef size, + c10::SymIntArrayRef stride, c10::optional dtype, + c10::optional layout, c10::optional device, c10::optional pin_memory) { if (!device || device->type() == c10::DeviceType::Lazy) { - return at::functionalization::functionalize_aten_op_symint< - ATEN_OP(new_empty_strided)>::call(self, size, stride, dtype, layout, - device, pin_memory); + return at::functionalization::functionalize_aten_op_symint::call(self, size, stride, dtype, layout, device, + pin_memory); } - // For cases when device != lazy, for example: lazy_tensor.new_empty_strided(..., "cpu") - // we need to avoid explicit functionalization. To do that we create regular cpu tensors. + // For cases when device != lazy, for example: + // lazy_tensor.new_empty_strided(..., "cpu") we need to avoid explicit + // functionalization. To do that we create regular cpu tensors. at::Tensor t = at::empty_symint( size, (dtype ? dtype : c10::optional(self.scalar_type())), (layout ? layout : c10::optional(self.layout())), device, @@ -585,65 +634,63 @@ at::Tensor LazyNativeFunctions::new_empty_strided_symint( return t.as_strided_symint(size, stride, /*storage_offset=*/0); } -at::Tensor LazyNativeFunctions::narrow_copy_symint( - const at::Tensor& self, - int64_t dim, - c10::SymInt start, - c10::SymInt length) { +at::Tensor LazyNativeFunctions::narrow_copy_symint(const at::Tensor &self, + int64_t dim, + c10::SymInt start, + c10::SymInt length) { return at::functionalization::functionalize_aten_op_symint::call(self, dim, start, length); } -at::Tensor LazyNativeFunctions::pixel_shuffle( - const at::Tensor& self, int64_t upscale_factor) { +at::Tensor LazyNativeFunctions::pixel_shuffle(const at::Tensor &self, + int64_t upscale_factor) { return at::functionalization::functionalize_aten_op::call(self, upscale_factor); } -at::Tensor LazyNativeFunctions::pixel_unshuffle( - const at::Tensor& self, int64_t downscale_factor) { +at::Tensor LazyNativeFunctions::pixel_unshuffle(const at::Tensor &self, + int64_t downscale_factor) { return at::functionalization::functionalize_aten_op::call(self, downscale_factor); } -at::Tensor LazyNativeFunctions::select_backward( - const at::Tensor& grad_output, at::IntArrayRef input_sizes, int64_t dim, - int64_t index) { +at::Tensor LazyNativeFunctions::select_backward(const at::Tensor &grad_output, + at::IntArrayRef input_sizes, + int64_t dim, int64_t index) { return at::functionalization::functionalize_aten_op::call(grad_output, input_sizes, dim, index); } at::Tensor LazyNativeFunctions::slice_backward_symint( - const at::Tensor& grad_output, - at::SymIntArrayRef input_sizes, - int64_t dim, - c10::SymInt start, - c10::SymInt end, - c10::SymInt step) { + const at::Tensor &grad_output, at::SymIntArrayRef input_sizes, int64_t dim, + c10::SymInt start, c10::SymInt end, c10::SymInt step) { return at::functionalization::functionalize_aten_op_symint::call(grad_output, input_sizes, dim, start, end, step); } -at::Tensor LazyNativeFunctions::diagonal_backward( - const at::Tensor& grad_output, at::IntArrayRef input_sizes, int64_t offset, - int64_t dim1, int64_t dim2) { +at::Tensor LazyNativeFunctions::diagonal_backward(const at::Tensor &grad_output, + at::IntArrayRef input_sizes, + int64_t offset, int64_t dim1, + int64_t dim2) { return at::functionalization::functionalize_aten_op::call(grad_output, input_sizes, offset, dim1, dim2); } at::Tensor LazyNativeFunctions::_trilinear( - const at::Tensor& i1, const at::Tensor& i2, const at::Tensor& i3, + const at::Tensor &i1, const at::Tensor &i2, const at::Tensor &i3, at::IntArrayRef expand1, at::IntArrayRef expand2, at::IntArrayRef expand3, at::IntArrayRef sumdim, int64_t unroll_dim) { - return at::functionalization::functionalize_aten_op:: - call(i1, i2, i3, expand1, expand2, expand3, sumdim, unroll_dim); + return at::functionalization::functionalize_aten_op::call(i1, i2, i3, expand1, expand2, expand3, sumdim, + unroll_dim); } at::Tensor LazyNativeFunctions::linalg_pinv( - const at::Tensor& self, const c10::optional& atol, - const c10::optional& rtol, bool hermitian) { + const at::Tensor &self, const c10::optional &atol, + const c10::optional &rtol, bool hermitian) { return at::functionalization::functionalize_aten_op::call(self, atol, rtol, hermitian); } // functionalize_aten_op can't handle out= ops directly. -// Instead, we can call the composite kernel from core, and copy and mutations back to the inputs. -at::Tensor& LazyNativeFunctions::logsumexp_out( - const at::Tensor& self, at::IntArrayRef dim, bool keepdim, - at::Tensor& out) { +// Instead, we can call the composite kernel from core, and copy and mutations +// back to the inputs. +at::Tensor &LazyNativeFunctions::logsumexp_out(const at::Tensor &self, + at::IntArrayRef dim, + bool keepdim, at::Tensor &out) { auto self_wrapped = at::functionalization::impl::to_functional_tensor(self); auto out_wrapped = at::functionalization::impl::to_functional_tensor(out); // directly call the composite kernel from core. diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_node.cpp b/projects/ltc/csrc/base_lazy_backend/mlir_node.cpp index 39dc1ad0cd58..0f31fab2c1e0 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_node.cpp +++ b/projects/ltc/csrc/base_lazy_backend/mlir_node.cpp @@ -18,11 +18,10 @@ namespace lazy { namespace { -hash_t OperandHashes( - const OpList& operands, const c10::ArrayRef& shapes, - const hash_t& seed, bool bakeInSizes) { +hash_t OperandHashes(const OpList &operands, const c10::ArrayRef &shapes, + const hash_t &seed, bool bakeInSizes) { hash_t hash = seed; - for (auto& operand : operands) { + for (auto &operand : operands) { if (!operand) { hash = HashCombine(hash, static_cast(kNullOpt)); continue; @@ -30,7 +29,7 @@ hash_t OperandHashes( auto operand_hash = bakeInSizes ? operand.shapeHash() : operand.hash(); hash = HashCombine(hash, operand_hash); } - for (auto& shape : shapes) { + for (auto &shape : shapes) { hash = HashCombine(hash, shape.hash(bakeInSizes)); } return hash; @@ -38,53 +37,51 @@ hash_t OperandHashes( } // namespace - -// Adds a static hook that is run after every single TorchMlirNode is initialized -static std::vector> constructor_hooks; -void TorchMlirNode::addConstructorHook(std::function f) { +// Adds a static hook that is run after every single TorchMlirNode is +// initialized +static std::vector> constructor_hooks; +void TorchMlirNode::addConstructorHook(std::function f) { constructor_hooks.emplace_back(f); } -TorchMlirNode::TorchMlirNode( - OpKind op, OpList operands, std::vector&& shapes, size_t num_outputs, - hash_t hash_seed) +TorchMlirNode::TorchMlirNode(OpKind op, OpList operands, + std::vector &&shapes, size_t num_outputs, + hash_t hash_seed) : Node(op, operands, std::move(shapes), num_outputs) { hash_seed = HashCombine(op.hash(), hash_seed); shape_hash_ = OperandHashes(operands, this->shapes(), hash_seed, true); - dag_hash_ = - (enableDynamicShape() - ? OperandHashes(operands, this->shapes(), hash_seed, false) - : shape_hash_); + dag_hash_ = (enableDynamicShape() + ? OperandHashes(operands, this->shapes(), hash_seed, false) + : shape_hash_); - for (std::function& f : constructor_hooks) { + for (std::function &f : constructor_hooks) { f(this); } } -TorchMlirNode::TorchMlirNode( - OpKind op, OpList operands, const std::function& shape_fn, - size_t num_outputs, hash_t hash_seed) - : TorchMlirNode( - op, operands, std::vector{}, num_outputs, hash_seed) { +TorchMlirNode::TorchMlirNode(OpKind op, OpList operands, + const std::function &shape_fn, + size_t num_outputs, hash_t hash_seed) + : TorchMlirNode(op, operands, std::vector{}, num_outputs, + hash_seed) { addComputedShape(shape_fn); } -TorchMlirNode::TorchMlirNode( - OpKind op, OpList operands, size_t num_outputs, hash_t hash_seed) - : TorchMlirNode( - op, operands, std::vector{}, num_outputs, hash_seed) {} +TorchMlirNode::TorchMlirNode(OpKind op, OpList operands, size_t num_outputs, + hash_t hash_seed) + : TorchMlirNode(op, operands, std::vector{}, num_outputs, + hash_seed) {} -TorchMlirNode::TorchMlirNode( - OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed) +TorchMlirNode::TorchMlirNode(OpKind op, Shape shape, size_t num_outputs, + hash_t hash_seed) : TorchMlirNode(op, {}, {std::move(shape)}, num_outputs, hash_seed) {} hash_t TorchMlirNode::hash() const { return dag_hash_; } hash_t TorchMlirNode::shapeHash() const { return shape_hash_; } - -TorchMlirNode* TorchMlirNode::mlir_node(int index) const { - return dynamic_cast(operands_.at(index).get()); +TorchMlirNode *TorchMlirNode::mlir_node(int index) const { + return dynamic_cast(operands_.at(index).get()); } /////////////////////////////////////////////////////////////////////////////// @@ -107,11 +104,12 @@ TorchMlirTensorList::TorchMlirTensorList(OpList values) /*num_outputs=*/1, /*hash_seed=*/kHashSeed) {} -torch::lazy::TorchMlirOpVector TorchMlirTensorList::Lower( - TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { - std::vector tensor_list; +torch::lazy::TorchMlirOpVector +TorchMlirTensorList::Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const { + std::vector tensor_list; CHECK(!operands().empty()); - for (const torch::lazy::Output& operand : operands()) { + for (const torch::lazy::Output &operand : operands()) { tensor_list.emplace_back(loctx->GetOutputOp(operand)); } auto graph = function->graph(); @@ -140,16 +138,17 @@ TorchMlirOptionalTensorList::TorchMlirOptionalTensorList(OpList values) /*num_outputs=*/1, /*hash_seed=*/kHashSeed) {} -torch::lazy::TorchMlirOpVector TorchMlirOptionalTensorList::Lower( - TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { - std::vector tensor_list; +torch::lazy::TorchMlirOpVector +TorchMlirOptionalTensorList::Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const { + std::vector tensor_list; CHECK(!operands().empty()); - for (const torch::lazy::Output& operand : operands()) { + for (const torch::lazy::Output &operand : operands()) { tensor_list.emplace_back(loctx->GetOutputOp(operand)); } auto graph = function->graph(); - auto listnode = - graph->insertNode(graph->createList(c10::OptionalType::create(c10::TensorType::get()), tensor_list)); + auto listnode = graph->insertNode(graph->createList( + c10::OptionalType::create(c10::TensorType::get()), tensor_list)); return {listnode->output()}; } diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_node.h b/projects/ltc/csrc/base_lazy_backend/mlir_node.h index a76ec0b05064..e5738a92176d 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_node.h +++ b/projects/ltc/csrc/base_lazy_backend/mlir_node.h @@ -27,23 +27,22 @@ namespace lazy { class TORCH_API TorchMlirNode : public torch::lazy::Node { public: - TorchMlirNode( - OpKind op, OpList operands, std::vector&& shapes, - size_t num_outputs, hash_t hash_seed = kHashSeed); + TorchMlirNode(OpKind op, OpList operands, std::vector &&shapes, + size_t num_outputs, hash_t hash_seed = kHashSeed); - TorchMlirNode( - OpKind op, OpList operands, const std::function& shape_fn, - size_t num_outputs, hash_t hash_seed = kHashSeed); + TorchMlirNode(OpKind op, OpList operands, + const std::function &shape_fn, size_t num_outputs, + hash_t hash_seed = kHashSeed); - TorchMlirNode( - OpKind op, OpList operands, size_t num_outputs, - hash_t hash_seed = kHashSeed); + TorchMlirNode(OpKind op, OpList operands, size_t num_outputs, + hash_t hash_seed = kHashSeed); - TorchMlirNode( - OpKind op, Shape shape, size_t num_outputs, hash_t hash_seed = kHashSeed); + TorchMlirNode(OpKind op, Shape shape, size_t num_outputs, + hash_t hash_seed = kHashSeed); - // Adds a static hook that is run after every single TorchMlirNode is constructed - static void addConstructorHook(std::function); + // Adds a static hook that is run after every single TorchMlirNode is + // constructed + static void addConstructorHook(std::function); ~TorchMlirNode() override = default; @@ -51,10 +50,10 @@ class TORCH_API TorchMlirNode : public torch::lazy::Node { hash_t shapeHash() const override; - TorchMlirNode* mlir_node(int index) const; + TorchMlirNode *mlir_node(int index) const; - virtual TorchMlirOpVector - Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const; + virtual TorchMlirOpVector Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const; private: // The hash of the dag WITH size info. Used for shape caching @@ -86,22 +85,23 @@ struct TORCH_API TorchMlirTensorList : public TorchMlirNode { TorchMlirTensorList() = delete; TorchMlirTensorList(OpList values); - torch::lazy::TorchMlirOpVector Lower( - TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + torch::lazy::TorchMlirOpVector + Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const override; }; -// TorchMlirOptionalTensorList is similar to TorchMlirTensorList but it can also represent -// optional tensors, so the output type for this op is !torch.list>. +// TorchMlirOptionalTensorList is similar to TorchMlirTensorList but it can also +// represent optional tensors, so the output type for this op is +// !torch.list>. struct TORCH_API TorchMlirOptionalTensorList : public TorchMlirNode { static OpKind ClassOpKind(); TorchMlirOptionalTensorList() = delete; TorchMlirOptionalTensorList(OpList values); - torch::lazy::TorchMlirOpVector Lower( - TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + torch::lazy::TorchMlirOpVector + Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const override; }; } // namespace lazy diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.cpp b/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.cpp index a21bb93f0854..b52b724f0f16 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.cpp +++ b/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.cpp @@ -31,21 +31,23 @@ namespace torch { namespace lazy { -TorchMlirOpVector LowerTorchMlirBuiltin( - TorchMlirFunction function, c10::Symbol sym, - const std::vector tensor_types, - const std::vector& arguments, - const std::vector& kwarguments) { +TorchMlirOpVector +LowerTorchMlirBuiltin(TorchMlirFunction function, c10::Symbol sym, + const std::vector tensor_types, + const std::vector &arguments, + const std::vector &kwarguments) { // Workaround for ListType::isSubtypeOfExt behavior which leads to // the problems with JIT schema matching, so we need to keep // c10::ListType empty before magic_method->call function call. auto dummy_graph = torch::jit::Graph(); for (auto arg : arguments) { - torch::jit::Value* value = arg.value(dummy_graph); + torch::jit::Value *value = arg.value(dummy_graph); if (value->type()->kind() == c10::TypeKind::ListType) { - auto list_element_type = value->type()->cast()->getElementType(); + auto list_element_type = + value->type()->cast()->getElementType(); if (list_element_type->cast()) { - value->setType(c10::ListType::create(c10::OptionalType::create(c10::TensorType::get()))); + value->setType(c10::ListType::create( + c10::OptionalType::create(c10::TensorType::get()))); } else { value->setType(c10::ListType::create(c10::TensorType::get())); } @@ -56,25 +58,27 @@ TorchMlirOpVector LowerTorchMlirBuiltin( std::make_shared(sym, at::nullopt); auto magic_method = std::make_shared("", builtin); auto ret = magic_method->call({}, *function, arguments, kwarguments, 0); - auto sv = dynamic_cast(ret.get()); + auto sv = dynamic_cast(ret.get()); CHECK(sv); TorchMlirOpVector results; if (sv->getValue()->type()->kind() == c10::TypeKind::ListType) { - // Unpack dynamic multi-output operations like aten::split with Tensor[] output type. - // This is required to have consistent input types for multi-output node consumers. - torch::jit::Node * node = function->graph()->createListUnpack(sv->getValue(), tensor_types.size()); + // Unpack dynamic multi-output operations like aten::split with Tensor[] + // output type. This is required to have consistent input types for + // multi-output node consumers. + torch::jit::Node *node = function->graph()->createListUnpack( + sv->getValue(), tensor_types.size()); function->graph()->insertNode(node); - for (const auto & output : node->outputs()) { + for (const auto &output : node->outputs()) { results.push_back(output); } } else if (sv->getValue()->type()->kind() == c10::TypeKind::TupleType) { - // Op returns multiple values and the number of outputs is static and defined - // by the operation schema. + // Op returns multiple values and the number of outputs is static and + // defined by the operation schema. const auto tuple_call_result = sv->asTuple({}, *function); - for (const auto& tuple_component : tuple_call_result) { + for (const auto &tuple_component : tuple_call_result) { auto tuple_component_sv = - dynamic_cast(tuple_component.get()); + dynamic_cast(tuple_component.get()); results.push_back(tuple_component_sv->getValue()); } } else { @@ -84,7 +88,7 @@ TorchMlirOpVector LowerTorchMlirBuiltin( // Insert known tensor type information. unsigned tensor_type_idx = 0; - for (jit::Value* value : results) { + for (jit::Value *value : results) { if (value->type()->kind() == c10::TypeKind::TensorType) { TORCH_CHECK( tensor_type_idx < tensor_types.size(), function->graph()->toString(), @@ -97,23 +101,22 @@ TorchMlirOpVector LowerTorchMlirBuiltin( } // Ensure that we use up all the known tensor type information available. - TORCH_CHECK( - tensor_type_idx == tensor_types.size(), tensor_type_idx, - " known types were injected into jit::Value, but ", tensor_types.size(), - " were provided from lazy::Node!"); + TORCH_CHECK(tensor_type_idx == tensor_types.size(), tensor_type_idx, + " known types were injected into jit::Value, but ", + tensor_types.size(), " were provided from lazy::Node!"); return results; } -TorchMlirOpVector LowerTorchMlirBuiltin( - TorchMlirFunction function, c10::Symbol sym, - const c10::ArrayRef result_shapes, - const std::vector& arguments, - const std::vector& kwarguments) { +TorchMlirOpVector +LowerTorchMlirBuiltin(TorchMlirFunction function, c10::Symbol sym, + const c10::ArrayRef result_shapes, + const std::vector &arguments, + const std::vector &kwarguments) { std::vector tensor_types; // Generate types with fixed tensor shape information. - for (const Shape& shape : result_shapes) { + for (const Shape &shape : result_shapes) { tensor_types.push_back(torch::jit::TensorType::create( /*scalar_type=*/shape.scalar_type(), /*device=*/c10::nullopt, @@ -122,34 +125,34 @@ TorchMlirOpVector LowerTorchMlirBuiltin( /*requires_grad=*/c10::nullopt)); } - return LowerTorchMlirBuiltin( - function, sym, tensor_types, arguments, kwarguments); + return LowerTorchMlirBuiltin(function, sym, tensor_types, arguments, + kwarguments); } -TorchMlirOpVector LowerBuiltin( - const torch::lazy::Node* node, TorchMlirFunction function, - const std::vector& arguments, - const std::vector& kwarguments = {}) { - return LowerTorchMlirBuiltin( - function, node->op().op, node->shapes(), arguments, kwarguments); +TorchMlirOpVector +LowerBuiltin(const torch::lazy::Node *node, TorchMlirFunction function, + const std::vector &arguments, + const std::vector &kwarguments = {}) { + return LowerTorchMlirBuiltin(function, node->op().op, node->shapes(), + arguments, kwarguments); } -TorchMlirOpVector LowerBuiltin( - c10::Symbol sym, const c10::ArrayRef result_shapes, - TorchMlirFunction function, - const std::vector& arguments, - const std::vector& kwarguments = {}) { - return LowerTorchMlirBuiltin( - function, sym, result_shapes, arguments, kwarguments); +TorchMlirOpVector +LowerBuiltin(c10::Symbol sym, const c10::ArrayRef result_shapes, + TorchMlirFunction function, + const std::vector &arguments, + const std::vector &kwarguments = {}) { + return LowerTorchMlirBuiltin(function, sym, result_shapes, arguments, + kwarguments); } -TorchMlirOpVector LowerBuiltin( - c10::Symbol sym, const std::vector types, - TorchMlirFunction function, - const std::vector& arguments, - const std::vector& kwarguments = {}) { +TorchMlirOpVector +LowerBuiltin(c10::Symbol sym, const std::vector types, + TorchMlirFunction function, + const std::vector &arguments, + const std::vector &kwarguments = {}) { return LowerTorchMlirBuiltin(function, sym, types, arguments, kwarguments); } -c10::TensorType& cast_tensor_type(c10::TypePtr value_type) { +c10::TensorType &cast_tensor_type(c10::TypePtr value_type) { auto tensor_type = value_type->cast(); TORCH_CHECK(tensor_type, "Unable to cast Value type to TensorType!"); @@ -157,8 +160,8 @@ c10::TensorType& cast_tensor_type(c10::TypePtr value_type) { } c10::optional> -get_tensor_type_shape(c10::TensorType& tensor_type) { - auto& symbolic_shape = tensor_type.symbolic_sizes(); +get_tensor_type_shape(c10::TensorType &tensor_type) { + auto &symbolic_shape = tensor_type.symbolic_sizes(); if (!symbolic_shape.rank()) { return c10::nullopt; } @@ -175,21 +178,21 @@ get_tensor_type_shape(c10::TensorType& tensor_type) { } std::vector compute_shape_copy(c10::TypePtr value_type) { - c10::TensorType& tensor_type = cast_tensor_type(value_type); + c10::TensorType &tensor_type = cast_tensor_type(value_type); auto maybe_dims = get_tensor_type_shape(tensor_type); TORCH_CHECK(maybe_dims.has_value(), "Cannot copy unranked tensor!"); auto scalar_type = tensor_type.scalarType(); - TORCH_CHECK( - scalar_type.has_value(), "Unable to copy due to lack of scalar type!"); + TORCH_CHECK(scalar_type.has_value(), + "Unable to copy due to lack of scalar type!"); return {Shape(scalar_type.value(), maybe_dims.value())}; } -std::vector compute_shape_slice( - c10::TypePtr value_type, int64_t dim, int64_t start, int64_t end, - int64_t step) { - c10::TensorType& tensor_type = cast_tensor_type(value_type); +std::vector compute_shape_slice(c10::TypePtr value_type, + int64_t dim, int64_t start, + int64_t end, int64_t step) { + c10::TensorType &tensor_type = cast_tensor_type(value_type); auto maybe_dims = get_tensor_type_shape(tensor_type); TORCH_CHECK(maybe_dims.has_value(), "Cannot slice unranked tensor!"); @@ -217,13 +220,13 @@ std::vector compute_shape_slice( } auto scalar_type = tensor_type.scalarType(); - TORCH_CHECK( - scalar_type.has_value(), "Unable to slice due to lack of scalar type!"); + TORCH_CHECK(scalar_type.has_value(), + "Unable to slice due to lack of scalar type!"); return {Shape(scalar_type.value(), dims)}; } -torch::jit::Value* -GenerateClone(torch::jit::Value* val, TorchMlirFunction function) { +torch::jit::Value *GenerateClone(torch::jit::Value *val, + TorchMlirFunction function) { std::vector clone_arguments; clone_arguments.emplace_back(val); @@ -234,20 +237,19 @@ GenerateClone(torch::jit::Value* val, TorchMlirFunction function) { return cloned.front(); } -void GenerateCopy( - torch::jit::Value* destination, torch::jit::Value* source, - TorchMlirFunction function) { +void GenerateCopy(torch::jit::Value *destination, torch::jit::Value *source, + TorchMlirFunction function) { std::vector arguments; arguments.emplace_back(destination); arguments.emplace_back(source); - LowerBuiltin( - at::aten::copy_, c10::ArrayRef(compute_shape_copy(source->type())), - function, arguments); + LowerBuiltin(at::aten::copy_, + c10::ArrayRef(compute_shape_copy(source->type())), + function, arguments); } -torch::jit::Value* GenerateSlice( - torch::jit::Value* base, int64_t dim, int64_t start, int64_t end, - int64_t step, TorchMlirFunction function) { +torch::jit::Value *GenerateSlice(torch::jit::Value *base, int64_t dim, + int64_t start, int64_t end, int64_t step, + TorchMlirFunction function) { std::vector arguments; arguments.emplace_back(base); arguments.emplace_back(dim); @@ -255,11 +257,11 @@ torch::jit::Value* GenerateSlice( arguments.emplace_back(end); arguments.emplace_back(step); - TorchMlirOpVector selected = LowerBuiltin( - at::aten::slice, - c10::ArrayRef( - compute_shape_slice(base->type(), dim, start, end, step)), - function, arguments); + TorchMlirOpVector selected = + LowerBuiltin(at::aten::slice, + c10::ArrayRef(compute_shape_slice(base->type(), dim, + start, end, step)), + function, arguments); TORCH_CHECK_EQ(selected.size(), 1); return selected.front(); } @@ -267,10 +269,10 @@ torch::jit::Value* GenerateSlice( // Node Lowerings // Default Node Lowering -TorchMlirOpVector TorchMlirNode::Lower( - TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { +TorchMlirOpVector TorchMlirNode::Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const { std::vector arguments; - for (const torch::lazy::Output& output : operands()) { + for (const torch::lazy::Output &output : operands()) { arguments.emplace_back(loctx->GetOutputOp(output)); } return LowerBuiltin(this, function, arguments); @@ -280,19 +282,19 @@ TorchMlirOpVector TorchMlirNode::Lower( // Non-native nodes -TorchMlirOpVector -Cast::Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { +TorchMlirOpVector Cast::Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const { std::vector arguments; arguments.emplace_back(loctx->GetOutputOp(operand(0))); arguments.emplace_back(dtype); return LowerBuiltin(at::aten::to, shapes(), function, arguments); } -TorchMlirOpVector DeviceData::Lower( - TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { +TorchMlirOpVector DeviceData::Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const { auto infoptr = data_->info(); auto deviceDataInfoPtr = - (torch::lazy::LazyGraphExecutor::DeviceDataInfo*)infoptr; + (torch::lazy::LazyGraphExecutor::DeviceDataInfo *)infoptr; if (GRAPH_DUMP_ENABLED) { LOG(ERROR) << "Lowering device data node, tensor id " << deviceDataInfoPtr->tensor_id << std::endl; @@ -300,8 +302,8 @@ TorchMlirOpVector DeviceData::Lower( return {loctx->GetParameter(data_)}; } -TorchMlirOpVector Scalar::Lower( - TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { +TorchMlirOpVector Scalar::Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const { auto options = at::TensorOptions() .device(torch::lazy::getBackend()->EagerFallbackDeviceType()) @@ -309,8 +311,8 @@ TorchMlirOpVector Scalar::Lower( return {loctx->graph()->insertConstant(at::scalar_tensor(value, options))}; } -TorchMlirOpVector Expand::Lower( - TorchMlirFunction function, TorchMlirLoweringContext* loctx) const { +TorchMlirOpVector Expand::Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const { std::vector arguments; arguments.emplace_back(loctx->GetOutputOp(operand(0))); arguments.emplace_back(size); diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.h b/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.h index f9e028a5cc15..650bed045c25 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.h +++ b/projects/ltc/csrc/base_lazy_backend/mlir_node_lowering.h @@ -18,14 +18,14 @@ namespace torch { namespace lazy { -typedef std::vector TorchMlirOpVector; +typedef std::vector TorchMlirOpVector; typedef std::shared_ptr TorchMlirFunction; TORCH_API TorchMlirOpVector LowerTorchMlirBuiltin( TorchMlirFunction function, c10::Symbol sym, const c10::ArrayRef result_shapes, - const std::vector& arguments, - const std::vector& kwarguments = {}); + const std::vector &arguments, + const std::vector &kwarguments = {}); } // namespace lazy } // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/ops/device_data.cpp b/projects/ltc/csrc/base_lazy_backend/ops/device_data.cpp index b4271df6691e..c4255068fcb5 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/device_data.cpp +++ b/projects/ltc/csrc/base_lazy_backend/ops/device_data.cpp @@ -2,18 +2,16 @@ #include -#include "device_data.h" #include "../backend_impl.h" +#include "device_data.h" namespace torch { namespace lazy { DeviceData::DeviceData(std::shared_ptr data) - : TorchMlirNode( - ClassOpKind(), - data->shape(), - /*num_outputs=*/1, - /*hash_seed=*/static_cast(101)), + : TorchMlirNode(ClassOpKind(), data->shape(), + /*num_outputs=*/1, + /*hash_seed=*/static_cast(101)), data_(std::move(data)) { propagate_name(); } @@ -21,9 +19,11 @@ DeviceData::DeviceData(std::shared_ptr data) void DeviceData::propagate_name() { if (data_ && name_ != "") { // Add device data name to backend data - TorchMlirBackendData* mlir_data = dynamic_cast(data_.get()); + TorchMlirBackendData *mlir_data = + dynamic_cast(data_.get()); TORCH_CHECK(mlir_data); - auto* info = dynamic_cast(mlir_data->mlir_info()); + auto *info = + dynamic_cast(mlir_data->mlir_info()); TORCH_CHECK(info); info->name = name_; } @@ -34,7 +34,7 @@ void DeviceData::SetData(std::shared_ptr data) { propagate_name(); } -void DeviceData::SetName(const std::string& name) { +void DeviceData::SetName(const std::string &name) { name_ = name; propagate_name(); } @@ -43,12 +43,12 @@ std::string DeviceData::ToString() const { std::stringstream ss; ss << TorchMlirNode::ToString() << ", device=" << data_->device(); if (name_ != "") { - ss << ", name=" << name_; + ss << ", name=" << name_; } return ss.str(); } -const DeviceData* DeviceData::Cast(const Node* node) { +const DeviceData *DeviceData::Cast(const Node *node) { return NodeCast(node); } @@ -59,7 +59,7 @@ NodePtr DeviceData::Create(std::shared_ptr data) { // Ditching the old data_ is safe because tracing is done iteration // by iteration, and after we lauch the async device execution for the // previous iteration, data_ in DeviceData nodes are not needed anymore. - DeviceData* device_data = static_cast(node.get()); + DeviceData *device_data = static_cast(node.get()); device_data->SetData(data); return node; } diff --git a/projects/ltc/csrc/base_lazy_backend/ops/device_data.h b/projects/ltc/csrc/base_lazy_backend/ops/device_data.h index ad9d9d0eb94b..6f96d074962f 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/device_data.h +++ b/projects/ltc/csrc/base_lazy_backend/ops/device_data.h @@ -6,15 +6,12 @@ #include #include - namespace torch { namespace lazy { class TORCH_API DeviceData : public TorchMlirNode { - public: - static OpKind ClassOpKind() { - return ltc_device_data; - } +public: + static OpKind ClassOpKind() { return ltc_device_data; } explicit DeviceData(std::shared_ptr data); @@ -27,22 +24,23 @@ class TORCH_API DeviceData : public TorchMlirNode { std::string ToString() const override; - const std::shared_ptr& data() const { return data_; } + const std::shared_ptr &data() const { return data_; } void SetData(std::shared_ptr data); - TorchMlirOpVector Lower(TorchMlirFunction function, TorchMlirLoweringContext* loctx) const override; + TorchMlirOpVector Lower(TorchMlirFunction function, + TorchMlirLoweringContext *loctx) const override; - static const DeviceData* Cast(const Node* node); + static const DeviceData *Cast(const Node *node); // To reuse IR nodes, use this method to create DeviceData nodes // instead of calling the constructor directly. static NodePtr Create(std::shared_ptr data); - const std::string& GetName() const { return name_; } - void SetName(const std::string& name); + const std::string &GetName() const { return name_; } + void SetName(const std::string &name); - private: +private: void propagate_name(); std::shared_ptr data_; diff --git a/projects/ltc/csrc/base_lazy_backend/ops/generic.cpp b/projects/ltc/csrc/base_lazy_backend/ops/generic.cpp index 1df8be231023..17e578946fb2 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/generic.cpp +++ b/projects/ltc/csrc/base_lazy_backend/ops/generic.cpp @@ -15,12 +15,8 @@ namespace torch { namespace lazy { -Generic::Generic( - OpKind op, - OpList operands, - Shape shape, - size_t num_outputs, - hash_t hash_seed) +Generic::Generic(OpKind op, OpList operands, Shape shape, size_t num_outputs, + hash_t hash_seed) : TorchMlirNode(op, operands, {std::move(shape)}, num_outputs, hash_seed), hash_seed_(hash_seed) {} diff --git a/projects/ltc/csrc/base_lazy_backend/ops/generic.h b/projects/ltc/csrc/base_lazy_backend/ops/generic.h index f294b1cfaed2..01794355a8b4 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/generic.h +++ b/projects/ltc/csrc/base_lazy_backend/ops/generic.h @@ -23,15 +23,11 @@ namespace lazy { // captured by the LowerFn), but they should instead create a dedicated IR node. // Doing the former would limit IR introspection. class TORCH_API Generic : public TorchMlirNode { - public: - Generic( - OpKind op, - OpList operands, - Shape shape, - size_t num_outputs = 1, - hash_t hash_seed = static_cast(0x5a2d296e9)); +public: + Generic(OpKind op, OpList operands, Shape shape, size_t num_outputs = 1, + hash_t hash_seed = static_cast(0x5a2d296e9)); - private: +private: hash_t hash_seed_; }; diff --git a/projects/ltc/csrc/base_lazy_backend/ops/index.cpp b/projects/ltc/csrc/base_lazy_backend/ops/index.cpp index 34af3e590162..ffa2f06bbccf 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/index.cpp +++ b/projects/ltc/csrc/base_lazy_backend/ops/index.cpp @@ -12,9 +12,9 @@ namespace torch { namespace lazy { -IndexTensor::IndexTensor(const torch::lazy::Value& self, - const torch::lazy::Value& indices, - std::vector&& shapes) +IndexTensor::IndexTensor(const torch::lazy::Value &self, + const torch::lazy::Value &indices, + std::vector &&shapes) : torch::lazy::TorchMlirNode(IndexTensor::ClassOpKind(), OpList{self, indices}, std::move(shapes), /* num_outputs */ 1, torch::lazy::MHash()) {} @@ -25,13 +25,13 @@ std::string IndexTensor::ToString() const { return ss.str(); } -bool IndexTensor::CanBeReused(const torch::lazy::Value& self, - const torch::lazy::Value& indices) const { +bool IndexTensor::CanBeReused(const torch::lazy::Value &self, + const torch::lazy::Value &indices) const { return false; } TorchMlirOpVector IndexTensor::Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const { + TorchMlirLoweringContext *loctx) const { PRINT_FUNCTION(); std::vector arguments; std::vector kwarguments; @@ -49,10 +49,10 @@ TorchMlirOpVector IndexTensor::Lower(TorchMlirFunction function, return index_out; } -IndexPut::IndexPut(const torch::lazy::Value& self, - const torch::lazy::Value& indices, - const torch::lazy::Value& values, bool accumulate, - std::vector&& shapes) +IndexPut::IndexPut(const torch::lazy::Value &self, + const torch::lazy::Value &indices, + const torch::lazy::Value &values, bool accumulate, + std::vector &&shapes) : torch::lazy::TorchMlirNode( IndexPut::ClassOpKind(), OpList{self, indices, values}, std::move(shapes), @@ -66,15 +66,15 @@ std::string IndexPut::ToString() const { return ss.str(); } -bool IndexPut::CanBeReused(const torch::lazy::Value& self, - const torch::lazy::Value& indices, - const torch::lazy::Value& values, +bool IndexPut::CanBeReused(const torch::lazy::Value &self, + const torch::lazy::Value &indices, + const torch::lazy::Value &values, bool accumulate) const { return false; } TorchMlirOpVector IndexPut::Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const { + TorchMlirLoweringContext *loctx) const { PRINT_FUNCTION(); std::vector arguments; std::vector kwarguments; @@ -95,5 +95,5 @@ TorchMlirOpVector IndexPut::Lower(TorchMlirFunction function, return index_out; } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/ops/index.h b/projects/ltc/csrc/base_lazy_backend/ops/index.h index e97760fc37ad..6f63cbc686a6 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/index.h +++ b/projects/ltc/csrc/base_lazy_backend/ops/index.h @@ -15,44 +15,44 @@ namespace torch { namespace lazy { class IndexTensor : public torch::lazy::TorchMlirNode { - public: +public: static torch::lazy::OpKind ClassOpKind() { return torch::lazy::OpKind(at::aten::index); } - IndexTensor(const torch::lazy::Value& self, const torch::lazy::Value& indices, - std::vector&& shapes); + IndexTensor(const torch::lazy::Value &self, const torch::lazy::Value &indices, + std::vector &&shapes); std::string ToString() const override; - bool CanBeReused(const torch::lazy::Value& self, - const torch::lazy::Value& indices) const; + bool CanBeReused(const torch::lazy::Value &self, + const torch::lazy::Value &indices) const; TorchMlirOpVector Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + TorchMlirLoweringContext *loctx) const override; }; class IndexPut : public torch::lazy::TorchMlirNode { - public: +public: static torch::lazy::OpKind ClassOpKind() { return torch::lazy::OpKind(at::aten::index_put); } - IndexPut(const torch::lazy::Value& self, const torch::lazy::Value& indices, - const torch::lazy::Value& values, bool accumulate, - std::vector&& shapes); + IndexPut(const torch::lazy::Value &self, const torch::lazy::Value &indices, + const torch::lazy::Value &values, bool accumulate, + std::vector &&shapes); std::string ToString() const override; - bool CanBeReused(const torch::lazy::Value& self, - const torch::lazy::Value& indices, - const torch::lazy::Value& values, bool accumulate) const; + bool CanBeReused(const torch::lazy::Value &self, + const torch::lazy::Value &indices, + const torch::lazy::Value &values, bool accumulate) const; TorchMlirOpVector Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + TorchMlirLoweringContext *loctx) const override; bool accumulate; }; -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/ops/ivalue.cpp b/projects/ltc/csrc/base_lazy_backend/ops/ivalue.cpp index 0653e4467313..e3db5ca37608 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/ivalue.cpp +++ b/projects/ltc/csrc/base_lazy_backend/ops/ivalue.cpp @@ -15,7 +15,7 @@ namespace torch { namespace lazy { -IValueConstant::IValueConstant(const c10::IValue& value) +IValueConstant::IValueConstant(const c10::IValue &value) : torch::lazy::TorchMlirNode(IValueConstant::ClassOpKind(), OpList{}, std::vector{}, /* num_outputs */ 1, torch::lazy::MHash()), @@ -28,9 +28,9 @@ std::string IValueConstant::ToString() const { } TorchMlirOpVector IValueConstant::Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const { + TorchMlirLoweringContext *loctx) const { return {loctx->graph()->insertConstant(value)}; } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/ops/ivalue.h b/projects/ltc/csrc/base_lazy_backend/ops/ivalue.h index 8f488ff47336..48fb95b73ddd 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/ivalue.h +++ b/projects/ltc/csrc/base_lazy_backend/ops/ivalue.h @@ -18,20 +18,20 @@ namespace lazy { // parameter which is helpful in different usecases when we need custom // native ops lowering to torch-mlir IR nodes. class IValueConstant : public torch::lazy::TorchMlirNode { - public: +public: static torch::lazy::OpKind ClassOpKind() { return torch::lazy::OpKind(at::prim::Constant); } - IValueConstant(const c10::IValue& value); + IValueConstant(const c10::IValue &value); std::string ToString() const override; TorchMlirOpVector Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + TorchMlirLoweringContext *loctx) const override; c10::IValue value; }; -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/ops/split.cpp b/projects/ltc/csrc/base_lazy_backend/ops/split.cpp index d20d298dfdd0..91cbd2a52e3d 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/split.cpp +++ b/projects/ltc/csrc/base_lazy_backend/ops/split.cpp @@ -13,10 +13,10 @@ namespace torch { namespace lazy { SplitWithSizesCopy::SplitWithSizesCopy( - const torch::lazy::Value& self, const ::std::vector& split_sizes, - const int64_t& dim, std::vector&& shapes) + const torch::lazy::Value &self, const ::std::vector &split_sizes, + const int64_t &dim, std::vector &&shapes) : torch::lazy::TorchMlirNode(SplitWithSizesCopy::ClassOpKind(), - OpList{ self }, std::move(shapes), + OpList{self}, std::move(shapes), split_sizes.size() /* num_outputs */, torch::lazy::MHash(split_sizes, dim)), split_sizes(split_sizes), dim(dim) {} @@ -29,15 +29,15 @@ std::string SplitWithSizesCopy::ToString() const { return ss.str(); } -bool SplitWithSizesCopy::CanBeReused(const torch::lazy::Value& self, - const ::std::vector& split_sizes, - const int64_t& dim) const { +bool SplitWithSizesCopy::CanBeReused(const torch::lazy::Value &self, + const ::std::vector &split_sizes, + const int64_t &dim) const { return false; } TorchMlirOpVector SplitWithSizesCopy::Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const { + TorchMlirLoweringContext *loctx) const { PRINT_FUNCTION(); std::vector arguments; std::vector kwarguments; @@ -55,13 +55,13 @@ SplitWithSizesCopy::Lower(TorchMlirFunction function, return split_with_sizes_copy_out; } -SplitCopyTensor::SplitCopyTensor(const torch::lazy::Value& self, - const torch::lazy::Value& split_size, - const int64_t& dim, - std::vector&& shapes, +SplitCopyTensor::SplitCopyTensor(const torch::lazy::Value &self, + const torch::lazy::Value &split_size, + const int64_t &dim, + std::vector &&shapes, const size_t num_outputs) : torch::lazy::TorchMlirNode(SplitCopyTensor::ClassOpKind(), - OpList{ self, split_size }, std::move(shapes), + OpList{self, split_size}, std::move(shapes), num_outputs, torch::lazy::MHash(dim)), dim(dim) {} @@ -72,15 +72,15 @@ std::string SplitCopyTensor::ToString() const { return ss.str(); } -bool SplitCopyTensor::CanBeReused(const torch::lazy::Value& self, - const torch::lazy::Value& split_size, - const int64_t& dim) const { +bool SplitCopyTensor::CanBeReused(const torch::lazy::Value &self, + const torch::lazy::Value &split_size, + const int64_t &dim) const { return false; } TorchMlirOpVector SplitCopyTensor::Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const { + TorchMlirLoweringContext *loctx) const { PRINT_FUNCTION(); std::vector arguments; std::vector kwarguments; diff --git a/projects/ltc/csrc/base_lazy_backend/ops/split.h b/projects/ltc/csrc/base_lazy_backend/ops/split.h index 8593d5628c2e..116ddd64ab2b 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/split.h +++ b/projects/ltc/csrc/base_lazy_backend/ops/split.h @@ -20,19 +20,19 @@ class SplitWithSizesCopy : public torch::lazy::TorchMlirNode { return torch::lazy::OpKind(at::aten::split_with_sizes_copy); } - SplitWithSizesCopy(const torch::lazy::Value& self, - const ::std::vector& split_sizes, - const int64_t& dim, - std::vector&& shapes); + SplitWithSizesCopy(const torch::lazy::Value &self, + const ::std::vector &split_sizes, + const int64_t &dim, + std::vector &&shapes); std::string ToString() const override; - bool CanBeReused(const torch::lazy::Value& self, - const ::std::vector& split_sizes, - const int64_t& dim) const; + bool CanBeReused(const torch::lazy::Value &self, + const ::std::vector &split_sizes, + const int64_t &dim) const; TorchMlirOpVector Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + TorchMlirLoweringContext *loctx) const override; std::vector split_sizes; int64_t dim; @@ -44,19 +44,19 @@ class SplitCopyTensor : public torch::lazy::TorchMlirNode { return torch::lazy::OpKind(at::aten::split_copy); } - SplitCopyTensor(const torch::lazy::Value& self, - const torch::lazy::Value& split_size, const int64_t& dim, - std::vector&& shapes, + SplitCopyTensor(const torch::lazy::Value &self, + const torch::lazy::Value &split_size, const int64_t &dim, + std::vector &&shapes, const size_t num_outputs = 1); std::string ToString() const override; - bool CanBeReused(const torch::lazy::Value& self, - const torch::lazy::Value& split_size, - const int64_t& dim) const; + bool CanBeReused(const torch::lazy::Value &self, + const torch::lazy::Value &split_size, + const int64_t &dim) const; TorchMlirOpVector Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + TorchMlirLoweringContext *loctx) const override; int64_t dim; }; diff --git a/projects/ltc/csrc/base_lazy_backend/ops/to_copy.h b/projects/ltc/csrc/base_lazy_backend/ops/to_copy.h index c6b75baaf8f3..402355031474 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/to_copy.h +++ b/projects/ltc/csrc/base_lazy_backend/ops/to_copy.h @@ -17,61 +17,65 @@ namespace torch { namespace lazy { - -// This IR was copied from code-generated output, but the entire _to_copy operator -// cannot be trivially code genereated since it is only desirable to capture IR for -// certain permutaions of _to_copy (e.g. dtype), and for the others it is difficult to even invoke -// the aten/eager fallback necessitating directly implementing the right to(device) behavior +// This IR was copied from code-generated output, but the entire _to_copy +// operator cannot be trivially code genereated since it is only desirable to +// capture IR for certain permutaions of _to_copy (e.g. dtype), and for the +// others it is difficult to even invoke the aten/eager fallback necessitating +// directly implementing the right to(device) behavior class ToCopy : public torch::lazy::TorchMlirNode { - public: - ToCopy(const torch::lazy::Value& self, const c10::optional& dtype, const c10::optional& layout, const c10::optional& device, const c10::optional& pin_memory, const bool& non_blocking, const c10::optional& memory_format, std::vector&& shapes) - : torch::lazy::TorchMlirNode(torch::lazy::OpKind(at::aten::_to_copy), - {self}, std::move(shapes), - /* num_outputs */ 1, - torch::lazy::MHash(dtype, layout, device, pin_memory, non_blocking, memory_format)), +public: + ToCopy(const torch::lazy::Value &self, + const c10::optional &dtype, + const c10::optional &layout, + const c10::optional &device, + const c10::optional &pin_memory, const bool &non_blocking, + const c10::optional &memory_format, + std::vector &&shapes) + : torch::lazy::TorchMlirNode( + torch::lazy::OpKind(at::aten::_to_copy), {self}, std::move(shapes), + /* num_outputs */ 1, + torch::lazy::MHash(dtype, layout, device, pin_memory, non_blocking, + memory_format)), - dtype(dtype), - layout(layout), - device(device), - pin_memory(pin_memory), - non_blocking(non_blocking), - memory_format(memory_format) {} + dtype(dtype), layout(layout), device(device), pin_memory(pin_memory), + non_blocking(non_blocking), memory_format(memory_format) {} std::string ToString() const override { std::stringstream ss; ss << torch::lazy::TorchMlirNode::ToString(); if (dtype.has_value()) { - ss << ", dtype=" << dtype.value(); + ss << ", dtype=" << dtype.value(); } else { - ss << ", dtype=null"; + ss << ", dtype=null"; } if (layout.has_value()) { - ss << ", layout=" << layout.value(); + ss << ", layout=" << layout.value(); } else { - ss << ", layout=null"; + ss << ", layout=null"; } if (device.has_value()) { - ss << ", device=" << device.value(); + ss << ", device=" << device.value(); } else { - ss << ", device=null"; + ss << ", device=null"; } if (pin_memory.has_value()) { - ss << ", pin_memory=" << pin_memory.value(); + ss << ", pin_memory=" << pin_memory.value(); } else { - ss << ", pin_memory=null"; + ss << ", pin_memory=null"; } ss << ", non_blocking=" << non_blocking; if (memory_format.has_value()) { - ss << ", memory_format=" << memory_format.value(); + ss << ", memory_format=" << memory_format.value(); } else { - ss << ", memory_format=null"; + ss << ", memory_format=null"; } return ss.str(); } - torch::lazy::TorchMlirOpVector Lower(TorchMlirFunction function, - torch::lazy::TorchMlirLoweringContext* loctx) const override { - std::vector arguments; + torch::lazy::TorchMlirOpVector + Lower(TorchMlirFunction function, + torch::lazy::TorchMlirLoweringContext *loctx) const override { + std::vector arguments; std::vector kwarguments; arguments.reserve(1); kwarguments.reserve(6); @@ -83,11 +87,12 @@ class ToCopy : public torch::lazy::TorchMlirNode { kwarguments.emplace_back("pin_memory", pin_memory); kwarguments.emplace_back("non_blocking", non_blocking); kwarguments.emplace_back("memory_format", memory_format); - torch::lazy::TorchMlirOpVector _to_copy_out = torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), arguments, kwarguments); + torch::lazy::TorchMlirOpVector _to_copy_out = + torch::lazy::LowerTorchMlirBuiltin(function, op().op, shapes(), + arguments, kwarguments); TORCH_CHECK_EQ(_to_copy_out.size(), 1); return _to_copy_out; - } c10::optional dtype; @@ -97,5 +102,5 @@ class ToCopy : public torch::lazy::TorchMlirNode { bool non_blocking; c10::optional memory_format; }; -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.cpp b/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.cpp index a5526366cd2b..c43c84d24d5e 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.cpp +++ b/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.cpp @@ -12,9 +12,9 @@ namespace torch { namespace lazy { -UnbindCopyInt::UnbindCopyInt(const torch::lazy::Value& self, const int64_t& dim, - std::vector&& shapes) - : torch::lazy::TorchMlirNode(UnbindCopyInt::ClassOpKind(), OpList{ self }, +UnbindCopyInt::UnbindCopyInt(const torch::lazy::Value &self, const int64_t &dim, + std::vector &&shapes) + : torch::lazy::TorchMlirNode(UnbindCopyInt::ClassOpKind(), OpList{self}, std::move(shapes), self.shape().size(dim), /* num_outputs */ torch::lazy::MHash(dim)), @@ -27,13 +27,13 @@ std::string UnbindCopyInt::ToString() const { return ss.str(); } -bool UnbindCopyInt::CanBeReused(const torch::lazy::Value& self, - const int64_t& dim) const { +bool UnbindCopyInt::CanBeReused(const torch::lazy::Value &self, + const int64_t &dim) const { return false; } TorchMlirOpVector UnbindCopyInt::Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const { + TorchMlirLoweringContext *loctx) const { PRINT_FUNCTION(); std::vector arguments; std::vector kwarguments; diff --git a/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.h b/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.h index 766752c16517..9d6d83842b10 100644 --- a/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.h +++ b/projects/ltc/csrc/base_lazy_backend/ops/unbind_int.h @@ -20,15 +20,15 @@ class UnbindCopyInt : public torch::lazy::TorchMlirNode { return torch::lazy::OpKind(at::aten::unbind_copy); } - UnbindCopyInt(const torch::lazy::Value& self, const int64_t& dim, - std::vector&& shapes); + UnbindCopyInt(const torch::lazy::Value &self, const int64_t &dim, + std::vector &&shapes); std::string ToString() const override; - bool CanBeReused(const torch::lazy::Value& self, const int64_t& dim) const; + bool CanBeReused(const torch::lazy::Value &self, const int64_t &dim) const; TorchMlirOpVector Lower(TorchMlirFunction function, - TorchMlirLoweringContext* loctx) const override; + TorchMlirLoweringContext *loctx) const override; int64_t dim; }; diff --git a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp index 325e89e14d5e..8e3b2c0702d3 100644 --- a/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp +++ b/projects/ltc/csrc/base_lazy_backend/shape_inference.cpp @@ -21,21 +21,20 @@ namespace lazy { // TODO(henrytu): Upstream these shape inference functions to PyTorch in the // future. -std::vector compute_shape_add(const at::Tensor& self, - const at::Scalar& other, - const at::Scalar& alpha) { +std::vector compute_shape_add(const at::Tensor &self, + const at::Scalar &other, + const at::Scalar &alpha) { return {Shape(self.scalar_type(), self.sizes().vec())}; } - -std::vector compute_shape_sub(const at::Tensor& self, - const at::Scalar& other, - const at::Scalar& alpha) { +std::vector compute_shape_sub(const at::Tensor &self, + const at::Scalar &other, + const at::Scalar &alpha) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_div(const at::Tensor& self, - const at::Scalar& other) { +std::vector compute_shape_div(const at::Tensor &self, + const at::Scalar &other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } @@ -85,7 +84,7 @@ compute_shape_quantize_per_tensor(const at::Tensor &self, double scale, return {Shape(dtype, self.sizes().vec())}; } -std::vector compute_shape_isinf(const at::Tensor& self) { +std::vector compute_shape_isinf(const at::Tensor &self) { return {Shape(at::kBool, self.sizes().vec())}; } @@ -96,9 +95,8 @@ std::vector compute_shape_quantize_per_channel( } std::vector compute_shape_max_pool3d_with_indices( - const at::Tensor& self, at::IntArrayRef kernel_size, - at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, - bool ceil_mode) { + const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, + at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode) { auto in_sizes = self.sizes().vec(); std::vector dhw(3, 0); std::vector paddings = padding.vec(); @@ -106,18 +104,19 @@ std::vector compute_shape_max_pool3d_with_indices( std::vector dilations = dilation.vec(); std::vector strides = stride.vec(); TORCH_CHECK(in_sizes.size() == 5, "max_pool3d requires 5D inputs, but got ", - in_sizes); - TORCH_CHECK(kernel_size.size() == 3 && - stride.size() == 3 && - padding.size() == 3 && - dilation.size() == 3, "max_pool3d requires 3D operands, but got ", - kernel_size, stride, padding, dilation); + in_sizes); + TORCH_CHECK(kernel_size.size() == 3 && stride.size() == 3 && + padding.size() == 3 && dilation.size() == 3, + "max_pool3d requires 3D operands, but got ", kernel_size, stride, + padding, dilation); int64_t batch = in_sizes[0]; int64_t channel = in_sizes[1]; // NCDHW // https://pytorch.org/docs/stable/generated/torch.nn.MaxPool3d.html - for (auto i = 0UL; i<3; ++i) { - double out_size = (in_sizes[2+i] + 2 * paddings[i] - dilations[i] * - (ksizes[i] - 1) - 1) / (double)strides[i] + 1; + for (auto i = 0UL; i < 3; ++i) { + double out_size = (in_sizes[2 + i] + 2 * paddings[i] - + dilations[i] * (ksizes[i] - 1) - 1) / + (double)strides[i] + + 1; if (ceil_mode) dhw[i] = (int64_t)std::ceil(out_size); else @@ -129,52 +128,54 @@ std::vector compute_shape_max_pool3d_with_indices( } std::vector compute_shape_max_pool3d_with_indices_backward( - const at::Tensor & grad_output, const at::Tensor & self, + const at::Tensor &grad_output, const at::Tensor &self, at::IntArrayRef kernel_size, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool ceil_mode, - const at::Tensor & indices) { + const at::Tensor &indices) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_mse_loss_backward( - const at::Tensor& grad_output, const at::Tensor& self, - const at::Tensor& target, int64_t reduction) { +std::vector +compute_shape_mse_loss_backward(const at::Tensor &grad_output, + const at::Tensor &self, + const at::Tensor &target, int64_t reduction) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_mul(const at::Tensor& self, - const at::Scalar& other) { +std::vector compute_shape_mul(const at::Tensor &self, + const at::Scalar &other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_var( - const at::Tensor& self, at::OptionalIntArrayRef dim, - const c10::optional & correction, bool keepdim) { +std::vector +compute_shape_var(const at::Tensor &self, at::OptionalIntArrayRef dim, + const c10::optional &correction, bool keepdim) { // Result of variance is scalar tensor. return {Shape(self.scalar_type(), {})}; } -std::vector compute_shape_nan_to_num( - const at::Tensor & self, c10::optional nan, - c10::optional posinf, c10::optional neginf) { +std::vector +compute_shape_nan_to_num(const at::Tensor &self, c10::optional nan, + c10::optional posinf, + c10::optional neginf) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_hardtanh( - const at::Tensor& self, const at::Scalar& min_val, - const at::Scalar& max_val) { +std::vector +compute_shape_hardtanh(const at::Tensor &self, const at::Scalar &min_val, + const at::Scalar &max_val) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_hardtanh_backward( - const at::Tensor& grad_output, const at::Tensor& self, - const at::Scalar& min_val, const at::Scalar& max_val) { + const at::Tensor &grad_output, const at::Tensor &self, + const at::Scalar &min_val, const at::Scalar &max_val) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_where(const at::Tensor& condition, - const at::Tensor& self, - const at::Tensor& other) { +std::vector compute_shape_where(const at::Tensor &condition, + const at::Tensor &self, + const at::Tensor &other) { // There are cases like - // torch.aten.where.self %42, %arg17, %37 : !torch.vtensor<[15,10],i1>, // !torch.vtensor<[],f32>, !torch.vtensor<[15,10],f32>. @@ -201,32 +202,32 @@ std::vector compute_shape_where(const at::Tensor& condition, return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } -std::vector compute_shape_bucketize( - const at::Tensor& self, const at::Tensor& boundaries, bool out_int32, - bool right) { +std::vector +compute_shape_bucketize(const at::Tensor &self, const at::Tensor &boundaries, + bool out_int32, bool right) { auto dtype = out_int32 ? at::kInt : at::kLong; return {Shape(dtype, self.sizes().vec())}; } -std::vector compute_shape_copy(const at::Tensor& self, - const at::Tensor& src, +std::vector compute_shape_copy(const at::Tensor &self, + const at::Tensor &src, bool non_blocking) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_floor_divide( - const at::Tensor& self, const at::Tensor& other) { +std::vector +compute_shape_floor_divide(const at::Tensor &self, const at::Tensor &other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_fmod(const at::Tensor& self, - const at::Scalar& other) { +std::vector compute_shape_fmod(const at::Tensor &self, + const at::Scalar &other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_native_group_norm( - const at::Tensor& input, const c10::optional& weight, - const c10::optional& bias, int64_t N, int64_t C, int64_t HxW, + const at::Tensor &input, const c10::optional &weight, + const c10::optional &bias, int64_t N, int64_t C, int64_t HxW, int64_t group, double eps) { TORCH_CHECK(input.sizes().size() >= 2, @@ -244,9 +245,10 @@ std::vector compute_shape_native_group_norm( return shapes; } -std::vector compute_shape_im2col( - const at::Tensor& self, at::IntArrayRef kernel_size, - at::IntArrayRef dilation, at::IntArrayRef padding, at::IntArrayRef stride) { +std::vector +compute_shape_im2col(const at::Tensor &self, at::IntArrayRef kernel_size, + at::IntArrayRef dilation, at::IntArrayRef padding, + at::IntArrayRef stride) { auto self_meta = at::native::empty_strided_meta_symint( self.sym_sizes(), self.sym_strides(), @@ -260,8 +262,8 @@ std::vector compute_shape_im2col( } std::vector compute_shape_native_group_norm_backward( - const at::Tensor& grad_out, const at::Tensor& input, const at::Tensor& mean, - const at::Tensor& rstd, const c10::optional& weight, int64_t N, + const at::Tensor &grad_out, const at::Tensor &input, const at::Tensor &mean, + const at::Tensor &rstd, const c10::optional &weight, int64_t N, int64_t C, int64_t HxW, int64_t group, ::std::array output_mask) { TORCH_CHECK(input.sizes().size() >= 2, @@ -280,8 +282,8 @@ std::vector compute_shape_native_group_norm_backward( return shapes; } -std::vector compute_shape_remainder( - const at::Tensor& self, const at::Scalar& other) { +std::vector +compute_shape_remainder(const at::Tensor &self, const at::Scalar &other) { return {Shape(self.scalar_type(), self.sizes().vec())}; } @@ -313,21 +315,22 @@ compute_shape_reflection_pad2d(const at::Tensor &self, return {Shape(self.scalar_type(), out_sizes)}; } -std::vector compute_shape_uniform( - const at::Tensor& self, double from, double to, - c10::optional generator) { +std::vector +compute_shape_uniform(const at::Tensor &self, double from, double to, + c10::optional generator) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_normal_functional( - const at::Tensor& self, double mean, double std, - c10::optional generator) { +std::vector +compute_shape_normal_functional(const at::Tensor &self, double mean, double std, + c10::optional generator) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_multinomial( - const at::Tensor& self, int64_t num_samples, bool replacement, - c10::optional generator) { +std::vector +compute_shape_multinomial(const at::Tensor &self, int64_t num_samples, + bool replacement, + c10::optional generator) { // Input tensor can be either 1D or 2D. The last dim of output // should be 'num_samples'. So the output shape can be either // [num_samples] or [m, num_samples]. @@ -337,35 +340,38 @@ std::vector compute_shape_multinomial( return {Shape(at::kLong, ishape)}; } -std::vector compute_shape_eye( - int64_t n, c10::optional dtype, - c10::optional layout, c10::optional device, - c10::optional pin_memory) { +std::vector +compute_shape_eye(int64_t n, c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { auto out_meta = at::eye(n, dtype, layout, c10::Device(c10::kMeta), pin_memory); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } -std::vector compute_shape_eye( - int64_t n, int64_t m, c10::optional dtype, - c10::optional layout, c10::optional device, - c10::optional pin_memory) { +std::vector +compute_shape_eye(int64_t n, int64_t m, c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { auto out_meta = at::eye(n, m, dtype, layout, c10::Device(c10::kMeta), pin_memory); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } -std::vector compute_shape_arange( - const at::Scalar& end, c10::optional dtype, - c10::optional layout, c10::optional device, - c10::optional pin_memory) { +std::vector +compute_shape_arange(const at::Scalar &end, c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { auto out_meta = at::arange(end, dtype, layout, c10::Device(c10::kMeta), pin_memory); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } std::vector compute_shape_arange( - const at::Scalar& start, const at::Scalar& end, + const at::Scalar &start, const at::Scalar &end, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { auto out_meta = at::arange(start, end, dtype, layout, c10::Device(c10::kMeta), @@ -374,7 +380,7 @@ std::vector compute_shape_arange( } std::vector compute_shape_arange( - const at::Scalar& start, const at::Scalar& end, const at::Scalar& step, + const at::Scalar &start, const at::Scalar &end, const at::Scalar &step, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { auto out_meta = at::arange(start, end, step, dtype, layout, @@ -383,34 +389,37 @@ std::vector compute_shape_arange( } std::vector compute_shape_full( - at::IntArrayRef size, const at::Scalar& fill_value, + at::IntArrayRef size, const at::Scalar &fill_value, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } -std::vector compute_shape_ones( - at::IntArrayRef size, c10::optional dtype, - c10::optional layout, c10::optional device, - c10::optional pin_memory) { +std::vector +compute_shape_ones(at::IntArrayRef size, c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } -std::vector compute_shape_zeros( - at::IntArrayRef size, c10::optional dtype, - c10::optional layout, c10::optional device, - c10::optional pin_memory) { +std::vector +compute_shape_zeros(at::IntArrayRef size, c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } -std::vector compute_shape_empty( - at::IntArrayRef size, c10::optional dtype, - c10::optional layout, c10::optional device, - c10::optional pin_memory, - c10::optional memory_format) { +std::vector +compute_shape_empty(at::IntArrayRef size, c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory, + c10::optional memory_format) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } @@ -423,20 +432,21 @@ std::vector compute_shape_empty_strided( Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } -std::vector compute_shape_fill(const at::Tensor& self, - const at::Scalar& value) { +std::vector compute_shape_fill(const at::Tensor &self, + const at::Scalar &value) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_fill(const at::Tensor& self, - const at::Tensor& value) { +std::vector compute_shape_fill(const at::Tensor &self, + const at::Tensor &value) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_randn( - at::IntArrayRef size, c10::optional dtype, - c10::optional layout, c10::optional device, - c10::optional pin_memory) { +std::vector +compute_shape_randn(at::IntArrayRef size, c10::optional dtype, + c10::optional layout, + c10::optional device, + c10::optional pin_memory) { return { Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } @@ -457,36 +467,39 @@ std::vector compute_shape_randint( Shape(dtype.value_or(at::get_default_dtype_as_scalartype()), size.vec())}; } -std::vector compute_shape_resize( - const at::Tensor & self, at::IntArrayRef size, - c10::optional memory_format) { +std::vector +compute_shape_resize(const at::Tensor &self, at::IntArrayRef size, + c10::optional memory_format) { return {Shape(self.scalar_type(), size.vec())}; } -std::vector compute_shape_bernoulli( - const at::Tensor& self, const at::Tensor &p, - c10::optional generator) { +std::vector +compute_shape_bernoulli(const at::Tensor &self, const at::Tensor &p, + c10::optional generator) { return {Shape(self.scalar_type(), self.sizes().vec())}; } std::vector compute_shape_scalar_tensor( - const at::Scalar & s, c10::optional dtype, + const at::Scalar &s, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { return {Shape(dtype.value_or(s.type()), c10::ArrayRef{})}; } -std::vector compute_shape_roll( - const at::Tensor& self, at::IntArrayRef shifts, at::IntArrayRef dims) { +std::vector compute_shape_roll(const at::Tensor &self, + at::IntArrayRef shifts, + at::IntArrayRef dims) { return {Shape(self.scalar_type(), self.sizes().vec())}; } -std::vector compute_shape_linspace(const at::Scalar & start, const at::Scalar & end, int64_t steps, c10::optional dtype, c10::optional layout, c10::optional device, c10::optional pin_memory) { - auto out_meta = - at::linspace(start, end, steps, dtype, layout, c10::Device(c10::kMeta), pin_memory); +std::vector compute_shape_linspace( + const at::Scalar &start, const at::Scalar &end, int64_t steps, + c10::optional dtype, c10::optional layout, + c10::optional device, c10::optional pin_memory) { + auto out_meta = at::linspace(start, end, steps, dtype, layout, + c10::Device(c10::kMeta), pin_memory); return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())}; } - -} // namespace lazy +} // namespace lazy } // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/tensor.cpp b/projects/ltc/csrc/base_lazy_backend/tensor.cpp index 82ae6cc27f4a..5be4ab369ff1 100644 --- a/projects/ltc/csrc/base_lazy_backend/tensor.cpp +++ b/projects/ltc/csrc/base_lazy_backend/tensor.cpp @@ -14,16 +14,16 @@ namespace torch { namespace lazy { -at::Tensor CreateFunctionalizedAtenFromLtcTensor( - const LazyTensorPtr& ltc_tensor) { +at::Tensor +CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr <c_tensor) { at::Tensor tensor = CreateAtenFromLtcTensor(ltc_tensor); if (!c10::impl::tls_is_dispatch_key_excluded( - c10::DispatchKey::Functionalize) && + c10::DispatchKey::Functionalize) && !at::functionalization::impl::isFunctionalTensor(tensor)) { return at::functionalization::impl::to_functional_tensor(tensor); } return tensor; } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/tensor.h b/projects/ltc/csrc/base_lazy_backend/tensor.h index 4e39dd095aa5..18e63ef68cd6 100644 --- a/projects/ltc/csrc/base_lazy_backend/tensor.h +++ b/projects/ltc/csrc/base_lazy_backend/tensor.h @@ -18,7 +18,8 @@ namespace lazy { // should have explicit tensor functinoalization. Otherwise we can get // unfanctionalized primitives or in the worst case if we apply inplace // operations to unfunctionalized tensor it won't be captured in LTC graph. -TORCH_API at::Tensor CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr& ltc_tensor); +TORCH_API at::Tensor +CreateFunctionalizedAtenFromLtcTensor(const LazyTensorPtr <c_tensor); } // namespace lazy } // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/utils/exception.h b/projects/ltc/csrc/base_lazy_backend/utils/exception.h index 96510d830aef..533677ad86eb 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/exception.h +++ b/projects/ltc/csrc/base_lazy_backend/utils/exception.h @@ -21,8 +21,8 @@ } #define UNIMPLEMENTED_FUNCTION_ERROR() \ - UNIMPLEMENTED_ERROR( \ - "\n\t" << __FILE__ << ":" << __LINE__ << " " << __PRETTY_FUNCTION__) + UNIMPLEMENTED_ERROR("\n\t" << __FILE__ << ":" << __LINE__ << " " \ + << __PRETTY_FUNCTION__) #define UNSUPPORTED_ERROR(msg) \ { \ diff --git a/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.cpp b/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.cpp index 9ca8b666a42e..a4f3673715e5 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.cpp +++ b/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.cpp @@ -7,9 +7,9 @@ namespace torch { namespace jit { -void ConvertScalarImplicit(std::shared_ptr& graph) { +void ConvertScalarImplicit(std::shared_ptr &graph) { DepthFirstGraphNodeIterator it(graph); - for (auto* node = it.next(); node != nullptr; node = it.next()) { + for (auto *node = it.next(); node != nullptr; node = it.next()) { if (node->kind() != c10::aten::ScalarImplicit) { continue; } @@ -27,15 +27,13 @@ void ConvertScalarImplicit(std::shared_ptr& graph) { node_type = c10::aten::FloatImplicit; output_type = FloatType::get(); } else { - throw std::runtime_error( - "Expected isIntegralType or isFloatingType"); + throw std::runtime_error("Expected isIntegralType or isFloatingType"); } - Value * output = graph - ->create(node_type, {input}) - ->insertBefore(node) - ->output() - ->setType(output_type); + Value *output = graph->create(node_type, {input}) + ->insertBefore(node) + ->output() + ->setType(output_type); node->output()->replaceAllUsesWith(output); node->destroy(); } diff --git a/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.h b/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.h index 2c4214cfc1ab..d9e47b464235 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.h +++ b/projects/ltc/csrc/base_lazy_backend/utils/jit_utils.h @@ -4,7 +4,7 @@ namespace torch { namespace jit { // Convert ScalarImplicit to IntImplicit or FloatImplicit. -TORCH_API void ConvertScalarImplicit(std::shared_ptr& graph); +TORCH_API void ConvertScalarImplicit(std::shared_ptr &graph); } // namespace jit } // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/utils/string_utils.h b/projects/ltc/csrc/base_lazy_backend/utils/string_utils.h index 281331992e49..a5a524b05353 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/string_utils.h +++ b/projects/ltc/csrc/base_lazy_backend/utils/string_utils.h @@ -1,49 +1,49 @@ #pragma once -#include #include +#include #include - template -std::ostream& string_join(std::ostream& out, const std::vector& v, const std::string& delimiter) { - size_t i = 0; - for (const T& e : v) { - if ((i++) > 0) { out << delimiter; } - out << e; +std::ostream &string_join(std::ostream &out, const std::vector &v, + const std::string &delimiter) { + size_t i = 0; + for (const T &e : v) { + if ((i++) > 0) { + out << delimiter; } - return out; + out << e; + } + return out; } template -std::string string_join(const std::vector& v, const std::string& delimiter) { - std::ostringstream joined; - string_join(joined, v, delimiter); - return joined.str(); +std::string string_join(const std::vector &v, const std::string &delimiter) { + std::ostringstream joined; + string_join(joined, v, delimiter); + return joined.str(); } -inline std::vector string_split( - const std::string& str, - const std::string& sep -) { - std::vector tokens; - std::size_t pos1 = str.find_first_not_of(sep); - while (pos1 != std::string::npos) { - std::size_t pos2 = str.find_first_of(sep, pos1); - if (pos2 == std::string::npos) { - tokens.push_back(str.substr(pos1)); - pos1 = pos2; - } else { - tokens.push_back(str.substr(pos1, pos2 - pos1)); - pos1 = str.find_first_not_of(sep, pos2 + 1); - } +inline std::vector string_split(const std::string &str, + const std::string &sep) { + std::vector tokens; + std::size_t pos1 = str.find_first_not_of(sep); + while (pos1 != std::string::npos) { + std::size_t pos2 = str.find_first_of(sep, pos1); + if (pos2 == std::string::npos) { + tokens.push_back(str.substr(pos1)); + pos1 = pos2; + } else { + tokens.push_back(str.substr(pos1, pos2 - pos1)); + pos1 = str.find_first_not_of(sep, pos2 + 1); } - return tokens; + } + return tokens; } /* * Returns true if str starts with prefix */ -inline bool startswith(const std::string& str, const std::string& prefix) { - return str.rfind(prefix, 0) == 0; +inline bool startswith(const std::string &str, const std::string &prefix) { + return str.rfind(prefix, 0) == 0; } diff --git a/projects/ltc/csrc/base_lazy_backend/utils/sys_utils.h b/projects/ltc/csrc/base_lazy_backend/utils/sys_utils.h index 5ae14904909a..f6c51ba6158f 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/sys_utils.h +++ b/projects/ltc/csrc/base_lazy_backend/utils/sys_utils.h @@ -6,24 +6,25 @@ namespace sys_util { template -static T GetEnv(const std::string& name, const T& default_value = T(0)) { - const char* env = std::getenv(name.c_str()); +static T GetEnv(const std::string &name, const T &default_value = T(0)) { + const char *env = std::getenv(name.c_str()); if (!env) { return default_value; } return T(std::atoi(env)); } -static std::string GetEnvString(const std::string& name, const std::string& default_value) { - const char* env = std::getenv(name.c_str()); +static std::string GetEnvString(const std::string &name, + const std::string &default_value) { + const char *env = std::getenv(name.c_str()); if (!env) { return default_value; } return std::string(env); } -static bool GetEnvBool(const char* name, bool defval) { - const char* env = std::getenv(name); +static bool GetEnvBool(const char *name, bool defval) { + const char *env = std::getenv(name); if (env == nullptr) { return defval; } diff --git a/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.cpp b/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.cpp index cdd97168031b..71a0e89f4c64 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.cpp +++ b/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.cpp @@ -3,84 +3,90 @@ #include "../generated/LazyIr.h" #include "../mlir_node.h" - namespace torch { namespace lazy { -bool is_detach_copy(const torch::lazy::Node* node) { - return node && node->op() == torch::lazy::DetachCopy::ClassOpKind(); +bool is_detach_copy(const torch::lazy::Node *node) { + return node && node->op() == torch::lazy::DetachCopy::ClassOpKind(); } -bool is_detach_copy(const torch::lazy::Value& value) { - return is_detach_copy(value.node.get()); +bool is_detach_copy(const torch::lazy::Value &value) { + return is_detach_copy(value.node.get()); } -torch::lazy::Node* extract_non_detach_copy_node(torch::lazy::Node* node) { - if (!node) { return nullptr; } +torch::lazy::Node *extract_non_detach_copy_node(torch::lazy::Node *node) { + if (!node) { + return nullptr; + } - torch::lazy::TorchMlirNode* mlir_node = dynamic_cast(node); - while(mlir_node && is_detach_copy(mlir_node)) { - mlir_node = mlir_node->mlir_node(0); - } - if (!mlir_node) { - return node; - } - return mlir_node; + torch::lazy::TorchMlirNode *mlir_node = + dynamic_cast(node); + while (mlir_node && is_detach_copy(mlir_node)) { + mlir_node = mlir_node->mlir_node(0); + } + if (!mlir_node) { + return node; + } + return mlir_node; } -const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node* node) { - if (!node) { return nullptr; } +const torch::lazy::Node * +extract_non_detach_copy_node(const torch::lazy::Node *node) { + if (!node) { + return nullptr; + } - const torch::lazy::TorchMlirNode* mlir_node = dynamic_cast(node); - while(mlir_node && is_detach_copy(mlir_node)) { - mlir_node = mlir_node->mlir_node(0); - } - if (!mlir_node) { - return node; - } - return mlir_node; + const torch::lazy::TorchMlirNode *mlir_node = + dynamic_cast(node); + while (mlir_node && is_detach_copy(mlir_node)) { + mlir_node = mlir_node->mlir_node(0); + } + if (!mlir_node) { + return node; + } + return mlir_node; } - -torch::lazy::DeviceData* device_data_cast(torch::lazy::Node* node) { - if (!node) { - return nullptr; - } - node = extract_non_detach_copy_node(node); - if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) { - return dynamic_cast(node); - } +torch::lazy::DeviceData *device_data_cast(torch::lazy::Node *node) { + if (!node) { return nullptr; + } + node = extract_non_detach_copy_node(node); + if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) { + return dynamic_cast(node); + } + return nullptr; } -const torch::lazy::DeviceData* device_data_cast(const torch::lazy::Node* node) { - if (!node) { - return nullptr; - } - node = extract_non_detach_copy_node(node); - if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) { - return dynamic_cast(node); - } +const torch::lazy::DeviceData *device_data_cast(const torch::lazy::Node *node) { + if (!node) { return nullptr; + } + node = extract_non_detach_copy_node(node); + if (node && node->op() == torch::lazy::DeviceData::ClassOpKind()) { + return dynamic_cast(node); + } + return nullptr; } -torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value) { - if (!value) { - return nullptr; - } - return device_data_cast(value.node.get()); +torch::lazy::DeviceData *device_data_cast(const torch::lazy::Value &value) { + if (!value) { + return nullptr; + } + return device_data_cast(value.node.get()); } -torch::lazy::DeviceData* device_data_cast( - const at::Tensor& tensor, c10::optional device -) { - if (!device) { - device = torch::lazy::GetBackendDevice(tensor); - } - TORCH_CHECK(device); - torch::lazy::LazyTensorPtr lazy_tensor = torch::lazy::GetLtcTensorOrCreateForWrappedNumber(tensor, *device); - if (lazy_tensor) { - return device_data_cast(lazy_tensor->GetIrValue()); - } - return nullptr; +torch::lazy::DeviceData * +device_data_cast(const at::Tensor &tensor, + c10::optional device) { + if (!device) { + device = torch::lazy::GetBackendDevice(tensor); + } + TORCH_CHECK(device); + torch::lazy::LazyTensorPtr lazy_tensor = + torch::lazy::GetLtcTensorOrCreateForWrappedNumber(tensor, *device); + if (lazy_tensor) { + return device_data_cast(lazy_tensor->GetIrValue()); + } + return nullptr; } -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.h b/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.h index 745be78c35d2..f8e5e317294a 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.h +++ b/projects/ltc/csrc/base_lazy_backend/utils/tensor_utils.h @@ -8,18 +8,21 @@ namespace torch { namespace lazy { -TORCH_API bool is_detach_copy(const torch::lazy::Node*); -TORCH_API bool is_detach_copy(const torch::lazy::Value&); +TORCH_API bool is_detach_copy(const torch::lazy::Node *); +TORCH_API bool is_detach_copy(const torch::lazy::Value &); -TORCH_API torch::lazy::Node* extract_non_detach_copy_node(torch::lazy::Node*); -TORCH_API const torch::lazy::Node* extract_non_detach_copy_node(const torch::lazy::Node*); +TORCH_API torch::lazy::Node *extract_non_detach_copy_node(torch::lazy::Node *); +TORCH_API const torch::lazy::Node * +extract_non_detach_copy_node(const torch::lazy::Node *); -TORCH_API torch::lazy::DeviceData* device_data_cast(torch::lazy::Node*); -TORCH_API const torch::lazy::DeviceData* device_data_cast(const torch::lazy::Node*); -TORCH_API torch::lazy::DeviceData* device_data_cast(const torch::lazy::Value& value); -TORCH_API torch::lazy::DeviceData* device_data_cast( - const at::Tensor& tensor, c10::optional device = c10::nullopt -); +TORCH_API torch::lazy::DeviceData *device_data_cast(torch::lazy::Node *); +TORCH_API const torch::lazy::DeviceData * +device_data_cast(const torch::lazy::Node *); +TORCH_API torch::lazy::DeviceData * +device_data_cast(const torch::lazy::Value &value); +TORCH_API torch::lazy::DeviceData *device_data_cast( + const at::Tensor &tensor, + c10::optional device = c10::nullopt); -} // namespace lazy -} // namespace torch +} // namespace lazy +} // namespace torch diff --git a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp index 4bcb9347b5aa..8708ff06a5a2 100644 --- a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp +++ b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/backend_impl.cpp @@ -30,7 +30,7 @@ namespace lazy { /// Returns true if a string begins with another. inline bool beginswith(const std::string& s, const std::string& t) { - return s.size() >= t.size() && s.compare(0, t.size(), t) == 0; + return s.size() >= t.size() && s.compare(0, t.size(), t) == 0; } struct ReferenceLazyBackendDeviceType : public BackendDeviceType { @@ -73,10 +73,8 @@ class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl { // Vendor backend specific lowering can be exec here before returning. for (const auto& instance : instances) { TORCH_CHECK( - instance->in_mark_step, - "Compile outside of mark step:\n", - GetComputationBackendText(instance) - ); + instance->in_mark_step, "Compile outside of mark step:\n", + GetComputationBackendText(instance)); // Store computation instance for external access after compilation. GetLatestComputation() = instance; } @@ -114,16 +112,17 @@ class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl { // Convert any lazy devices to cpu devices to ensure // that the values are actually computed if (node->outputs().size() == 1 && - node->output()->type()->kind() == - c10::TypeKind::DeviceObjType) { - auto value_sym = torch::jit::Symbol::attr("value"); - TORCH_CHECK(node->hasAttribute(value_sym), - "Expected node to have 'value' attribute."); - TORCH_CHECK(node->kindOf(value_sym) == torch::jit::AttributeKind::s, - "Expected 'value' attribute to be a string."); - if (beginswith(node->s(value_sym), "lazy")) { - node->s_(value_sym, "cpu"); - } + node->output()->type()->kind() == c10::TypeKind::DeviceObjType) { + auto value_sym = torch::jit::Symbol::attr("value"); + TORCH_CHECK( + node->hasAttribute(value_sym), + "Expected node to have 'value' attribute."); + TORCH_CHECK( + node->kindOf(value_sym) == torch::jit::AttributeKind::s, + "Expected 'value' attribute to be a string."); + if (beginswith(node->s(value_sym), "lazy")) { + node->s_(value_sym, "cpu"); + } } } @@ -132,7 +131,8 @@ class ReferenceLazyBackendImpl : public torch::lazy::TorchMlirBackendImpl { for (const auto& argument : arguments) { const auto mlir_data = std::static_pointer_cast(argument); - auto* info = dynamic_cast(mlir_data->mlir_info()); + auto* info = + dynamic_cast(mlir_data->mlir_info()); TORCH_CHECK(info); if (info->scalar.has_value()) { stack.emplace_back(info->scalar.value()); diff --git a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp index f4b8cd9ba579..2cbb6d6f16dc 100644 --- a/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp +++ b/projects/pt1/python/torch_mlir/csrc/reference_lazy_backend/reference_lazy_backend_pybind.cpp @@ -8,8 +8,8 @@ //===----------------------------------------------------------------------===// #include "torch/csrc/jit/python/pybind.h" -#include "torch/csrc/lazy/core/config.h" #include "torch/csrc/lazy/backend/backend_interface.h" +#include "torch/csrc/lazy/core/config.h" #include #include @@ -56,8 +56,8 @@ void Initialize() { } if (ir_debug) { - FLAGS_torch_lazy_ir_debug = true; - std::cout << "Enabled lazy tensor IR debugging." << std::endl; + FLAGS_torch_lazy_ir_debug = true; + std::cout << "Enabled lazy tensor IR debugging." << std::endl; } } @@ -82,15 +82,17 @@ PYBIND11_MODULE(_REFERENCE_LAZY_BACKEND, m) { torch::lazy::GetLatestComputation().get()); return py::cast(computation); }); - m.def("set_parameter_name", - [](const at::Tensor& tensor, const std::string& name) -> bool { - torch::lazy::DeviceData* ir_node = torch::lazy::device_data_cast(tensor); - if (ir_node) { - ir_node->SetName(name); - return true; - } - return false; - }); + m.def( + "set_parameter_name", + [](const at::Tensor& tensor, const std::string& name) -> bool { + torch::lazy::DeviceData* ir_node = + torch::lazy::device_data_cast(tensor); + if (ir_node) { + ir_node->SetName(name); + return true; + } + return false; + }); m.def("_initialize", []() { NoGilSection gil; Initialize(); From 1d6aca3823e33a026320dadfd4d680452758edfa Mon Sep 17 00:00:00 2001 From: Quinn Dawkins Date: Mon, 29 Jan 2024 13:29:51 -0500 Subject: [PATCH 133/283] Add .git-blame-ignore-revs to allow ignoring sweeping formatting changes (#2823) This allows the following command to be used to ignore sweeping formatting changes. ``` git blame --ignore-revs-file .git-blame-ignore-revs ``` --- .git-blame-ignore-revs | 11 +++++++++++ 1 file changed, 11 insertions(+) create mode 100644 .git-blame-ignore-revs diff --git a/.git-blame-ignore-revs b/.git-blame-ignore-revs new file mode 100644 index 000000000000..0fe0ceed72d3 --- /dev/null +++ b/.git-blame-ignore-revs @@ -0,0 +1,11 @@ +# This file contains the list of commits to exclude from 'git blame'. +# Such commits do not meaningfully contribute to git history, and include +# large-scale mechanical changes like code formatting style changes. +# +# To set this file as the default ignore file for 'git blame', run: +# ```shell +# git config blame.ignoreRevsFile .git-blame-ignore-revs +# ``` + +# Refresh clang-format +494089d53db4c183b3ba12e36f61ce1c7553984c From eff325abc3e7bba294e50d3c5c320a61bd62fab6 Mon Sep 17 00:00:00 2001 From: aldesilv <35313749+aldesilv@users.noreply.github.com> Date: Mon, 29 Jan 2024 22:14:48 -0800 Subject: [PATCH 134/283] OnnxToTorch ReduceMax lowering (#2768) Fixes https://github.com/nod-ai/SHARK-Turbine/issues/352 --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 103 ++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 64 +++++++++++ 2 files changed, 167 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 87f68375a593..3cba62d7691c 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -656,6 +656,109 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, operand, vAlpha, vScale, vInputScale); return success(); }); + patterns.onOp( + "ReduceMax", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + SmallVector operands; + int64_t keepDims, noop_with_empty_axes; + + if (binder.tensorOperandsList(operands) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + Value data = operands[0]; + + if (operands.size() == 1) { + if (noop_with_empty_axes == 0) { + MLIRContext *context = binder.op->getContext(); + auto rank = + data.getType().cast().getSizes().size(); + SmallVector dims; + for (int i = 0; i < rank; i++) { + dims.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i))); + } + Value dimsList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(context)), dims); + Value keepDimsConstInt = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); + Value keepDimsBool = rewriter.create( + binder.getLoc(), keepDimsConstInt); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, /*dim=*/dimsList, + /*keepdim=*/keepDimsBool); + } else { + rewriter.replaceOp(binder.op, data); + } + return success(); + } + + Value axes = operands[1]; + + SmallVector dimList; + + Torch::BaseTensorType axesType = + axes.getType().cast(); + SmallVector selectSizes; + selectSizes.push_back(1); + Type selectResultType = axesType.getWithSizesAndDtype( + llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); + Value noneVal = rewriter.create(binder.getLoc()); + + auto sizes = + dyn_cast(axes.getType()).getSizes(); + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + int64_t adjustmentInt = + cast(data.getType()).getSizes().size(); + Value adjustment = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + adjustmentInt)); + for (int i = 0; i < sizes[0]; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, axes, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + // deal with neg axis: if (axis < 0) axis += rank + Value isNegative = + rewriter.create(binder.getLoc(), dim, zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, adjustment); + Value finalDim = rewriter.create( + binder.getLoc(), dim, finalOffset); + dimList.push_back(finalDim); + } + + Value dimValueList = rewriter.create( + binder.getLoc(), + Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), + dimList); + + Value keepDimBool; + if (keepDims == 1) { + keepDimBool = + rewriter.create(binder.getLoc(), true); + } else { + keepDimBool = + rewriter.create(binder.getLoc(), false); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList, keepDimBool); + return success(); + }); patterns.onOp( "ReduceSum", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index a9f6098a26d2..97f698e5eb3e 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -624,6 +624,70 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // ----- +// CHECK-LABEL: func.func @test_reduce_max_keepdims_example +func.func @test_reduce_max_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[RANK:.*]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[SELECT_DIM0:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM0:.*]] = torch.aten.item %[[SELECT_DIM0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[LTZERO_0:.*]] = torch.aten.lt.int %[[ITEM0]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[ISNEG_0:.*]] = torch.aten.Int.bool %[[LTZERO_0]] : !torch.bool -> !torch.int + // CHECK: %[[ADJUSTMENT_0:.*]] = torch.aten.mul.int %[[ISNEG_0]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[FINAL_0:.*]] = torch.aten.add.int %[[ITEM0]], %[[ADJUSTMENT_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[SELECT_DIM1:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM1:.*]] = torch.aten.item %[[SELECT_DIM1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[LTZERO_1:.*]] = torch.aten.lt.int %[[ITEM1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[ISNEG_1:.*]] = torch.aten.Int.bool %[[LTZERO_1]] : !torch.bool -> !torch.int + // CHECK: %[[ADJUSTMENT_1:.*]] = torch.aten.mul.int %[[ISNEG_1]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[FINAL_1:.*]] = torch.aten.add.int %[[ITEM1]], %[[ADJUSTMENT_1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[FINAL_0]], %[[FINAL_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[KEEPDIMS:.*]] = torch.constant.bool true + // CHECK: torch.aten.amax %arg0, %[[DIMS]], %[[KEEPDIMS]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[3,1,1],f32> + %0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,1,1],f32> + return %0 : !torch.vtensor<[3,1,1],f32> + } + +// ----- + +// CHECK-LABEL: func.func @test_reduce_max_default_axes_keepdim_example +func.func @test_reduce_max_default_axes_keepdim_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT1_0:.*]] = torch.constant.int 1 + // CHECK: %[[KEEPDIMS:.*]] = torch.aten.Bool.int %[[INT1_0]] : !torch.int -> !torch.bool + // CHECK: torch.aten.amax %arg0, %[[DIMS]], %[[KEEPDIMS]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceMax"(%arg0) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1],f32> + } + +// ----- + +// CHECK-LABEL: func.func @test_reduce_max_do_not_keepdims_example + func.func @test_reduce_max_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[RANK:.*]] = torch.constant.int 3 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[SELECT_DIM:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT_DIM]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[LTZERO:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[ISNEG:.*]] = torch.aten.Int.bool %[[LTZERO]] : !torch.bool -> !torch.int + // CHECK: %[[ADJUSTMENT:.*]] = torch.aten.mul.int %[[ISNEG]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[FINAL:.*]] = torch.aten.add.int %[[ITEM]], %[[ADJUSTMENT]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[FINAL]] : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: torch.aten.amax %arg0, %[[DIMS]], %[[FALSE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[3,2],f32> + %0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> + } + +// ----- + // CHECK-LABEL: func.func @test_reduce_sum_default_axes_keepdims_example func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.+]] = torch.constant.none From e18fcebd3af1659ab3eeaa9db75534cf0fe03eaf Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 30 Jan 2024 16:42:18 +0530 Subject: [PATCH 135/283] [CI] Change Roll PyTorch runner (#2828) Signed-Off By: Vivek Khandelwal --- .github/workflows/RollPyTorch.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/RollPyTorch.yml b/.github/workflows/RollPyTorch.yml index 975b538c5d95..1c0f8f568728 100644 --- a/.github/workflows/RollPyTorch.yml +++ b/.github/workflows/RollPyTorch.yml @@ -9,7 +9,7 @@ on: jobs: build_linux: name: Manylinux Build - runs-on: a100 + runs-on: torch-mlir-cpubuilder-manylinux-x86-64 # Don't run this in everyone's forks. if: github.repository == 'llvm/torch-mlir' From 1e882f58035171ae2c9d2d1533fbf32569be7d12 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 30 Jan 2024 08:28:08 -0800 Subject: [PATCH 136/283] Additional information in error message (#2783) See change in test for what the new message looks like. --- .../csrc/jit_ir_importer/class_annotator.cpp | 22 ++++++++++++++++--- .../ivalue_import/annotations/arg-error.py | 5 ++++- 2 files changed, 23 insertions(+), 4 deletions(-) diff --git a/projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.cpp b/projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.cpp index b144e946ba5e..47f7a974c8e8 100644 --- a/projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.cpp +++ b/projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.cpp @@ -9,6 +9,7 @@ #include "class_annotator.h" +#include #include using namespace torch_mlir; @@ -150,11 +151,26 @@ ClassAnnotator::getOrCreateClassAnnotation(c10::ClassType *classType) { } static void fillArgAnnotations(MethodAnnotation &methodAnnotation, - std::vector argAnnotations, + const std::vector &argAnnotations, torch::jit::Function *function) { if (argAnnotations.size() != function->num_inputs()) { - throw std::invalid_argument("Arg annotations should have one entry per " - "function parameter (including self)."); + + std::ostringstream oss; + oss << "There must be one argument annotation per function parameter. " + << "Including 'self' the number of argument annotations is: " + << argAnnotations.size() + << ". The number of function parameters is: " << function->num_inputs() + << ". "; + const auto &args = function->getSchema().arguments(); + if (args.size() > 0) { + oss << "The function signature is ("; + oss << args[0]; + for (auto iter = args.begin() + 1; iter != args.end(); iter++) { + oss << ", " << *iter; + } + oss << ')' << '.'; + } + throw std::invalid_argument(oss.str()); } if (!methodAnnotation.argAnnotations.has_value()) { methodAnnotation.argAnnotations.emplace(function->num_inputs(), diff --git a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-error.py b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-error.py index 26eaa5bd0cb1..0979d04228b5 100644 --- a/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-error.py +++ b/projects/pt1/test/python/importer/jit_ir/ivalue_import/annotations/arg-error.py @@ -33,7 +33,10 @@ def forward(self, tensor): try: annotator.annotateArgs(class_type, ['forward'], [None]) except Exception as e: - # CHECK: Arg annotations should have one entry per function parameter (including self). + # CHECK: There must be one argument annotation per function parameter. + # CHECK-SAME: Including 'self' the number of argument annotations is: 1. + # CHECK-SAME: The number of function parameters is: 2. + # CHECK-SAME: The function signature is (__torch__.TestModule self, Tensor tensor) print(e) try: From 9d983161fc6c7c78eaaab9cf4a40115cffb416a4 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 30 Jan 2024 08:30:00 -0800 Subject: [PATCH 137/283] Describe how to get --debug and --debug-only flags in dev notes (#2793) Change should be visible : https://github.com/newling/torch-mlir/blob/docs_update/docs/development.md --- docs/development.md | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/docs/development.md b/docs/development.md index 7927bb39a35f..782058a63ea7 100644 --- a/docs/development.md +++ b/docs/development.md @@ -45,12 +45,12 @@ cmake -GNinja -Bbuild \ -DLLVM_TARGETS_TO_BUILD=host \ externals/llvm-project/llvm ``` -The following additional quality of life flags can be used to reduce build time: +#### Flags that can reduce build time: * Enabling clang on Linux ```shell -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ ``` -* Enabling ccache: +* Enabling ccache ```shell -DCMAKE_C_COMPILER_LAUNCHER=ccache -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ``` @@ -72,6 +72,14 @@ By default we download the latest version of libtorch. We have an experimental p -DLIBTORCH_VARIANT=shared # Set the variant of libtorch to build / link against. (`shared`|`static` and optionally `cxxabi11`) ``` +#### Flags to enable MLIR debugging: + +* Enabling `--debug` and `--debug-only` flags (see [MLIR docs](https://mlir.llvm.org/getting_started/Debugging/)) for the `torch-mlir-opt` tool +```shell + -DCMAKE_BUILD_TYPE=RelWithDebInfo \ # or =Debug + -DIREE_ENABLE_ASSERTIONS=ON \ +``` + ### Building against a pre-built LLVM If you have built llvm-project separately in the directory `$LLVM_INSTALL_DIR`, you can also build the project *out-of-tree* using the following command as template: From db67bc555ade76a26ea96d733050508e0a60822b Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 30 Jan 2024 09:01:42 -0800 Subject: [PATCH 138/283] Bump LLVM to llvm/llvm-project@70eb0e3 (#2827) --- externals/llvm-project | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/externals/llvm-project b/externals/llvm-project index 5fcf907b3435..70eb0e37a867 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 5fcf907b34355980f77d7665a175b05fea7a6b7b +Subproject commit 70eb0e37a86747f9266e4c8380baa89746f5e23b From 4c557847bdd44bdfff90fa6d56089529ef065843 Mon Sep 17 00:00:00 2001 From: Aaron St George Date: Tue, 30 Jan 2024 09:45:51 -0800 Subject: [PATCH 139/283] Don't fold `aten.detach` if result isn't same type as input. (#2824) We were seeing some assertion failures after some checks around folders were tightened up in LLVM: https://github.com/llvm/llvm-project/pull/75887 . This PR essentially moves the logic that used to be applied at the LLVM level into the folder, which seems to be the suggested fix. I'm not sure if the IR that caused issues for us _should_ be valid? ``` %1 = torch.aten.detach %arg0 : !torch.tensor<[1],f32> -> !torch.tensor ``` A better fix might be to create a verifier ensuring the result of `aten.detach` has the same type as its operand. --------- Co-authored-by: aaron-stgeorge --- lib/Dialect/Torch/IR/TorchOps.cpp | 6 +++++- test/Dialect/Torch/canonicalize.mlir | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 4af9bcfc1e3b..4aacd8d7693e 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1465,7 +1465,11 @@ static OpFoldResult intComparatorFoldHelper(OpTy op, // AtenDetachOp //===----------------------------------------------------------------------===// -OpFoldResult AtenDetachOp::fold(FoldAdaptor adaptor) { return getSelf(); } +OpFoldResult AtenDetachOp::fold(FoldAdaptor adaptor) { + if (getSelf().getType() != getResult().getType()) + return {}; + return getSelf(); +} //===----------------------------------------------------------------------===// // AtenNeIntOp diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 9172d4642759..3cf82d9ed6a7 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2146,3 +2146,10 @@ func.func @torch.aten.masked_fill.Tensor$canonicalize(%arg0: !torch.vtensor<[?,? %1 = torch.aten.masked_fill.Tensor %arg0, %arg1, %0 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],i1>, !torch.vtensor<[],f32> -> !torch.vtensor<[?,?],f32> return %1 : !torch.vtensor<[?,?],f32> } + +// CHECK-LABEL: func.func @torch.aten.detach$canonicalize +// CHECK-NEXT: torch.aten.detach +func.func @torch.aten.detach$canonicalize(%arg0: !torch.tensor<[1],f32>) -> !torch.tensor { + %1 = torch.aten.detach %arg0 : !torch.tensor<[1],f32> -> !torch.tensor + return %1 : !torch.tensor +} From 25a5a22cbd35de8a9a639f5c1eb0d6ed8c11e6e4 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 30 Jan 2024 13:46:47 -0800 Subject: [PATCH 140/283] [torch] Support `torch.convolution` quantized lowering to `linalg` (#2811) Linalg has quantized specific operations. We can lower to these operations when there is a known zeropoint and scale operations. This allows the `convolution` to occur with lower bitwidth's, improving the overall performance. --- .../Conversion/TorchToLinalg/Utils.h | 2 +- lib/Conversion/TorchToLinalg/Linear.cpp | 382 ++++++++++++------ .../TorchToLinalg/Uncategorized.cpp | 28 +- lib/Conversion/TorchToLinalg/Utils.cpp | 6 +- .../Torch/Transforms/FuseQuantizedOps.cpp | 32 +- projects/pt1/e2e_testing/xfail_sets.py | 2 + .../torch_mlir_e2e_test/test_suite/conv.py | 35 ++ test/Dialect/Torch/fuse-quantized-ops.mlir | 28 +- 8 files changed, 362 insertions(+), 153 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h index 7c9257075824..5d2095f04f14 100644 --- a/include/torch-mlir/Conversion/TorchToLinalg/Utils.h +++ b/include/torch-mlir/Conversion/TorchToLinalg/Utils.h @@ -36,7 +36,7 @@ Value getZeroPaddedTensor(Operation *op, OpBuilder &b, Value &input, // padding value is zero. Value getDynamicZeroPaddedTensor(Operation *op, OpBuilder &b, Value &input, SmallVectorImpl &padding, - int unpaddedDims = 0); + int unpaddedDims = 0, Value pad = {}); // Helper function to caculate the output tensor dims for convolution-like ops. // Along each dim: diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index c0585df0bcd7..4523febb9b9d 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -36,6 +36,27 @@ static void getZeroPoint(Value value, Value &zeropoint) { } } +static Value transposeValue(Location loc, Value value, ArrayRef perms, + PatternRewriter &rewriter) { + auto valueTy = value.getType().cast(); + auto inShape = valueTy.getShape(); + llvm::SmallVector outShape; + llvm::SmallVector dynDims; + for (size_t i = 0; i < perms.size(); ++i) { + outShape.push_back(inShape[perms[i]]); + if (ShapedType::isDynamic(inShape[perms[i]])) { + dynDims.push_back(rewriter.create(loc, value, perms[i])); + } + } + + auto outTy = RankedTensorType::get(outShape, valueTy.getElementType()); + Value empty = rewriter.create(loc, outTy, dynDims); + Value transpose = + rewriter.create(loc, value, empty, perms) + ->getResult(0); + return transpose; +} + class ConvertAtenMmOp : public OpConversionPattern { public: using OpConversionPattern::OpConversionPattern; @@ -600,19 +621,62 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { MLIRContext *context = op->getContext(); Value input = adaptor.getInput(); /* in form of N*C*H*W */ Value weight = adaptor.getWeight(); /* in form of F*C*H*W */ + Value bias = adaptor.getBias(); + auto resultTy = op.getType().cast(); + + Value inputZp, weightZp; + if (auto make = op.getInput() + .getDefiningOp()) { + input = make.getSelf(); + inputZp = make.getZeroPoint(); + input = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(input.getType()), input); + inputZp = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(inputZp.getType()), + inputZp); + } + + if (auto make = op.getWeight() + .getDefiningOp()) { + weight = make.getSelf(); + weightZp = make.getZeroPoint(); + + weight = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(weight.getType()), weight); + weightZp = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(weightZp.getType()), + weightZp); + } + + if (static_cast(inputZp) != static_cast(weightZp)) { + return rewriter.notifyMatchFailure( + op, "lhs and rhs of convolution must either be both int or fp"); + } + + if (inputZp && weightZp) { + auto biasDTy = bias.getType().cast().getElementType(); + if (!biasDTy.isInteger(32)) { + return rewriter.notifyMatchFailure( + op, "quantized result ty should be i32 accumulator"); + } + } bool transposed = true; if (!matchPattern(op.getTransposed(), m_TorchConstantBool(&transposed))) return rewriter.notifyMatchFailure( op, "unimplemented: only constant transposed supported"); - Type elementType = - input.getType().cast().getElementType(); - if (!elementType.isa()) - return op.emitError("unimplemented: non-floating point type"); + auto inputDTy = input.getType().cast().getElementType(); + auto weightDTy = weight.getType().cast().getElementType(); + auto resultDTy = resultTy.toBuiltinTensor().getElementType(); + + if (!inputDTy.isa() || + !weightDTy.isa() || + !resultDTy.isa()) + return op.emitError("unimplemented: non-fp not-int type"); size_t inRank = input.getType().cast().getRank(); - size_t numSpacialDims = inRank - 2; - if (numSpacialDims < 1 || numSpacialDims > 3) + size_t numSpatialDims = inRank - 2; + if (numSpatialDims < 1 || numSpatialDims > 3) return rewriter.notifyMatchFailure( op, "unimplemented: only 1d-3d convolution currently supported"); @@ -684,6 +748,12 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { SmallVector outDims{inBatch, weightBatch}; Value paddedInput; if (transposed) { + if (!inputDTy.isa() || + !weightDTy.isa() || + !resultDTy.isa()) + return rewriter.notifyMatchFailure( + op, "transpose does not support non-fp type yet"); + Value c0 = rewriter.create(loc, rewriter.getIndexAttr(0)); Value c1 = @@ -696,7 +766,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { std::iter_swap(weightInitDims.begin(), weightInitDims.begin() + 1); outDims[1] = weightInitDims[0]; Value weightInitTensor = - createZeroInitTensor(rewriter, loc, weightInitDims, elementType); + createZeroInitTensor(rewriter, loc, weightInitDims, weightDTy); SmallVector iteratorTypes( inRank, utils::IteratorType::parallel); SmallVector indexingMaps{ @@ -729,7 +799,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { SmallVector outerSizes{inBatch, inChannels}; SmallVector innerSizes{inBatch, inChannels}; SmallVector offsets{c0, c0}; - for (size_t i = 0; i < numSpacialDims; i++) { + for (size_t i = 0; i < numSpatialDims; i++) { Value innerSize = rewriter.create(loc, inDims[i], c1); innerSize = rewriter.create( loc, innerSize, castIntToIndex(rewriter, loc, strideIntValues[i])); @@ -753,7 +823,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { // Allocate padded input tensor Value initTensor = - createZeroInitTensor(rewriter, loc, outerSizes, elementType); + createZeroInitTensor(rewriter, loc, outerSizes, inputDTy); // Insert input into allocated tensor SmallVector strideIndexValues{c1, c1}; @@ -766,7 +836,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { initTensor, offsets, insertSizes, strideIndexValues); // Calculate output dims - for (size_t i = 0; i < numSpacialDims; i++) + for (size_t i = 0; i < numSpatialDims; i++) outDims.push_back(torch_to_linalg::getOutputDimForConvTransposeOps( rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i], castIndexToInt(weightDims[i]), strideIntValues[i], @@ -774,36 +844,57 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { // Set stride to 1 strideInts.clear(); - strideInts.append(numSpacialDims, 1); - + strideInts.append(numSpatialDims, 1); } else { + Value pad = inputZp; + if (!pad) { + if (isa(inputDTy)) + pad = rewriter.create( + op.getLoc(), rewriter.getFloatAttr(inputDTy, 0.0)); + if (isa(inputDTy)) + pad = rewriter.create( + op.getLoc(), rewriter.getIntegerAttr(inputDTy, 0)); + } + + if (pad.getType() != inputDTy) { + if (isa(inputDTy)) + pad = rewriter.create(op.getLoc(), inputDTy, pad); + + if (isa(inputDTy)) + pad = rewriter.create(op.getLoc(), inputDTy, pad); + } + // Pad input paddedInput = torch_to_linalg::getDynamicZeroPaddedTensor( - op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2); + op, rewriter, input, paddingIntValues, /*unpaddedDims=*/2, pad); // Calculate output dims - for (size_t i = 0; i < numSpacialDims; i++) + for (size_t i = 0; i < numSpatialDims; i++) outDims.push_back(torch_to_linalg::getOutputDimForConvOps( rewriter, loc, inDims[i], paddingIntValues[i], dilationIntValues[i], castIndexToInt(weightDims[i]), strideIntValues[i])); } Value initTensor = rewriter.create( - loc, getAsOpFoldResult(outDims), elementType); + loc, getAsOpFoldResult(outDims), resultDTy); - Value bias = adaptor.getBias(); Value outputTensor; if (bias.getType().isa()) { - Value c0float = rewriter.create( - loc, FloatAttr::get(elementType, 0.0)); - outputTensor = rewriter.create(loc, c0float, initTensor) + Value c0; + if (resultDTy.isa()) { + c0 = rewriter.create( + loc, FloatAttr::get(resultDTy, 0.0)); + } else if (resultDTy.isa()) { + c0 = rewriter.create( + loc, IntegerAttr::get(resultDTy, 0)); + } + outputTensor = rewriter.create(loc, c0, initTensor) .getResult(0); + } else { auto biasType = bias.getType().cast(); if (biasType.getRank() != 1) return rewriter.notifyMatchFailure(op, "expect bias to be rank 1"); - if (elementType != biasType.getElementType()) - return rewriter.notifyMatchFailure(op, "unimplemented: type promotion"); auto resultRank = initTensor.getType().cast().getRank(); SmallVector indexingMaps = { @@ -843,16 +934,16 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { weightSliceSizes.append(weightDims); Value conv; - // the code so far is able to respect all numSpacialDims - // the code below this point is numSpacialDims specific and groupSize + // the code so far is able to respect all numSpatialDims + // the code below this point is numSpatialDims specific and groupSize // specific // TODO: factor out the above code into a helper function, and then separate // convolution into: // - grouped 1d-3d + // - grouped 1d-3d (quantized) // - ungrouped 1d-3d - if (groupSize == 1) { - // TODO: 3D case - switch (numSpacialDims) { + if (groupSize == 1 && !inputZp && !weightZp) { + switch (numSpatialDims) { case 1: conv = rewriter .create( @@ -884,113 +975,170 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, conv); return success(); - } else { - if (numSpacialDims != 2) - return rewriter.notifyMatchFailure( - op, "unimplemented: only 2D grouped convolution supported"); - - // Special depthwise case - auto inShape = makeShapeTorchCompatible( - input.getType().cast().getShape()); - auto weightShape = makeShapeTorchCompatible( - weight.getType().cast().getShape()); - if (weightShape[0] != kUnknownSize && inShape[1] == groupSize && - weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) { - // Collapse weight shape - SmallVector collapsedDims = {{0, 1}, {2}, {3}}; - SmallVector collapsedShape{ - (weightShape[0] == kUnknownSize ? kUnknownSize - : weightShape[0] * weightShape[1]), - weightShape[2], weightShape[3]}; - // TODO: audit possibility of sparsity on this tensor - Type collapsedType = RankedTensorType::get( - makeShapeLLVMCompatible(collapsedShape), elementType); - Value collapsedWeight = rewriter.create( - loc, collapsedType, weight, collapsedDims); + } + + if (groupSize == 1 && inputZp && weightZp) { + // The quantized version uses a different channel ordering so we need to + // permute the tensors in order to use the existing path. We should + // eventually directly support this channel ordering. + llvm::SmallVector inPerms, weightPerms; + inPerms.push_back(0); // N stays at the front for input. + // Then we expect the spatial dimensions + for (size_t i = 0; i < numSpatialDims; ++i) { + inPerms.push_back(i + 2); + weightPerms.push_back(i + 2); + } + inPerms.push_back(1); + weightPerms.append({1, 0}); + paddedInput = transposeValue(op.getLoc(), paddedInput, inPerms, rewriter); + weight = transposeValue(op.getLoc(), weight, weightPerms, rewriter); + outputTensor = + transposeValue(op.getLoc(), outputTensor, inPerms, rewriter); + + switch (numSpatialDims) { + case 2: conv = rewriter - .create( + .create( loc, outputTensor.getType(), - ValueRange{paddedInput, collapsedWeight}, outputTensor, - stridesAttr, dilationAttr) + ValueRange{paddedInput, weight, inputZp, weightZp}, + outputTensor, stridesAttr, dilationAttr) .getResult(0); + break; + case 3: + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, weight, inputZp, weightZp}, + outputTensor, stridesAttr, dilationAttr) + .getResult(0); + break; + default: + return rewriter.notifyMatchFailure( + op, "unimplemented: only 1D, 2D, and 3D convolution supported"); + }; - Type newResultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, newResultType, conv); - return success(); + llvm::SmallVector outPerms; + outPerms.push_back(0); + outPerms.push_back(inPerms.size() - 1); + for (size_t i = 0; i < numSpatialDims; ++i) { + outPerms.push_back(i + 1); } + conv = transposeValue(op.getLoc(), conv, outPerms, rewriter); - // Grouped case, use the grouped conv linalg op - auto expandGroups = [&](Value tensor, size_t dim) { - auto inType = tensor.getType().cast(); - auto inShape = makeShapeTorchCompatible(inType.getShape()); - - SmallVector outShape; - for (auto i = 0; i < (long)inShape.size(); i++) { - if (i == 1) { - outShape.push_back(groupSize); - } - if (i == (long)dim) { - outShape.push_back(inShape[i] == kUnknownSize - ? kUnknownSize - : inShape[i] / groupSize); - } else { - outShape.push_back(inShape[i]); - } - } + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, conv); + return success(); + } + + if (inputZp || weightZp) + return rewriter.notifyMatchFailure( + op, "unimplemented: quantized grouped convolutions"); + + if (numSpatialDims != 2) + return rewriter.notifyMatchFailure( + op, "unimplemented: only 2D grouped convolution supported"); + + // Special depthwise case + auto inShape = makeShapeTorchCompatible( + input.getType().cast().getShape()); + auto weightShape = makeShapeTorchCompatible( + weight.getType().cast().getShape()); + if (weightShape[0] != kUnknownSize && inShape[1] == groupSize && + weightShape[0] % inShape[1] == 0 && weightShape[1] == 1) { + // Collapse weight shape + SmallVector collapsedDims = {{0, 1}, {2}, {3}}; + SmallVector collapsedShape{ + (weightShape[0] == kUnknownSize ? kUnknownSize + : weightShape[0] * weightShape[1]), + weightShape[2], weightShape[3]}; + Type collapsedType = RankedTensorType::get( + makeShapeLLVMCompatible(collapsedShape), weightDTy); + Value collapsedWeight = rewriter.create( + loc, collapsedType, weight, collapsedDims); - SmallVector indices; - for (auto i = 0; i <= (long)inShape.size(); i++) { - if (i == (long)dim) { - indices.push_back({i, ++i}); - continue; - } - indices.push_back({i}); + conv = rewriter + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, collapsedWeight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); + + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, conv); + return success(); + } + + // Grouped case, use the grouped conv linalg op + auto expandGroups = [&](Value tensor, size_t dim) { + auto inType = tensor.getType().cast(); + auto inShape = makeShapeTorchCompatible(inType.getShape()); + + SmallVector outShape; + for (auto i = 0; i < (long)inShape.size(); i++) { + if (i == 1) { + outShape.push_back(groupSize); + } + if (i == (long)dim) { + outShape.push_back(inShape[i] == kUnknownSize + ? kUnknownSize + : inShape[i] / groupSize); + } else { + outShape.push_back(inShape[i]); } + } - auto retType = inType.clone(makeShapeLLVMCompatible(outShape)); - return rewriter.create(loc, retType, tensor, - indices); - }; + SmallVector indices; + for (auto i = 0; i <= (long)inShape.size(); i++) { + if (i == (long)dim) { + indices.push_back({i, ++i}); + continue; + } + indices.push_back({i}); + } - // expand F,C,H,W -> G,F/G,C,H,W - auto expandWeight = [&](Value tensor) { - auto inType = tensor.getType().cast(); - auto inShape = makeShapeTorchCompatible(inType.getShape()); + auto retType = inType.clone(makeShapeLLVMCompatible(outShape)); + return rewriter.create(loc, retType, tensor, + indices); + }; - SmallVector outShape{ - groupSize, (inShape[0] == kUnknownSize ? kUnknownSize - : inShape[0] / groupSize)}; - outShape.append(inShape.begin() + 1, inShape.end()); + // expand F,C,H,W -> G,F/G,C,H,W + auto expandWeight = [&](Value tensor) { + auto inType = tensor.getType().cast(); + auto inShape = makeShapeTorchCompatible(inType.getShape()); - SmallVector indices{{0, 1}}; - for (auto i = 2; i <= (long)inShape.size(); i++) - indices.push_back({i}); + SmallVector outShape{ + groupSize, + (inShape[0] == kUnknownSize ? kUnknownSize : inShape[0] / groupSize)}; + outShape.append(inShape.begin() + 1, inShape.end()); - auto retType = inType.clone(makeShapeLLVMCompatible(outShape)); - return rewriter.create(loc, retType, tensor, - indices); - }; + SmallVector indices{{0, 1}}; + for (auto i = 2; i <= (long)inShape.size(); i++) + indices.push_back({i}); - Value paddedInputExpanded = expandGroups(paddedInput, 1); - Value weightExpanded = expandWeight(weight); - auto expandOutputTensor = expandGroups(outputTensor, 1); + auto retType = inType.clone(makeShapeLLVMCompatible(outShape)); + return rewriter.create(loc, retType, tensor, + indices); + }; - // TODO: add 1D and 3D case - conv = rewriter - .create( - loc, expandOutputTensor.getResultType(), - ValueRange{paddedInputExpanded, weightExpanded}, - expandOutputTensor.getResult(), stridesAttr, dilationAttr) - .getResult(0); - - conv = rewriter.create( - loc, outputTensor.getType(), conv, - expandOutputTensor.getReassociationIndices()); - Type newResultType = getTypeConverter()->convertType(op.getType()); - rewriter.replaceOpWithNewOp(op, newResultType, conv); - return success(); - } + Value paddedInputExpanded = expandGroups(paddedInput, 1); + Value weightExpanded = expandWeight(weight); + auto expandOutputTensor = expandGroups(outputTensor, 1); + + // TODO: add 1D and 3D case + conv = rewriter + .create( + loc, expandOutputTensor.getResultType(), + ValueRange{paddedInputExpanded, weightExpanded}, + expandOutputTensor.getResult(), stridesAttr, dilationAttr) + .getResult(0); + + conv = rewriter.create( + loc, outputTensor.getType(), conv, + expandOutputTensor.getReassociationIndices()); + Type newResultType = getTypeConverter()->convertType(op.getType()); + rewriter.replaceOpWithNewOp(op, newResultType, conv); + return success(); } }; } // namespace diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 54317979353d..794b755998fe 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1342,11 +1342,20 @@ static Value createLinalgPayloadCalculationForElementwiseOp( auto valueTy = value.getType(); auto qtensor = op->getOperand(0); auto qtensorTy = qtensor.getType().cast().getDtype(); - auto makeQTensor = - qtensor.getDefiningOp(); - if (!makeQTensor) { - op->emitWarning( - "unimplemented: dequantizing tensor of unknown scale / zero-point"); + + Value zp, scale; + if (auto makeQTensor = + qtensor.getDefiningOp()) { + zp = makeQTensor.getZeroPoint(); + scale = makeQTensor.getScale(); + } + + if (auto quant = qtensor.getDefiningOp()) { + zp = quant.getZeroPoint(); + scale = quant.getScale(); + } + + if (!zp || !scale) { return nullptr; } @@ -1362,10 +1371,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } } - Value zp = makeQTensor.getZeroPoint(); zp = converter->materializeTargetConversion( - b, loc, converter->convertType(zp.getType()), - makeQTensor.getZeroPoint()); + b, loc, converter->convertType(zp.getType()), zp); auto zpTy = zp.getType(); if (zpTy != outIntTy) { @@ -1380,10 +1387,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( value = b.create(loc, outFpTy, value); } - Value scale = makeQTensor.getScale(); scale = converter->materializeTargetConversion( - b, loc, converter->convertType(scale.getType()), - makeQTensor.getScale()); + b, loc, converter->convertType(scale.getType()), scale); if (scale.getType() != value.getType()) { scale = b.create(loc, value.getType(), scale); } @@ -2233,7 +2238,6 @@ class ConvertDequantizePerChannel auto qoperand = op.getOperand(); auto make = qoperand.getDefiningOp(); if (!make) { - llvm::errs() << "Did not find make per channel\n"; return rewriter.notifyMatchFailure(op, "did not find per channel qint"); } diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 0d62010d7b55..20b32cd1fe73 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -70,7 +70,7 @@ Value torch_to_linalg::getZeroPaddedTensor( // padding value is zero. Value torch_to_linalg::getDynamicZeroPaddedTensor( Operation *op, OpBuilder &b, Value &input, SmallVectorImpl &padding, - int unpaddedDims) { + int unpaddedDims, Value pad) { assert(input.getType().isa() && "input must be RankedTensorType"); unsigned int inRank = input.getType().cast().getRank(); @@ -93,12 +93,10 @@ Value torch_to_linalg::getDynamicZeroPaddedTensor( SmallVector(inRank, kUnknownSize))), elementType); - Value cf0 = - b.create(loc, b.getFloatAttr(elementType, 0.0)); SmallVector paddingValues = getAsOpFoldResult(paddingIncludingUnchanged); return b.create(loc, inputType, input, /*low=*/paddingValues, - /*high=*/paddingValues, cf0); + /*high=*/paddingValues, pad); } Value torch_to_linalg::getOutputDimForConvOps(OpBuilder &b, Location loc, diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index 85dadb755112..15d5ec105ed4 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -30,7 +30,7 @@ class QuantizeOperands : public OpRewritePattern { llvm::SmallVector operands(op->getOperands()); bool dequanted = false; - for (auto &operand : operands) { + auto f = [&dequanted](Value operand) { if (auto dequant = operand.getDefiningOp()) { operand = dequant.getOperand(); dequanted = true; @@ -39,7 +39,11 @@ class QuantizeOperands : public OpRewritePattern { operand = dequant.getOperand(); dequanted = true; } - } + return operand; + }; + + operands[0] = f(operands[0]); + operands[1] = f(operands[1]); if (!dequanted) { return rewriter.notifyMatchFailure(op, "no dequantizations found"); @@ -77,6 +81,7 @@ template class QuantizeBias : public OpRewritePattern { if (!rhsScale || !lhsScale) return failure(); + auto resultTy = cast(op.getType()); auto biasTy = bias.getType().cast(); auto biasETy = biasTy.getOptionalDtype(); if (!biasETy || !isa(biasETy)) @@ -95,9 +100,27 @@ template class QuantizeBias : public OpRewritePattern { Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), qi32Ty); bias = rewriter.create( op.getLoc(), newBiasTy, bias, biasScale, zero, dtype); + bias = rewriter.create( + op.getLoc(), + rewriter.getType( + biasTy.getOptionalSizes(), + rewriter.getIntegerType(32, IntegerType::Signed)), + bias); operands[2] = bias; - rewriter.replaceOpWithNewOp(op, op.getType(), operands); + + auto convTy = rewriter.getType( + resultTy.getOptionalSizes(), + rewriter.getIntegerType(32, IntegerType::Signed)); + auto conv = rewriter.create(op.getLoc(), convTy, operands); + + auto convQTy = + rewriter.getType(resultTy.getOptionalSizes(), qi32Ty); + auto makeOut = rewriter.create( + op.getLoc(), convQTy, conv, biasScale, zero); + rewriter.replaceOpWithNewOp(op, op.getType(), + makeOut); + return success(); } }; @@ -151,7 +174,7 @@ class QuantizeAccumulator : public OpRewritePattern { rewriter.getType(resultTy.getOptionalSizes(), qi32Ty); auto conv = rewriter.create(op.getLoc(), newResultTy, operands); - // Attach the quantize information to the resulting quint32: + // Attach the quantize information to the resulting qint32: auto intReprTy = rewriter.getType( resultTy.getOptionalSizes(), rewriter.getIntegerType(32, IntegerType::Signed)); @@ -194,7 +217,6 @@ class FuseQuantizedOpsPass : public FuseQuantizedOpsBase { RemoveUnused, RemoveUnused, QuantizeOperands, QuantizeOperands, - QuantizeAccumulator, QuantizeAccumulator, QuantizeBias>( context); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index f43c325069ce..d6fe18809b5f 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -317,6 +317,7 @@ "ElementwiseDequantizePerTensorModule_basic", "ElementwiseQuantizePerTensorModule_basic", "AtenMmQuint8_basic", + "Conv2dQInt8Module_basic", # Dynamo not supporting conv_tbc "ConvTbcModule_basic", @@ -1541,4 +1542,5 @@ "ElementwiseBitwiseAndScalarInt64Module_basic", "ElementwiseBitwiseAndScalarInt32Module_basic", "ElementwiseBitwiseAndScalarInt8Module_basic", + "Conv2dQInt8Module_basic", } diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index f75e17a4f6cd..b12424cbb7b2 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -857,3 +857,38 @@ def forward(self, x, weight, bias): @register_test_case(module_factory=lambda: ConvTbcModule()) def ConvTbcModule_basic(module, tu: TestUtils): module.forward(tu.rand(9, 4, 5), tu.rand(3, 5, 6), tu.rand(6)) + +class Conv2dQInt8Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.int8, True), + ([-1, -1, -1, -1], torch.int8, True), + ([-1], torch.float, True), + ]) + def forward(self, inputVec, weight, bias): + inputVec = torch._make_per_tensor_quantized_tensor(inputVec, 0.01, 7) + inputVec = torch.dequantize(inputVec) + + weight = torch._make_per_tensor_quantized_tensor(weight, 0.01, 3) + weight = torch.dequantize(weight) + + bias = torch.quantize_per_tensor(bias, 0.0001, 0, torch.qint32) + bias = torch.dequantize(bias) + + return torch.ops.aten.conv2d(inputVec, + weight, + bias=bias, + stride=[1, 1], + padding=[0, 0], + dilation=[1, 1], + groups=1) +@register_test_case(module_factory=lambda: Conv2dQInt8Module()) +def Conv2dQInt8Module_basic(module, tu: TestUtils): + inputVec = tu.randint(2, 4, 7, 8, low=-128, high=127).to(torch.int8) + weight = tu.randint(3, 4, 3, 2, low=-128, high=127).to(torch.int8) + bias = torch.rand(3) + module.forward(inputVec, weight, bias) diff --git a/test/Dialect/Torch/fuse-quantized-ops.mlir b/test/Dialect/Torch/fuse-quantized-ops.mlir index c62a0d13d9cf..1aaeb9ce1cd8 100644 --- a/test/Dialect/Torch/fuse-quantized-ops.mlir +++ b/test/Dialect/Torch/fuse-quantized-ops.mlir @@ -43,20 +43,20 @@ func.func @convolution(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtens %15 = torch.prim.ListConstruct %zero, %zero : (!torch.int, !torch.int) -> !torch.list %16 = torch.aten.convolution %7, %13, %arg2, %14, %15, %14, %false, %15, %one : !torch.vtensor<[1,3,8,8],f32>, !torch.vtensor<[3,3,2,2],f32>, !torch.vtensor<[3],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],f32> - // CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0 - // CHECK-DAG: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01 - // CHECK-DAG: %[[DTYPE:.+]] = torch.constant.int 14 - // CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01 - // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false - // CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1 - // CHECK-DAG: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> - // CHECK-DAG: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8> - // CHECK-DAG: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list - // CHECK-DAG: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list - // CHECK-DAG: %[[QBIAS:.+]] = torch.aten.quantize_per_tensor %arg2, %[[SCALEO]], %[[ZERO]], %[[DTYPE]] : !torch.vtensor<[3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[3],!torch.qint32> - // CHECK-DAG: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[QBIAS]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.vtensor<[3],!torch.qint32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32> - // CHECK-DAG: %[[INT:.+]] = torch.aten.int_repr %[[CONV]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],si32> - // CHECK-DAG: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[INT]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32> + // CHECK: %[[DTYPE:.+]] = torch.constant.int 14 + // CHECK: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01 + // CHECK: %[[HALF:.+]] = torch.constant.float 5.000000e-01 + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK: %[[ONE:.+]] = torch.constant.int 1 + // CHECK: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> + // CHECK: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8> + // CHECK: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[QBIAS:.+]] = torch.aten.quantize_per_tensor %arg2, %[[SCALEO]], %[[ZERO]], %[[DTYPE]] : !torch.vtensor<[3],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[3],!torch.qint32> + // CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QBIAS]] : !torch.vtensor<[3],!torch.qint32> -> !torch.vtensor<[3],si32> + // CHECK: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[INT]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.vtensor<[3],si32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],si32> + // CHECK: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32> // CHECK: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32> return %16 : !torch.vtensor<[1,3,7,7],f32> } From d778950f455ff479875de78f6b421a3e9adfd7a9 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Wed, 31 Jan 2024 09:43:21 +0800 Subject: [PATCH 141/283] [Torch Dialect] add fold pattern for aten.clone (#2804) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + lib/Conversion/TorchToStablehlo/Basic.cpp | 1 - lib/Dialect/Torch/IR/TorchOps.cpp | 13 ++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 1 + .../build_tools/torch_ods_gen.py | 2 +- .../torch_mlir_e2e_test/test_suite/basic.py | 20 +++++++++++++++++++ test/Conversion/TorchToStablehlo/basic.mlir | 15 -------------- 7 files changed, 36 insertions(+), 17 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c09900ce8ecc..81ee9844af30 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -9101,6 +9101,7 @@ def Torch_AtenCloneOp : Torch_Op<"aten.clone", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; } def Torch_AtenLiftFreshCopyOp : Torch_Op<"aten.lift_fresh_copy", [ diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 00c9fcd7b88f..33db9ac9ee54 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1763,7 +1763,6 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( #define INSERT_UNARY_PATTERN(AtenOp, StablehloOp) \ target.addIllegalOp(); \ patterns.add>(typeConverter, context) - INSERT_UNARY_PATTERN(AtenCloneOp, stablehlo::ConvertOp); INSERT_UNARY_PATTERN(AtenNegOp, stablehlo::NegOp); INSERT_UNARY_PATTERN(AtenLogicalNotOp, stablehlo::NotOp); INSERT_UNARY_PATTERN(AtenBitwiseNotOp, stablehlo::NotOp); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 4aacd8d7693e..5877a35495e0 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1662,6 +1662,19 @@ void AtenMaskedFillTensorOp::getCanonicalizationPatterns( }); } +//===----------------------------------------------------------------------===// +// AtenCloneOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenCloneOp::fold(FoldAdaptor adaptor) { + // note: memory_format would be ignored + if (llvm::dyn_cast(getSelf().getType())) { + // self should have value semantics + return getSelf(); + } + return {}; +} + //===----------------------------------------------------------------------===// // AtenSortIntOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index d6fe18809b5f..7c1cc16261d0 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1021,6 +1021,7 @@ "BroadcastZeroRankInputStaticModule_basic", "BucketizeTensorStaticFloatModule_basic", "BucketizeTensorStaticModule_basic", + "CloneModule_basic", "ChunkListUnpackUneven_Module_basic", "ChunkListUnpack_Module_basic", "ConstantBoolParameterModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 3b930c20e79d..a329c1ae01a4 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -592,7 +592,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::one_hot : (Tensor, int) -> (Tensor)") emit("aten::einsum : (str, Tensor[], int[]?) -> (Tensor)") emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)") - emit("aten::clone : (Tensor, int?) -> (Tensor)") + emit("aten::clone : (Tensor, int?) -> (Tensor)", has_folder=True) emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)") emit("aten::contiguous : (Tensor, int) -> (Tensor)") emit_with_mutating_variants("aten::copy : (Tensor, Tensor, bool) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 51deffb6175a..91c3112135d0 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -4994,3 +4994,23 @@ def forward(self, x): @register_test_case(module_factory=lambda: IscloseStaticModuleTrue()) def IscloseStaticModuleTrue_basic(module, tu: TestUtils): module.forward(torch.ones(5, 5)) + + +# ============================================================================== + +class CloneModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([5, 5], torch.float32, True), + ]) + def forward(self, x): + return torch.ops.aten.clone(x) + +@register_test_case(module_factory=lambda: CloneModule()) +def CloneModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 5)) diff --git a/test/Conversion/TorchToStablehlo/basic.mlir b/test/Conversion/TorchToStablehlo/basic.mlir index e0ab6bf1502b..b502d3ffcce9 100644 --- a/test/Conversion/TorchToStablehlo/basic.mlir +++ b/test/Conversion/TorchToStablehlo/basic.mlir @@ -1,21 +1,6 @@ // RUN: torch-mlir-opt <%s -convert-torch-to-stablehlo -split-input-file -verify-diagnostics | FileCheck %s -// ----- - -// CHECK-LABEL: func.func @torch.aten.clone$basic( -// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { -// CHECK: %[[T0:.*]] = torch_c.to_builtin_tensor %[[ARG0]] : !torch.vtensor<[?,?],f32> -> tensor -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[T1:.*]] = stablehlo.convert %[[T0]] : tensor -// CHECK: %[[T2:.*]] = torch_c.from_builtin_tensor %[[T1]] : tensor -> !torch.vtensor<[?,?],f32> -// CHECK: return %[[T2]] : !torch.vtensor<[?,?],f32> -func.func @torch.aten.clone$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> { - %none = torch.constant.none - %0 = torch.aten.clone %arg0, %none : !torch.vtensor<[?,?],f32>, !torch.none -> !torch.vtensor<[?,?],f32> - return %0 : !torch.vtensor<[?,?],f32> -} - // ----- // CHECK-LABEL: func.func @torch.vtensor.literal$basic() -> !torch.vtensor<[],f32> { From 1a7442e0aa685ea7ab786e4f5b5703619e15c09e Mon Sep 17 00:00:00 2001 From: Ramiro Leal-Cavazos Date: Tue, 30 Jan 2024 19:59:46 -0800 Subject: [PATCH 142/283] Add clang-format check to CI (#2816) This PR adds a check to the CI right after checking out the Torch-MLIR repository to make sure that the changes in the PR don't require any `git clang-format` modifications. --- .github/workflows/lint.yml | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/.github/workflows/lint.yml b/.github/workflows/lint.yml index 464ebdad93c0..364e9fa9d378 100644 --- a/.github/workflows/lint.yml +++ b/.github/workflows/lint.yml @@ -1,3 +1,4 @@ +# yamllint disable rule:line-length name: Lint Checks on: @@ -12,6 +13,14 @@ jobs: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 + with: + # `git-clang-format` needs access to the commit before the + # current merge commit to know what changes to format. + fetch-depth: 2 - name: Validate GitHub Actions yaml files run: | yamllint ./.github/workflows/ ./.github/actions/ + - name: Check clang-format + run: | + wget -q https://raw.githubusercontent.com/llvm/llvm-project/main/clang/tools/clang-format/git-clang-format + python3 git-clang-format --diff HEAD~1 From 105aad6f57a19db1cfcf17bb394367431973b65e Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Tue, 30 Jan 2024 21:22:12 -0800 Subject: [PATCH 143/283] [torch-mlir] provide FX traced graph importer for sparse tensors (#2817) Note that we are waiting for actual FX traced graph support for sparse tensors. For details see https://github.com/pytorch/pytorch/issues/117188 Until then, however, we provide this clever importer that builds the FX traced graph for for the dense case and then puts a sparse annotation back on the parameters. With import test. --- python/torch_mlir/extras/fx_importer.py | 46 +++++++-- test/python/fx_importer/sparse_test.py | 130 ++++++++++++++++++++++++ 2 files changed, 166 insertions(+), 10 deletions(-) create mode 100644 test/python/fx_importer/sparse_test.py diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index d799d61f6a92..8cffcb1ea935 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -207,10 +207,32 @@ } -"""Check whether an object in our graph is symbolic""" +def sparsity_encoding(shape: torch.Size, sparse_layout : torch.layout) -> str: + """Returns sparse tensor encoding for the given sparse layout as string. + + The method currently just supports 2-dim sparse formats. This should be + generalized to the torch.sparse encodings for prefix dense batch dimensions + and suffix dense subtensor dimensions. Since MLIR supports a superset of what + is currently implememented in torch.sparse, this should not a be problem. + """ + + # TODO: any rank + if len(shape) != 2: + raise RuntimeError(f"Unsupported sparse rank {len(shape)}") + + if sparse_layout is torch.sparse_coo: + return '#sparse_tensor.encoding<{map=(i,j)->(i:compressed(nonunique),j:singleton)}>' + if sparse_layout is torch.sparse_csr: + return '#sparse_tensor.encoding<{map=(i,j)->(i:dense,j:compressed)}>' + if sparse_layout is torch.sparse_csc: + return '#sparse_tensor.encoding<{map=(i,j)->(j:dense,i:compressed)}>' + # TODO: block format (derive block size!) + + raise RuntimeError(f"Unsupported sparse layout {sparse_layout}") def is_symbolic(obj: Any) -> bool: + """Check whether an object in our graph is symbolic""" return isinstance(obj, (torch.SymInt, torch.SymFloat, torch.SymBool)) @@ -337,7 +359,7 @@ def import_frozen_exported_program(self, prog: torch.export.ExportedProgram): ) from e arg_replacements[input_name] = state_value - # Remove any lifted placeholders, replacing their uses with the state + # Remove any lifted placeholders, replacing their uses with the state # replacement value. g = prog.graph for node in g.nodes: @@ -455,17 +477,21 @@ def format_asm_shape(self, shape: torch.Size) -> str: """Return IrType for !torch.vtensor with the given shape and dtype""" - def get_vtensor_type(self, shape: torch.Size, dtype: torch.dtype): + def get_vtensor_type(self, shape: torch.Size, dtype: torch.dtype, sparse_layout: torch.layout = None): shape_asm = self.format_asm_shape(shape) mlir_dtype = str(self.dtype_to_type(dtype)) + if sparse_layout is not None: + sparsity = sparsity_encoding(shape, sparse_layout) + return IrType.parse( + f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)},{sparsity}>", context=self._c) return IrType.parse( - f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c - ) + f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c) def node_val_to_type(self, node: torch_fx.Node) -> IrType: try: tensor_meta = node.meta.get("tensor_meta") val = node.meta.get("val") + sparse_layout = node.meta.get("sparsity", None) if tensor_meta is not None: assert isinstance(tensor_meta, TensorMetadata) # Quantized tensor meta data is not preserved in our lowering, @@ -475,12 +501,12 @@ def node_val_to_type(self, node: torch_fx.Node) -> IrType: f"Quantized tensor meta data is not supported." ) else: - return self.tensor_metadata_to_type(tensor_meta) + return self.tensor_metadata_to_type(tensor_meta, sparse_layout) elif val is not None: # some nodes with symbolic inputs pass a 'val' attribute rather than # tensor_meta if isinstance(val, TorchFakeTensor): - return self.get_vtensor_type(val.size(), val.dtype) + return self.get_vtensor_type(val.size(), val.dtype, sparse_layout) t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val)) if t is not None: @@ -495,15 +521,15 @@ def node_val_to_type(self, node: torch_fx.Node) -> IrType: f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})" ) - def tensor_metadata_to_type(self, tm: TensorMetadata) -> IrType: + def tensor_metadata_to_type(self, tm: TensorMetadata, sparse_layout : torch.layout = None) -> IrType: tm_shape = tuple( item.node if is_symbolic(item) else item for item in list(tm.shape) ) - key = (tm_shape, tm.dtype) + key = (tm_shape, tm.dtype, sparse_layout) t = self._tensor_metadata_cache.get(key) if t is None: - t = self.get_vtensor_type(tm.shape, tm.dtype) + t = self.get_vtensor_type(tm.shape, tm.dtype, sparse_layout) self._tensor_metadata_cache[key] = t return t diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py new file mode 100644 index 000000000000..1490c160c3f1 --- /dev/null +++ b/test/python/fx_importer/sparse_test.py @@ -0,0 +1,130 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s + +from typing import Any, Callable, Dict, Optional, Tuple + +import torch +import torch.export +import torch.nn as nn + +from torch_mlir.extras.fx_importer import FxImporter +from torch_mlir import ir +from torch_mlir.dialects import torch as torch_d + + +# All sparse layouts currently supported in torch.sparse. +SPARSE_LAYOUTS = [ + torch.sparse_coo, + torch.sparse_csr, + torch.sparse_csc, + torch.sparse_bsr, + torch.sparse_bsc +] + + +def sparse_export(f: Callable, + args: Tuple[Any, ...], + kwargs: Optional[Dict[str, Any]] = None) -> torch.export.ExportedProgram: + """ + This is a ***temporary*** wrapper around `torch.export.export` + that eventually should be removed and simply replaced by the + standard API for exporting traced graphs. + + But until issue + + https://github.com/pytorch/pytorch/pull/117907 + + is addressed, this wrapper provides support for the sparse + tensor types by first converting all operands to dense tensors, + building the traced graph as for the dense case, and then + annotation sparse parameters with their actual sparse layout + attributes. This temporary solution accelerates testing + torch-mlir with PyTorch sparse tensors until the issue is + resovled. + """ + # Convert all arguments to dense. + dargs = tuple( a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args ) + mask = [ a.layout in SPARSE_LAYOUTS for a in args ] + # Build the regular FX traced graph with only dense arguments + # (the current version would crash otherwise, see issue above). + prog = torch.export.export(f, dargs, kwargs, constraints=None) + # Annotate sparse arguments in the graph. + alen = len(args) + for i, node in enumerate(prog.graph.nodes): + if node.op == "placeholder" and i < alen and mask[i]: + node.meta['sparsity'] = args[i].layout + # TODO: annotate inputs to change calling conventions! + return prog + + +def export_and_import(f, *args, **kwargs): + """This method implements Stella's importer, stripped down to essentials.""" + context = ir.Context() + torch_d.register_dialect(context) + fx_importer = FxImporter(context=context) + prog = sparse_export(f, args, kwargs) + fx_importer.import_frozen_exported_program(prog) + return fx_importer.module_op + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# CHECK-LABEL: test_sparse_sum +# CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[64,64],f32,#[[$CSR]]>) -> !torch.vtensor<[],f32> { +# CHECK: %[[N:.*]] = torch.constant.none +# CHECK: %[[R:.*]] = torch.aten.sum %[[A]], %[[N]] : !torch.vtensor<[64,64],f32,#[[$CSR]]>, !torch.none -> !torch.vtensor<[],f32> +# CHECK: return %[[R]] : !torch.vtensor<[],f32> +# CHECK: } +def test_sparse_sum(): + + class SumNet(torch.nn.Module): + + def __init__(self): + super(SumNet, self).__init__() + + def forward(self, x): + return x.sum() + + + dense_input = torch.ones(64, 64) + sparse_input = dense_input.to_sparse_csr() + m = export_and_import(SumNet(), sparse_input) + print(m) + + +@run +# CHECK-LABEL: test_sparse_SpMM +# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton) }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[64,64],f32,#[[$COO]]>, +# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[64,64],f32>) -> !torch.vtensor<[64,64],f32> { +# CHECK: %[[R:.*]] = torch.aten.mm %[[A]], %[[B]] : !torch.vtensor<[64,64],f32,#[[$COO]]>, !torch.vtensor<[64,64],f32> -> !torch.vtensor<[64,64],f32> +# CHECK: return %[[R]] : !torch.vtensor<[64,64],f32> +# CHECK: } +def test_sparse_SpMM(): + + class MatMulNet(torch.nn.Module): + + def __init__(self): + super(MatMulNet, self).__init__() + + def forward(self, x, y): + return torch.matmul(x, y) + + + dense_input = torch.ones(64, 64) + sparse_input = dense_input.to_sparse_coo() + m = export_and_import(MatMulNet(), sparse_input, dense_input) + print(m) From 26c0ecd09c532204b80e2145bcbe1539238e4c33 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 30 Jan 2024 22:18:13 -0800 Subject: [PATCH 144/283] [nfc] Remove unused var causing error downstream --- lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 3cba62d7691c..ff0bcd524e4f 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -707,8 +707,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( selectSizes.push_back(1); Type selectResultType = axesType.getWithSizesAndDtype( llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); - Value noneVal = rewriter.create(binder.getLoc()); - auto sizes = dyn_cast(axes.getType()).getSizes(); From 943164d797f2f20a4c9a3d792f824b3ecd5ecc70 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 30 Jan 2024 22:39:22 -0800 Subject: [PATCH 145/283] Fix some spurious `None` values in tests (broken at head). (#2840) --- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 21 ++++++++----------- 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 97f698e5eb3e..9d947dce5ce5 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -626,7 +626,6 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // CHECK-LABEL: func.func @test_reduce_max_keepdims_example func.func @test_reduce_max_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[RANK:.*]] = torch.constant.int 3 // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 @@ -669,7 +668,6 @@ func.func @test_reduce_max_default_axes_keepdim_example(%arg0: !torch.vtensor<[3 // CHECK-LABEL: func.func @test_reduce_max_do_not_keepdims_example func.func @test_reduce_max_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[RANK:.*]] = torch.constant.int 3 // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 @@ -693,7 +691,7 @@ func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[ // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[INT1:.+]] = torch.constant.int 1 // CHECK: torch.aten.Bool.int %int1 : !torch.int -> !torch.bool - // CHECK: torch.aten.sum.dim_IntList %arg0, %none, %0, %none : !torch.vtensor<[3,2,2],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + // CHECK: torch.aten.sum.dim_IntList %arg0, %[[NONE]], %0, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> return %0 : !torch.vtensor<[1,1,1],f32> } @@ -712,14 +710,13 @@ func.func @test_reduce_sum_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2] // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: torch.aten.sum.dim_IntList %arg0, %6, %false, %none : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + // CHECK: torch.aten.sum.dim_IntList %arg0, %6, %false, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> return %0 : !torch.vtensor<[3,2],f32> } // CHECK-LABEL: func.func @test_reduce_sum_empty_axes_input_noop_example func.func @test_reduce_sum_empty_axes_input_noop_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[3,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[NONE:.+]] = torch.constant.none %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64, torch.onnx.noop_with_empty_axes = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[3,2,2],f32> return %0 : !torch.vtensor<[3,2,2],f32> } @@ -738,7 +735,7 @@ func.func @test_reduce_sum_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true - // CHECK: torch.aten.sum.dim_IntList %arg0, %6, %true, %none : !torch.vtensor<[2,0,4],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[2,0,1],f32> + // CHECK: torch.aten.sum.dim_IntList %arg0, %6, %true, %[[NONE]] : !torch.vtensor<[2,0,4],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[2,0,1],f32> %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32> return %0 : !torch.vtensor<[2,0,1],f32> } @@ -757,7 +754,7 @@ func.func @test_reduce_sum_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true - // CHECK: torch.aten.sum.dim_IntList %arg0, %6, %true, %none : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> + // CHECK: torch.aten.sum.dim_IntList %arg0, %6, %true, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> return %0 : !torch.vtensor<[3,1,2],f32> } @@ -776,7 +773,7 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor< // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true - // CHECK: torch.aten.sum.dim_IntList %arg0, %6, %true, %none : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> + // CHECK: torch.aten.sum.dim_IntList %arg0, %6, %true, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> return %0 : !torch.vtensor<[3,1,2],f32> } @@ -788,7 +785,7 @@ func.func @test_reduce_mean_default_axes_keepdims_example(%arg0: !torch.vtensor< // CHECK: %[[NONE:.+]] = torch.constant.none // CHECK: %[[INT1:.+]] = torch.constant.int 1 // CHECK: torch.aten.Bool.int %int1 : !torch.int -> !torch.bool - // CHECK: torch.aten.mean.dim %arg0, %none, %0, %none : !torch.vtensor<[3,2,2],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + // CHECK: torch.aten.mean.dim %arg0, %[[NONE]], %0, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> %0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> return %0 : !torch.vtensor<[1,1,1],f32> } @@ -807,7 +804,7 @@ func.func @test_reduce_mean_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2 // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: torch.aten.mean.dim %arg0, %6, %false, %none : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> + // CHECK: torch.aten.mean.dim %arg0, %6, %false, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> %0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> return %0 : !torch.vtensor<[3,2],f32> } @@ -826,7 +823,7 @@ func.func @test_reduce_mean_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true - // CHECK: torch.aten.mean.dim %arg0, %6, %true, %none : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> + // CHECK: torch.aten.mean.dim %arg0, %6, %true, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> %0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> return %0 : !torch.vtensor<[3,1,2],f32> } @@ -845,7 +842,7 @@ func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true - // CHECK: torch.aten.mean.dim %arg0, %6, %true, %none : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> + // CHECK: torch.aten.mean.dim %arg0, %6, %true, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> %0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> return %0 : !torch.vtensor<[3,1,2],f32> } From 7301aa80fd38ed30a7a9b2006b4af1c1e2557eea Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 30 Jan 2024 23:33:21 -0800 Subject: [PATCH 146/283] Enable -Werror in lib/ and LTC. (#2841) Required some massaging of LTC to make it warning clean, and I had to manually disable some warnings on the generated source files (which we don't control). The project is warning clean now. The `-Werror` flag is disabled by default as we can't control everywhere people will try to build/install. The CI enables it via -DTORCH_MLIR_ENABLE_WERROR_FLAG=ON. --- CMakeLists.txt | 9 +++++ build_tools/ci/build_posix.sh | 1 + include/torch-mlir-c/Registration.h | 2 +- include/torch-mlir-c/TorchTypes.h | 40 +++++++++---------- lib/CMakeLists.txt | 2 + .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 2 +- .../ltc/csrc/base_lazy_backend/CMakeLists.txt | 10 +++++ .../ltc/csrc/base_lazy_backend/dynamic_ir.cpp | 8 ++-- .../base_lazy_backend/mlir_lowering_context.h | 9 ++--- .../mlir_native_functions.cpp | 8 ++-- .../csrc/base_lazy_backend/utils/sys_utils.h | 6 +-- 11 files changed, 59 insertions(+), 38 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 376aea80eea3..44f02ac6af38 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -31,6 +31,7 @@ include(CMakeDependentOption) # Project options #------------------------------------------------------------------------------- +option(TORCH_MLIR_ENABLE_WERROR_FLAG "Enable `-Werror` flag on supported directories, treat error as warning" OFF) option(TORCH_MLIR_USE_INSTALLED_PYTORCH "If depending on PyTorch use it as installed in the current Python environment" ON) option(TORCH_MLIR_ENABLE_REFBACKEND "Enable reference backend" ON) @@ -53,6 +54,14 @@ cmake_dependent_option(TORCH_MLIR_ENABLE_LTC "Enables LTC backend" OFF TORCH_MLI option(TORCH_MLIR_ENABLE_ONNX_C_IMPORTER "Enables the ONNX C importer" OFF) +macro(torch_mlir_enable_werror) + if(TORCH_MLIR_ENABLE_WERROR_FLAG) + if(NOT MSVC) + add_compile_options(-Werror) + endif() + endif() +endmacro() + #------------------------------------------------------------------------------- # Configure out-of-tree vs in-tree build #------------------------------------------------------------------------------- diff --git a/build_tools/ci/build_posix.sh b/build_tools/ci/build_posix.sh index 438a55c74389..fec5e252e8d7 100755 --- a/build_tools/ci/build_posix.sh +++ b/build_tools/ci/build_posix.sh @@ -42,6 +42,7 @@ cmake -S "$repo_root/externals/llvm-project/llvm" -B "$build_dir" \ -DCMAKE_BUILD_TYPE=Release \ -DPython3_EXECUTABLE="$(which python)" \ -DLLVM_ENABLE_ASSERTIONS=ON \ + -DTORCH_MLIR_ENABLE_WERROR_FLAG=ON \ -DCMAKE_INSTALL_PREFIX="$install_dir" \ -DCMAKE_INSTALL_LIBDIR=lib \ -DLLVM_ENABLE_PROJECTS=mlir \ diff --git a/include/torch-mlir-c/Registration.h b/include/torch-mlir-c/Registration.h index 4d582e61f132..7d607693d56b 100644 --- a/include/torch-mlir-c/Registration.h +++ b/include/torch-mlir-c/Registration.h @@ -23,7 +23,7 @@ extern "C" { MLIR_CAPI_EXPORTED void torchMlirRegisterAllDialects(MlirContext context); /** Registers all passes for symbolic access with the global registry. */ -MLIR_CAPI_EXPORTED void torchMlirRegisterAllPasses(); +MLIR_CAPI_EXPORTED void torchMlirRegisterAllPasses(void); #ifdef __cplusplus } diff --git a/include/torch-mlir-c/TorchTypes.h b/include/torch-mlir-c/TorchTypes.h index c852dd61387d..b214e147d5d9 100644 --- a/include/torch-mlir-c/TorchTypes.h +++ b/include/torch-mlir-c/TorchTypes.h @@ -35,7 +35,7 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchNnModuleTypeGet(MlirContext context, MlirStringRef className); /// Gets the !torch.nn.Module typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNnModuleTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNnModuleTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.optional type. @@ -53,7 +53,7 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchOptionalTypeGetContained(MlirType containedType); /// Gets the !torch.optional typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchOptionalTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchOptionalTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.tuple type. @@ -75,7 +75,7 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchTupleTypeGetType(MlirType t, intptr_t pos); /// Gets the !torch.tuple typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchTupleTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchTupleTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.union type. @@ -97,7 +97,7 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchUnionTypeGetType(MlirType t, intptr_t pos); /// Gets the !torch.union typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchUnionTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchUnionTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.list type. @@ -113,7 +113,7 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGet(MlirType containedType); MLIR_CAPI_EXPORTED MlirType torchMlirTorchListTypeGetContainedType(MlirType t); /// Gets the !torch.list typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchListTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchListTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.Device type. @@ -126,7 +126,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchDevice(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchDeviceTypeGet(MlirContext context); /// Gets the !torch.device typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDeviceTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDeviceTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.Generator type. @@ -139,7 +139,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchGenerator(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchGeneratorTypeGet(MlirContext context); /// Gets the !torch.generator typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchGeneratorTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchGeneratorTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.bool type. @@ -152,7 +152,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchBool(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchBoolTypeGet(MlirContext context); /// Gets the !torch.bool typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchBoolTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchBoolTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.int type. @@ -165,7 +165,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchInt(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchIntTypeGet(MlirContext context); /// Gets the !torch.int typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchIntTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchIntTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.float type. @@ -178,7 +178,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchFloat(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchFloatTypeGet(MlirContext context); /// Gets the !torch.float typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchFloatTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchFloatTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.LinearParams type. @@ -192,7 +192,7 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchLinearParamsTypeGet(MlirContext context); /// Gets the !torch.linearparams typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchLinearParamsTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.qint8 type. @@ -205,7 +205,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQInt8(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchQInt8TypeGet(MlirContext context); /// Gets the !torch.qint8 typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQInt8TypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQInt8TypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.quint8 type. @@ -218,7 +218,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchQUInt8(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchQUInt8TypeGet(MlirContext context); /// Gets the !torch.quint8 typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQUInt8TypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchQUInt8TypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.tensor type. @@ -266,7 +266,7 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchNonValueTensorTypeGetDtype(MlirType t); /// Gets the !torch.tensor typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNonValueTensorTypeGetTypeID(void); //===----------------------------------------------------------------------===// // torch.vtensor type. @@ -312,7 +312,7 @@ torchMlirTorchValueTensorTypeGetSizes(MlirType t, int64_t *sizes); MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGetDtype(MlirType t); /// Gets the !torch.vtensor typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchValueTensorTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchValueTensorTypeGetTypeID(void); //===----------------------------------------------------------------------===// // !torch.none type. @@ -325,7 +325,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNone(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchNoneTypeGet(MlirContext context); /// Gets the !torch.none typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNoneTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNoneTypeGetTypeID(void); //===----------------------------------------------------------------------===// // !torch.str type. @@ -338,7 +338,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchString(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchStringTypeGet(MlirContext context); /// Gets the !torch.str typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchStringTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchStringTypeGetTypeID(void); //===----------------------------------------------------------------------===// // !torch.any type. @@ -351,7 +351,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchAny(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchAnyTypeGet(MlirContext context); /// Gets the !torch.any typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchAnyTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchAnyTypeGetTypeID(void); //===----------------------------------------------------------------------===// // !torch.number type. @@ -364,7 +364,7 @@ MLIR_CAPI_EXPORTED bool torchMlirTypeIsATorchNumber(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchNumberTypeGet(MlirContext context); /// Gets the !torch.number typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNumberTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchNumberTypeGetTypeID(void); //===----------------------------------------------------------------------===// // !torch.dict type. @@ -387,7 +387,7 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetKeyType(MlirType t); MLIR_CAPI_EXPORTED MlirType torchMlirTorchDictTypeGetValueType(MlirType t); /// Gets the !torch.dict typeid. -MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDictTypeGetTypeID(); +MLIR_CAPI_EXPORTED MlirTypeID torchMlirTorchDictTypeGetTypeID(void); #ifdef __cplusplus } diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index d9030c23a66f..0db753e4746a 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -1,3 +1,5 @@ +torch_mlir_enable_werror() + add_subdirectory(CAPI) add_subdirectory(Conversion) add_subdirectory(Dialect) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index ff0bcd524e4f..d54c1e1b9dd0 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -673,7 +673,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (operands.size() == 1) { if (noop_with_empty_axes == 0) { MLIRContext *context = binder.op->getContext(); - auto rank = + int rank = data.getType().cast().getSizes().size(); SmallVector dims; for (int i = 0; i < rank; i++) { diff --git a/projects/ltc/csrc/base_lazy_backend/CMakeLists.txt b/projects/ltc/csrc/base_lazy_backend/CMakeLists.txt index eee3044f0fc9..2bbdbd233344 100644 --- a/projects/ltc/csrc/base_lazy_backend/CMakeLists.txt +++ b/projects/ltc/csrc/base_lazy_backend/CMakeLists.txt @@ -2,11 +2,21 @@ # Setup PyTorch/LTC #------------------------------------------------------------------------------- +torch_mlir_enable_werror() + set(LTC_GENERATED generated/LazyNativeFunctions.cpp generated/RegisterLazy.cpp generated/shape_inference.cpp ) + +# The auto generated files trigger some warnings we can't do anything about. +if(NOT MSVC) + set_source_files_properties(${LTC_GENERATED} + PROPERTIES COMPILE_FLAGS "-Wno-sign-compare -Wno-unused-function" + ) +endif() + set(LTC_BACKEND_DEPENDS mlir_lowering_context.cpp mlir_native_functions.cpp diff --git a/projects/ltc/csrc/base_lazy_backend/dynamic_ir.cpp b/projects/ltc/csrc/base_lazy_backend/dynamic_ir.cpp index c11c1563bb5d..363bac959281 100644 --- a/projects/ltc/csrc/base_lazy_backend/dynamic_ir.cpp +++ b/projects/ltc/csrc/base_lazy_backend/dynamic_ir.cpp @@ -24,7 +24,7 @@ std::string DimensionNode::ToString() const { return "DimensionNode"; } SizeNode::SizeNode(Value input, size_t dim) : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::size")}, {input}, MHash(dim)), - dim_(dim){}; + dim_(dim) {} int64_t SizeNode::getStaticValue() const { return dynamic_cast(operand(0).node) @@ -35,7 +35,7 @@ int64_t SizeNode::getStaticValue() const { std::string SizeNode::ToString() const { return "SizeNode"; } SizeAdd::SizeAdd(Value a, Value b) - : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::add")}, {a, b}){}; + : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::add")}, {a, b}) {} int64_t SizeAdd::getStaticValue() const { return dynamic_cast(operand(0).node) @@ -46,7 +46,7 @@ int64_t SizeAdd::getStaticValue() const { std::string SizeAdd::ToString() const { return "SizeAdd"; } SizeMul::SizeMul(Value a, Value b) - : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::mul")}, {a, b}){}; + : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::mul")}, {a, b}) {} int64_t SizeMul::getStaticValue() const { return dynamic_cast(operand(0).node) @@ -57,7 +57,7 @@ int64_t SizeMul::getStaticValue() const { std::string SizeMul::ToString() const { return "SizeMul"; } SizeDiv::SizeDiv(Value a, Value b) - : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::div")}, {a, b}){}; + : DimensionNode(OpKind{c10::Symbol::fromQualString("aten::div")}, {a, b}) {} int64_t SizeDiv::getStaticValue() const { TORCH_CHECK( diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.h b/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.h index 3b226b46896a..e69820535cb8 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.h +++ b/projects/ltc/csrc/base_lazy_backend/mlir_lowering_context.h @@ -150,15 +150,14 @@ class TORCH_API TorchMlirComputation : public torch::lazy::Computation { protected: size_t num_parameters_; - std::unordered_map parameters_map_; - std::vector parameter_names_; - std::vector parameter_shapes_; - Shape result_shape_; - MlirModule module_op_; MlirContext mlir_context_; std::shared_ptr graph_; InputOutputAliases input_output_aliases_; + std::unordered_map parameters_map_; + std::vector parameter_names_; + std::vector parameter_shapes_; + Shape result_shape_; }; } // namespace lazy diff --git a/projects/ltc/csrc/base_lazy_backend/mlir_native_functions.cpp b/projects/ltc/csrc/base_lazy_backend/mlir_native_functions.cpp index a0e4bae76db6..af680f224095 100644 --- a/projects/ltc/csrc/base_lazy_backend/mlir_native_functions.cpp +++ b/projects/ltc/csrc/base_lazy_backend/mlir_native_functions.cpp @@ -67,7 +67,7 @@ c10::optional to_meta(const c10::optional &tensor) { return c10::nullopt; } -std::vector to_meta(at::ITensorListRef t_list) { +[[maybe_unused]] std::vector to_meta(at::ITensorListRef t_list) { std::vector outs; outs.reserve(t_list.size()); for (const auto &tensor : t_list) { @@ -92,7 +92,7 @@ namespace lazy { namespace { -at::Tensor +[[maybe_unused]] at::Tensor CreateLtcTensor(const at::Tensor &tensor, const c10::optional &device) { if (tensor.defined() && device) { @@ -102,7 +102,7 @@ CreateLtcTensor(const at::Tensor &tensor, return tensor; } -c10::optional +[[maybe_unused]] c10::optional GetLtcDevice(const c10::optional &device) { if (!device) { return c10::nullopt; @@ -334,7 +334,7 @@ at::Tensor LazyNativeFunctions::_to_copy( std::move(node), lazy_self->GetDevice())); return result; } -}; +} at::Tensor LazyNativeFunctions::_unsafe_view(const at::Tensor &self, at::IntArrayRef size) { diff --git a/projects/ltc/csrc/base_lazy_backend/utils/sys_utils.h b/projects/ltc/csrc/base_lazy_backend/utils/sys_utils.h index f6c51ba6158f..5804bce5fd93 100644 --- a/projects/ltc/csrc/base_lazy_backend/utils/sys_utils.h +++ b/projects/ltc/csrc/base_lazy_backend/utils/sys_utils.h @@ -14,8 +14,8 @@ static T GetEnv(const std::string &name, const T &default_value = T(0)) { return T(std::atoi(env)); } -static std::string GetEnvString(const std::string &name, - const std::string &default_value) { +[[maybe_unused]] static std::string +GetEnvString(const std::string &name, const std::string &default_value) { const char *env = std::getenv(name.c_str()); if (!env) { return default_value; @@ -23,7 +23,7 @@ static std::string GetEnvString(const std::string &name, return std::string(env); } -static bool GetEnvBool(const char *name, bool defval) { +[[maybe_unused]] static bool GetEnvBool(const char *name, bool defval) { const char *env = std::getenv(name); if (env == nullptr) { return defval; From 54ef18c556c10402027eca04b27b0384a2647da1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ilija=20Kalini=C4=87?= Date: Wed, 31 Jan 2024 18:39:38 +0100 Subject: [PATCH 147/283] Implement lowering of torch.aten.lerp.Scalar (#2773) Closes nod-ai/SHARK-Turbine#356 --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 49 +++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 14 ++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 30 ++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 4 ++ .../build_tools/abstract_interp_lib_gen.py | 24 +++++++++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/elementwise.py | 42 ++++++++++++++++ 8 files changed, 165 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 81ee9844af30..0ae45798d188 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -1620,6 +1620,55 @@ def Torch_AtenLerp_TensorOp : Torch_Op<"aten.lerp_.Tensor", [ }]; } +def Torch_AtenLerpScalarOp : Torch_Op<"aten.lerp.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::lerp.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$end, + AnyTorchScalarType:$weight + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLerpScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenLerpScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + +def Torch_AtenLerp_ScalarOp : Torch_Op<"aten.lerp_.Scalar", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::lerp_.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$end, + AnyTorchScalarType:$weight + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLerp_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenLerp_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenEqTensorOp : Torch_Op<"aten.eq.Tensor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index bb9717303e6b..57ece8cfdd7e 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8438,6 +8438,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list, !torch.list) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.lerp.Scalar\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.addcmul\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.list, %arg3: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg1, %arg2) : (!torch.list, !torch.list) -> !torch.list\n" " %1 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %0) : (!torch.list, !torch.list) -> !torch.list\n" @@ -11198,6 +11202,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%3, %4) : (!torch.list>, !torch.list) -> !torch.int\n" " return %5 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.lerp.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.number) -> !torch.int {\n" +" %none = torch.constant.none\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0, %none : (!torch.int, !torch.int, !torch.none) -> !torch.list>\n" +" %3 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.get_dtype_of_scalar(%arg2) : (!torch.number) -> !torch.int\n" +" %4 = torch.prim.ListConstruct %0#1, %1#1, %3 : (!torch.int, !torch.int, !torch.int) -> !torch.list\n" +" %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %4) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %5 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.addcmul\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.tuple, %arg3: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index d1794de930b4..edf51be11310 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1895,6 +1895,35 @@ class DecomposeAtenLeakyReluBackwardOp }; } // namespace +namespace { +class DecomposeAtenLerpScalarOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLerpScalarOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + auto resType = op.getType().cast(); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + Value cstOne = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + auto start = op.getSelf(); + auto inputType = start.getType().cast(); + + auto delta = rewriter.create(loc, inputType, op.getEnd(), + start, cstOne); + + auto weightedDelta = + rewriter.create(loc, inputType, delta, op.getWeight()); + auto lerp = rewriter.create(loc, inputType, start, + weightedDelta, cstOne); + rewriter.replaceOp(op, lerp); + return success(); + } +}; +} // namespace + // Elu = scale * max(0,x) + alpha * scale * (exp(min(0,x) * input_scale) - 1) namespace { class DecomposeAtenEluOp : public OpRewritePattern { @@ -6763,6 +6792,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index befdf808ad5b..c4259dc958b8 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -488,6 +488,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 7c1cc16261d0..4f8c04b1d2c5 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1116,6 +1116,8 @@ "ElementwiseLeakyReluModule_basic", "ElementwiseLeakyReluModule_basic", "ElementwiseLeakyReluStaticModule_basic", + "ElementwiseLerpScalarIntModule_basic", + "ElementwiseLerpScalarFloatModule_basic", "ElementwiseLog2Module_basic", "ElementwiseLogModule_basic", "ElementwiseLtDiffWidthScalarModule_basic", @@ -1496,6 +1498,8 @@ "ElementwiseLogitModule_basic", "ElementwiseRemainderScalarModule_Int_Float_basic", "ElementwiseRemainderScalarModule_Bool_basic", + "ElementwiseLerpScalarIntModule_basic", + "ElementwiseLerpScalarFloatModule_basic", "AtenIntTensorByteDtypeModule_basic", "AtenIntTensorCharDtypeModule_basic", "UpSampleNearest2dBackwardVec_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 91e98d99c9ff..922b207a2c57 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1245,6 +1245,9 @@ def aten〇nan_to_num〡shape(self: List[int], nan: Optional[float] = None, posi def aten〇lerp〇Tensor〡shape(self: List[int], end: List[int], weight: List[int]) -> List[int]: return upstream_shape_functions.broadcast(self, upstream_shape_functions.broadcast(end, weight)) +def aten〇lerp〇Scalar〡shape(self: List[int], end: List[int], weight: float) -> List[int]: + return upstream_shape_functions.broadcast(self, end) + def aten〇addcmul〡shape(self: List[int], tensor1: List[int], tensor2: List[int], value: float = 1) -> List[int]: return upstream_shape_functions.broadcast(self, upstream_shape_functions.broadcast(tensor1, tensor2)) @@ -3313,6 +3316,27 @@ def aten〇lerp〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtyp dtypes = [self_dtype, end_dtype, weight_dtype] return promote_dtypes(ranks, dtypes) +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1)], weight=0.5) + + # Different width + [Invocation(TensorOfShape(4, 3, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.float64), + weight=0.5), + # Different type + Invocation(TensorOfShape(4, 3, dtype=torch.int32), + TensorOfShape(4, 3, dtype=torch.float32), + weight=0.5), + Invocation(TensorOfShape(4, 3, dtype=torch.float32), + TensorOfShape(4, 3, dtype=torch.float32), + weight=2)]) +def aten〇lerp〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], end_rank_dtype: Tuple[int, int], weight: Union[int, float, complex]) -> int: + self_rank, self_dtype = self_rank_dtype + end_rank, end_dtype = end_rank_dtype + + ranks: List[Optional[int]] = [self_rank, end_rank, None] + dtypes = [self_dtype, end_dtype, get_dtype_of_scalar(weight)] + return promote_dtypes(ranks, dtypes) + @check_dtype_function( _check_tensors_with_the_same_dtype(tensor_shapes=[(1, 1), (1, 1), (1, 1)], error_types={torch.bool}) + # Different width diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index a329c1ae01a4..43635bf2fa05 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -290,6 +290,7 @@ def emit_with_mutating_variants(key, **kwargs): "aten::logical_xor : (Tensor, Tensor) -> (Tensor)", "aten::logical_not : (Tensor) -> (Tensor)", "aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", + "aten::lerp.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)", "aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::ge.Tensor : (Tensor, Tensor) -> (Tensor)", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 26eac617a4a6..f711af6d4639 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -545,6 +545,48 @@ def forward(self, x): def ElementwiseLeakyReluStaticModule_basic(module, tu: TestUtils): module.forward(tu.rand(4, 5, 6, low=-1)) + +# ============================================================================== + + +class ElementwiseLerpScalarIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.ops.aten.lerp(a, b, weight=2) + +@register_test_case(module_factory=lambda: ElementwiseLerpScalarIntModule()) +def ElementwiseLerpScalarIntModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5,3), tu.rand(5,3)) + + +class ElementwiseLerpScalarFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ([-1, -1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.ops.aten.lerp(a, b, weight=0.5) + +@register_test_case(module_factory=lambda: ElementwiseLerpScalarFloatModule()) +def ElementwiseLerpScalarFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5,3), tu.rand(5,3)) + + # ============================================================================== From 3500523f750e069aca1239e3c9a2866b86f1a37a Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 31 Jan 2024 11:40:53 -0800 Subject: [PATCH 148/283] [onnx] Convert resources to denseattr for `onnx.constant` to `torch` (#2830) `onnx` explicitly specifies that `raw_data` is stored in `little-endian` layout. While converting to `torch` we need to convert from a known endian format to an internal format of consistent layout. This means endianness must be correct during the import of `onnx.Constant`. --------- Co-authored-by: Xida Ren (Cedar) --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 35 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 25 +++++++++++++ 2 files changed, 60 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index bed22b08407a..e9221ed139d8 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -7,6 +7,7 @@ // //===----------------------------------------------------------------------===// +#include "mlir/IR/DialectResourceBlobManager.h" #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" @@ -16,6 +17,20 @@ using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::onnx_c; +class Endian { +private: + static constexpr uint32_t uint32_ = 0x01020304; + static constexpr uint8_t magic_ = (const uint8_t &)uint32_; + +public: + static constexpr bool little = magic_ == 0x04; + static constexpr bool big = magic_ == 0x01; + static_assert(little || big, "Cannot determine endianness!"); + +private: + Endian() = delete; +}; + static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) { // TODO: Add complete mapping. // Where are the ONNX and PyTorch dtype enums defined? @@ -632,6 +647,26 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); } + if (DenseResourceElementsAttr attr = + binder.op->getAttr("torch.onnx.value") + .dyn_cast_or_null()) { + // Bytes are stored in little endian order. Big endian support will + // require swizzling. + if (!Endian::little) { + binder.op->emitError( + "unimplemented: importing on big endian systems"); + return failure(); + } + + auto ty = cast(attr.getType()); + auto ptr = attr.getRawHandle().getBlob()->getData(); + DenseElementsAttr denseAttr = + DenseElementsAttr::getFromRawBuffer(ty, ptr); + rewriter.replaceOpWithNewOp( + binder.op, resultType, denseAttr); + return success(); + } + if (ElementsAttr attr = binder.op->getAttr("torch.onnx.value") .dyn_cast_or_null()) { rewriter.replaceOpWithNewOp( diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 2c06567bde97..797f9b6c2054 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1282,6 +1282,31 @@ func.func @ints_constant() -> !torch.vtensor<[2], si64> attributes {torch.onnx_m // ----- +// CHECK-LABEL: @dense_constant +func.func @dense_constant() -> () attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: torch.vtensor.literal(dense<[0, 10, 128, 17000]> : tensor<4xsi32>) : !torch.vtensor<[4],si32> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_int32> : tensor<4xsi32>} : () -> !torch.vtensor<[4],si32> + // CHECK: torch.vtensor.literal(dense<[0.000000e+00, 1.000000e+01, 1.280000e+02, 1.700000e+04]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %1 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_float32> : tensor<4xf32>} : () -> !torch.vtensor<[4],f32> + // CHECK: torch.vtensor.literal(dense<[-128, -1, 50, 127]> : tensor<4xsi8>) : !torch.vtensor<[4],si8> + %2 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_int8> : tensor<4xsi8>} : () -> !torch.vtensor<[4],si8> + // CHECK: torch.vtensor.literal(dense<[128, 255, 50, 127]> : tensor<4xui8>) : !torch.vtensor<[4],ui8> + %3 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_int8> : tensor<4xui8>} : () -> !torch.vtensor<[4],ui8> + return +} + +{-# + dialect_resources: { + builtin: { + _int8: "0x0800000080FF327F", + _int32: "0x08000000000000000a0000008000000068420000", + _float32: "0x0800000000000000000020410000004300d08446" + } + } +#-} + +// ----- + // CHECK-LABEL: @test_flatten_4d_axis_2 func.func @test_flatten_4d_axis_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2 From 54e258792c686e4bb30a8bbd1a0f1268d440d2a4 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 31 Jan 2024 11:41:06 -0800 Subject: [PATCH 149/283] [onnx] Import `onnx` constants as `onnx.Constant` instead of literals (#2831) To handle the conversion from raw bytes to `DenseElementsAttr` we need to handle the endianness conversion during `torch-onnx-to-torch`. Therefore when importing `onnx.Constant` it is better to represent using the `onnx` constant operation so that only one location requires the endianness correction. --- python/torch_mlir/extras/onnx_importer.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index 59a2682bbba9..c651f79b15fe 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -343,10 +343,14 @@ def import_initializer(self, initializer: onnx.TensorProto, extern_name: str = N with InsertionPoint(self._b), Location.name(iname): value_attr = self._cc.tensor_proto_to_attr(initializer) vtensor_type = self._cc.tensor_proto_to_type(initializer) + attrs = { + "name": StringAttr.get(f"onnx.Constant"), + "torch.onnx.value": value_attr, + } literal_op = Operation.create( - name="torch.vtensor.literal", + name="torch.operator", results=[vtensor_type], - attributes={"value": value_attr}, + attributes=attrs, ) self._nv_map[iname] = literal_op.result return literal_op.result From 8a17c98b74b53ae71e19a9c7f7451af62dc339d9 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Wed, 31 Jan 2024 14:21:17 -0800 Subject: [PATCH 150/283] Bump stablehlo to openxla/stablehlo@fd52182f76cadb82f2064fe5fc49a4fb4347a826 (#2821) With the recent LLVM integrate and changes from https://github.com/llvm/llvm-project/pull/78260, we hit this build error in Stablehlo (which is quite old). ``` external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1020:14: error: no member named 'startRootUpdate' in 'mlir::PatternRewriter' rewriter.startRootUpdate(op); ~~~~~~~~ ^ external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1026:16: error: no member named 'finalizeRootUpdate' in 'mlir::PatternRewriter' rewriter.finalizeRootUpdate(op); ~~~~~~~~ ^ external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1029:16: error: no member named 'cancelRootUpdate' in 'mlir::PatternRewriter' rewriter.cancelRootUpdate(op); ~~~~~~~~ ^ external/stablehlo/stablehlo/transforms/StablehloRefineShapes.cpp:1108:14: error: no member named 'updateRootInPlace' in 'mlir::PatternRewriter' rewriter.updateRootInPlace(op->getParentOp(), [&]() { return; }); ~~~~~~~~ ^ 4 errors generated. Target @torch-mlir//:torch-mlir-opt failed to build ``` I'm still puzzled as to how this didn't fail with the CMake merge gating CI (do we not test Stablehlo builds/tests?). In any case, bumping our submodule to https://github.com/openxla/stablehlo/pull/1918 fixes it. It exposes a new failing lit test in TorchToStablehlo though, that I have looped stablehlo developers into ([here](https://discord.com/channels/999073994483433573/999074539138990131/1201235845391331419)). ``` bazel run @torch-mlir//test/Conversion:TorchToStablehlo/scatter.mlir.test ...external/torch-mlir/test/Conversion/TorchToStablehlo/scatter.mlir within split at :1 offset :33:8: error: unexpected error: Expects non-empty reduction block for type inference %0 = torch.aten.scatter.src %arg0, %int0, %arg1, %arg2 : !torch.vtensor<[?,?],si64>, !torch.int, !torch.vtensor<[?,?],si64>, !torch.vtensor<[?,?],si64> -> !torch.vtensor<[?,?],si64> ^ LLVM ERROR: Failed to infer result type(s). ``` Bazel CI: https://github.com/sjain-stanford/torch-mlir/actions/runs/7732673480/job/21083102228 --- externals/stablehlo | 2 +- lib/Conversion/TorchToStablehlo/Basic.cpp | 18 ++--- .../TorchToStablehlo/GatherScatter.cpp | 10 ++- lib/Conversion/TorchToStablehlo/Linear.cpp | 34 +++------ lib/Conversion/TorchToStablehlo/Pooling.cpp | 73 +++++-------------- lib/Conversion/TorchToStablehlo/Reduction.cpp | 14 ++-- .../StablehloLegalizeUtils.cpp | 7 +- test/Conversion/TorchToStablehlo/pooling.mlir | 12 +-- 8 files changed, 61 insertions(+), 109 deletions(-) diff --git a/externals/stablehlo b/externals/stablehlo index ab709fe48de8..fd52182f76ca 160000 --- a/externals/stablehlo +++ b/externals/stablehlo @@ -1 +1 @@ -Subproject commit ab709fe48de88c67717abfbd7ef17425eb95ddaf +Subproject commit fd52182f76cadb82f2064fe5fc49a4fb4347a826 diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index 33db9ac9ee54..bee6c529bacb 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -377,12 +377,12 @@ class ConvertAtenAddSubOp : public OpConversionPattern { if (!skipMultiplyAlpha(op.getAlpha())) { Value alpha = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getAlpha(), outElemTy); - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; rhs = rewriter.create(op->getLoc(), rhs, alpha, bcastDimensions); } - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, bcastDimensions); return success(); @@ -424,7 +424,7 @@ class ConvertAtenMulDivOp : public OpConversionPattern { rhs = hlo::scalarToStablehloTensor(rewriter, op, adaptor.getOther(), outElemTy); } - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); auto loc = op.getLoc(); @@ -542,7 +542,7 @@ class ConvertAtenCompareOp : public OpConversionPattern { } else { return op.emitError("operator haven't been supported"); } - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; rewriter.replaceOpWithNewOp( op, outType, lhs, rhs, bcastDimensions, compareDirectionAttr, compareTypeAttr); @@ -570,7 +570,7 @@ class ConvertAtenLogicalBinaryOp : public OpConversionPattern { Value rhs = hlo::promoteType(rewriter, op.getLoc(), adaptor.getOther(), outType); - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; rewriter.replaceOpWithNewOp(op, outType, lhs, rhs, bcastDimensions); return success(); @@ -757,7 +757,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( llvm::to_vector<4>(llvm::seq(leadingRank, totalRank)); rewriter.replaceOpWithNewOp( op, outType, self, bcastShapeTensor, - rewriter.getI64TensorAttr(dimensionNumbers)); + rewriter.getDenseI64ArrayAttr(dimensionNumbers)); } return success(); } @@ -887,7 +887,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!rhsType) { rhs = hlo::scalarToStablehloTensor(rewriter, op, rhs, outElemTy); } - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, outType); rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, outType); auto loc = op.getLoc(); @@ -1478,7 +1478,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( Value window = rewriter.create(loc, outType, resultLength, 0); - DenseIntElementsAttr broadcastDimensions; + DenseI64ArrayAttr broadcastDimensions; Value mulOut = rewriter.create(loc, window, step, broadcastDimensions); rewriter.replaceOpWithNewOp(op, mulOut, start, @@ -1721,7 +1721,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.create(op->getLoc(), adaptor.getSelf()); Value bcastScalar = rewriter.create( op->getLoc(), outType, scalarTensor, shapeTensor, - rewriter.getI64TensorAttr({})); + rewriter.getDenseI64ArrayAttr({})); rewriter.replaceOp(op, bcastScalar); return success(); } diff --git a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp index d2b0450cd19a..53c418da4fb9 100644 --- a/lib/Conversion/TorchToStablehlo/GatherScatter.cpp +++ b/lib/Conversion/TorchToStablehlo/GatherScatter.cpp @@ -334,7 +334,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( return failure(); auto stablehloReduceOp = rewriter.create( - op.getLoc(), gatherOutput, initValue, rewriter.getI64TensorAttr({0})); + op.getLoc(), gatherOutput, initValue, rewriter.getDenseI64ArrayAttr({0}), + elementTy); Region ®ion = stablehloReduceOp.getBody(); Block &block = region.emplaceBlock(); @@ -510,7 +511,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, input, gatherIndicies, dimsAttr, - rewriter.getI64TensorAttr(sliceSizes)); + rewriter.getDenseI64ArrayAttr(sliceSizes)); return success(); } @@ -666,7 +667,8 @@ LogicalResult ConvertAtenOp::matchAndRewrite( /*indexVectorDim=*/indexVecDim); auto stablehloScatterOp = rewriter.create( - loc, input, scatterIndicies, src, scatterDimensionNumbers, false, false); + loc, inputType, input, scatterIndicies, src, scatterDimensionNumbers, + false, false); // config update computation function: just return the element from src. Block &block = stablehloScatterOp.getUpdateComputation().emplaceBlock(); @@ -833,7 +835,7 @@ LogicalResult ConvertAtenOp::matchAndRewrite( rewriter.replaceOpWithNewOp( op, resultType, input, finalIndexTensor, dimsAttr, - rewriter.getI64TensorAttr(sliceSizes)); + rewriter.getDenseI64ArrayAttr(sliceSizes)); return success(); } diff --git a/lib/Conversion/TorchToStablehlo/Linear.cpp b/lib/Conversion/TorchToStablehlo/Linear.cpp index df92317824a1..b1749ee1c074 100644 --- a/lib/Conversion/TorchToStablehlo/Linear.cpp +++ b/lib/Conversion/TorchToStablehlo/Linear.cpp @@ -39,10 +39,7 @@ Value getBroadcastTensor(PatternRewriter &rewriter, Operation *op, Value tensor, RankedTensorType outTy = RankedTensorType::get(shape, tensorTy.getElementType()); - RankedTensorType attrTy = - RankedTensorType::get({static_cast(broadcastDims.size())}, - rewriter.getIntegerType(64)); - auto broadcastAttr = DenseIntElementsAttr::get(attrTy, broadcastDims); + auto broadcastAttr = rewriter.getDenseI64ArrayAttr(broadcastDims); auto broadcast = rewriter.create( loc, outTy, tensor, stablehloShape, broadcastAttr); @@ -549,8 +546,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { // Prepare for transposed convolution SmallVector stablehloStrideVec(nSpatialDims, 1); - DenseIntElementsAttr stablehloStride = - rewriter.getI64TensorAttr(stablehloStrideVec); + auto stablehloStride = rewriter.getDenseI64ArrayAttr(stablehloStrideVec); SmallVector stablehloPaddingVec(nSpatialDims * 2, 0); for (int i = 0; i < nSpatialDims; ++i) { int64_t padInt = dilation[i] * (weightShape[i + 2] - 1) - padding[i]; @@ -563,15 +559,15 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { stablehloPaddingVec); SmallVector stablehloLhsDilationVec(nSpatialDims); std::copy(stride.begin(), stride.end(), stablehloLhsDilationVec.begin()); - DenseIntElementsAttr stablehloLhsDilation = - rewriter.getI64TensorAttr(stablehloLhsDilationVec); + auto stablehloLhsDilation = + rewriter.getDenseI64ArrayAttr(stablehloLhsDilationVec); SmallVector stablehloRhsDilationVec(nSpatialDims); std::copy(dilation.begin(), dilation.end(), stablehloRhsDilationVec.begin()); - DenseIntElementsAttr stablehloRhsDilation = - rewriter.getI64TensorAttr(stablehloRhsDilationVec); + auto stablehloRhsDilation = + rewriter.getDenseI64ArrayAttr(stablehloRhsDilationVec); - DenseElementsAttr windowReversal; + DenseBoolArrayAttr windowReversal; ArrayAttr precisionConfig; SmallVector spatialDims; @@ -614,10 +610,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { int64_t nDims = outType.getRank(); // Get stablehlo::ConvolutionOp attributes - DenseIntElementsAttr stablehloWindowStride = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stride.size())}, - rewriter.getI64Type()), - stride); + auto stablehloWindowStride = rewriter.getDenseI64ArrayAttr(stride); std::vector stablehloPaddingVec; for (size_t i = 0; i < padding.size(); i++) { stablehloPaddingVec.emplace_back(padding[i]); @@ -628,10 +621,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { {static_cast(padding.size()), static_cast(2)}, rewriter.getI64Type()), stablehloPaddingVec); - DenseIntElementsAttr stablehloRhsDilation = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(dilation.size())}, - rewriter.getI64Type()), - dilation); + auto stablehloRhsDilation = rewriter.getDenseI64ArrayAttr(dilation); SmallVector spatialDimensions; for (int64_t i = 2; i < nDims; i++) { spatialDimensions.emplace_back(i); @@ -648,8 +638,8 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { /*outputSpatialDimensions=*/spatialDimensions); // stablehlo::ConvolutionOp's optional attributes, leave them as default - DenseIntElementsAttr stablehloLhsDilation; - DenseElementsAttr windowReversal; + DenseI64ArrayAttr stablehloLhsDilation; + DenseBoolArrayAttr windowReversal; ArrayAttr precisionConfig; auto stablehloConvOp = rewriter.create( @@ -781,7 +771,7 @@ class ConvertAtenConvolutionOp : public ConvertAtenOp { options.dimSizeIndexBits); bias = hlo::promoteType(rewriter, op.getLoc(), bias, outTy); - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; rewriter.replaceOpWithNewOp( op, outTy, stablehloConvResult, bias, bcastDimensions); return success(); diff --git a/lib/Conversion/TorchToStablehlo/Pooling.cpp b/lib/Conversion/TorchToStablehlo/Pooling.cpp index 7ef69ae6712d..40b0dd691071 100644 --- a/lib/Conversion/TorchToStablehlo/Pooling.cpp +++ b/lib/Conversion/TorchToStablehlo/Pooling.cpp @@ -136,19 +136,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( stablehloPadding[stablehloPadding.size() - 2] = padding[1]; stablehloPadding[stablehloPadding.size() - 1] = padding[1]; - DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloKernelSize.size())}, - rewriter.getI64Type()), - stablehloKernelSize); - DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloStride.size())}, - rewriter.getI64Type()), - stablehloStride); - DenseIntElementsAttr baseDilations; - DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloDilation.size())}, - rewriter.getI64Type()), - stablehloDilation); + auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize); + auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride); + DenseI64ArrayAttr baseDilations; + auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, @@ -242,19 +233,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( stablehloPadding[stablehloPadding.size() - 2] = padding[1]; stablehloPadding[stablehloPadding.size() - 1] = padding[1]; - DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloKernelSize.size())}, - rewriter.getI64Type()), - stablehloKernelSize); - DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloStride.size())}, - rewriter.getI64Type()), - stablehloStride); - DenseIntElementsAttr baseDilations; - DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloDilation.size())}, - rewriter.getI64Type()), - stablehloDilation); + auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize); + auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride); + DenseI64ArrayAttr baseDilations; + auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, @@ -453,20 +435,10 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { Value initVal = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); - DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get( - {static_cast(stablehloKernelSize.size())}, - rewriter.getI64Type()), - stablehloKernelSize); - DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloStride.size())}, - rewriter.getI64Type()), - stablehloStride); - DenseIntElementsAttr baseDilations; - DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloDilation.size())}, - rewriter.getI64Type()), - stablehloDilation); + auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize); + auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride); + DenseI64ArrayAttr baseDilations; + auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, @@ -508,7 +480,7 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { .value(); } divisor = hlo::promoteType(rewriter, op.getLoc(), divisor, outTy); - DenseIntElementsAttr bcastDimensions; + DenseI64ArrayAttr bcastDimensions; rewriter.replaceOpWithNewOp( op, outTy, reduceWindowSum.getResult(0), divisor, bcastDimensions); return success(); @@ -528,7 +500,7 @@ class ConvertAtenAvgPoolOp : public ConvertAtenOp { windowSizeConst = rewriter.create( op->getLoc(), RankedTensorType::get(inputTy.getShape(), outTy.getElementType()), - windowSizeConst, inputShapeTensor, rewriter.getI64TensorAttr({})); + windowSizeConst, inputShapeTensor, rewriter.getDenseI64ArrayAttr({})); Value zero = createInitialValueForAtenPoolingOp(op, inputElemTy, rewriter); auto reduceWindowSize = rewriter.create( @@ -599,19 +571,10 @@ LogicalResult ConvertAtenOp::matchAndRewrite( SmallVector stablehloPadding(inputRank * 2, 0); stablehloPadding[dim * 2] = inputShape[dim] - 1; - DenseIntElementsAttr windowDimensions = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloKernelSize.size())}, - rewriter.getI64Type()), - stablehloKernelSize); - DenseIntElementsAttr windowStrides = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloStride.size())}, - rewriter.getI64Type()), - stablehloStride); - DenseIntElementsAttr baseDilations; - DenseIntElementsAttr windowDilations = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(stablehloDilation.size())}, - rewriter.getI64Type()), - stablehloDilation); + auto windowDimensions = rewriter.getDenseI64ArrayAttr(stablehloKernelSize); + auto windowStrides = rewriter.getDenseI64ArrayAttr(stablehloStride); + DenseI64ArrayAttr baseDilations; + auto windowDilations = rewriter.getDenseI64ArrayAttr(stablehloDilation); DenseIntElementsAttr pad = DenseIntElementsAttr::get( RankedTensorType::get( {static_cast(inputRank), static_cast(2)}, diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index f495aa39508f..e413fe532654 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -130,7 +130,7 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, initValue, initIndex, }, - rewriter.getI64TensorAttr(dim)); + rewriter.getDenseI64ArrayAttr(dim)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); @@ -412,7 +412,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); @@ -473,7 +473,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return failure(); llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); @@ -535,7 +535,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return failure(); llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); @@ -625,7 +625,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getI64TensorAttr(dims)); + op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims)); Region ®ion = stablehloReduceOp.getBody(); Block &block = region.emplaceBlock(); @@ -729,7 +729,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( auto reduceOp = rewriter.create( op->getLoc(), squareOp.getResult(), initValue, - rewriter.getI64TensorAttr(dims)); + rewriter.getDenseI64ArrayAttr(dims)); Region ®ion = reduceOp.getBody(); Block &block = region.emplaceBlock(); @@ -848,7 +848,7 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( ord, nullptr); auto reduceOp = rewriter.create( - op->getLoc(), powValue, initValue, rewriter.getI64TensorAttr(dims)); + op->getLoc(), powValue, initValue, rewriter.getDenseI64ArrayAttr(dims)); Region ®ion = reduceOp.getBody(); Block &block = region.emplaceBlock(); diff --git a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp index ed203cb0f91f..c3f8eff22fbc 100644 --- a/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp +++ b/lib/Conversion/TorchToStablehlo/StablehloLegalizeUtils.cpp @@ -241,10 +241,7 @@ Value promoteAndBroadcast(ConversionPatternRewriter &rewriter, Value input, if (!do_bcast) { return input; } - DenseIntElementsAttr bcast_attr = DenseIntElementsAttr::get( - RankedTensorType::get({static_cast(bcastDims.size())}, - rewriter.getI64Type()), - bcastDims); + auto bcast_attr = rewriter.getDenseI64ArrayAttr(bcastDims); auto bcast_op = rewriter.create( op->getLoc(), outType, input, bcast_attr); return bcast_op.getResult(); @@ -360,7 +357,7 @@ Value getConstantOfShape(PatternRewriter &rewriter, Location loc, auto constTensor = rewriter.create(loc, constAttr); return rewriter .create( - loc, outType, constTensor, shape, rewriter.getI64TensorAttr({})) + loc, outType, constTensor, shape, rewriter.getDenseI64ArrayAttr({})) .getResult(); } } // namespace hlo diff --git a/test/Conversion/TorchToStablehlo/pooling.mlir b/test/Conversion/TorchToStablehlo/pooling.mlir index fd531006d614..b8fc6cbd8384 100644 --- a/test/Conversion/TorchToStablehlo/pooling.mlir +++ b/test/Conversion/TorchToStablehlo/pooling.mlir @@ -18,7 +18,7 @@ // CHECK: ^bb0(%[[VAL_8:.*]]: tensor, %[[VAL_9:.*]]: tensor): // CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor // CHECK: stablehlo.return %[[VAL_10]] : tensor -// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK: }) {padding = dense<0> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor // CHECK: %[[VAL_11:.*]] = torch_c.from_builtin_tensor %[[VAL_7]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_11]] : !torch.vtensor<[?,?,?,?],f32> func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { @@ -51,7 +51,7 @@ func.func @torch.aten.max_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK: %[[VAL_10:.*]] = stablehlo.maximum %[[VAL_8]], %[[VAL_9]] : tensor // CHECK: stablehlo.return %[[VAL_10]] : tensor // CHECK: }) -// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<[1, 1, 2, 1]> : tensor<4xi64>, window_dimensions = dense<[1, 1, 2, 2]> : tensor<4xi64>, window_strides = dense<1> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK-SAME{LITERAL}: {padding = dense<[[0, 0], [0, 0], [2, 2], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = torch_c.from_builtin_tensor %[[VAL_6]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_7]] : !torch.vtensor<[?,?,?,?],f32> func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { @@ -105,7 +105,7 @@ func.func @torch.aten.max_pool2d$padding(%arg0: !torch.vtensor<[?,?,?,?],f32>) - // CHECK: %[[T20:.*]] = stablehlo.select %[[T16]], %[[ARG2]], %[[ARG4]] : tensor, tensor // CHECK: %[[T21:.*]] = stablehlo.select %[[T18]], %[[T19]], %[[T20]] : tensor, tensor // CHECK: stablehlo.return %[[T17]], %[[T21]] : tensor, tensor -// CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = dense<1> : tensor<3xi64>, window_dimensions = dense<[1, 3, 3]> : tensor<3xi64>, window_strides = dense<[1, 2, 2]> : tensor<3xi64>} : (tensor, tensor, tensor, tensor) -> (tensor, tensor) +// CHECK: }) {padding = dense<0> : tensor<3x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor, tensor, tensor) -> (tensor, tensor) // CHECK: %[[T14:.*]] = torch_c.from_builtin_tensor %[[T13]]#0 : tensor -> !torch.vtensor<[?,?,?],f32> // CHECK: %[[T15:.*]] = torch_c.from_builtin_tensor %[[T13]]#1 : tensor -> !torch.vtensor<[?,?,?],si64> // CHECK: return %[[T14]], %[[T15]] : !torch.vtensor<[?,?,?],f32>, !torch.vtensor<[?,?,?],si64> @@ -141,7 +141,7 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32> // CHECK: ^bb0(%[[IVAL_0:.*]]: tensor, %[[IVAL_1:.*]]: tensor): // CHECK: %[[IVAL_2:.*]] = stablehlo.add %[[IVAL_0]], %[[IVAL_1]] : tensor // CHECK: stablehlo.return %[[IVAL_2]] : tensor -// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor // CHECK: %[[VAL_7:.*]] = stablehlo.constant dense<1.000000e+00> : tensor // CHECK: %[[IDX_0:.*]] = arith.constant 0 : index // CHECK: %[[VAL_8:.*]] = tensor.dim %[[VAL_1]], %[[IDX_0]] : tensor @@ -162,7 +162,7 @@ func.func @torch.aten.max_pool2d_with_indices(%arg0: !torch.vtensor<[?,?,?],f32> // CHECK: ^bb0(%[[IVAL_3:.*]]: tensor, %[[IVAL_4:.*]]: tensor): // CHECK: %[[IVAL_5:.*]] = stablehlo.add %[[IVAL_3]], %[[IVAL_4]] : tensor // CHECK: stablehlo.return %[[IVAL_5]] : tensor -// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor // CHECK: %[[VAL_20:.*]] = stablehlo.divide %[[VAL_6]], %[[VAL_19]] : tensor // CHECK: %[[VAL_21:.*]] = torch_c.from_builtin_tensor %[[VAL_20]] : tensor -> !torch.vtensor<[?,?,?,?],f32> // CHECK: return %[[VAL_21]] : !torch.vtensor<[?,?,?,?],f32> @@ -198,7 +198,7 @@ func.func @torch.aten.avg_pool2d(%arg0: !torch.vtensor<[?,?,?,?],f32>) -> !torch // CHECK: ^bb0(%[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor): // CHECK: %[[T10:.*]] = stablehlo.add %[[ARG1]], %[[ARG2]] : tensor // CHECK: stablehlo.return %[[T10]] : tensor -// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = dense<1> : tensor<4xi64>, window_dimensions = dense<[1, 1, 3, 3]> : tensor<4xi64>, window_strides = dense<[1, 1, 2, 2]> : tensor<4xi64>} : (tensor, tensor) -> tensor +// CHECK{LITERAL}: }) {padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64>, window_dilations = array, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor // CHECK: %[[T6:.*]] = stablehlo.constant dense<9> : tensor // CHECK: %[[T7:.*]] = stablehlo.convert %[[T6]] : (tensor) -> tensor // CHECK: %[[T8:.*]] = chlo.broadcast_divide %[[T5]], %[[T7]] : (tensor, tensor) -> tensor From 0114a570e3c27403feaa2bce69535f9f7fd3b084 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 31 Jan 2024 15:09:12 -0800 Subject: [PATCH 151/283] [torch] Support lowering `torch.item` to `tensor.extract` (#2835) Extracting scalar values from tensors can be implemented via a lowering to tensor.extract. --- .../TorchToTensor/TorchToTensor.cpp | 45 ++++++++++++++++++- .../TorchConversion/Transforms/Passes.cpp | 2 + projects/pt1/e2e_testing/xfail_sets.py | 4 ++ .../torch_mlir_e2e_test/test_suite/scalar.py | 38 ++++++++++++++++ 4 files changed, 88 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchToTensor/TorchToTensor.cpp b/lib/Conversion/TorchToTensor/TorchToTensor.cpp index 417fd17fcb86..1b5341028c6d 100644 --- a/lib/Conversion/TorchToTensor/TorchToTensor.cpp +++ b/lib/Conversion/TorchToTensor/TorchToTensor.cpp @@ -28,6 +28,47 @@ using namespace mlir::torch::Torch; namespace { +class ConvertAtenItemOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename AtenItemOp::Adaptor; + LogicalResult + matchAndRewrite(AtenItemOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto operand = adaptor.getOperands()[0]; + auto operandTy = cast(operand.getType()); + auto torchDTy = cast(op.getOperand().getType()).getDtype(); + + if (operandTy.getNumElements() != 1) + return rewriter.notifyMatchFailure(op, "expected only one item"); + + auto zeroIdx = rewriter.create(op.getLoc(), 0); + auto rank = operandTy.getRank(); + llvm::SmallVector indices(rank, zeroIdx); + + Value extract = rewriter.create( + op.getLoc(), operandTy.getElementType(), operand, indices); + auto extractTy = extract.getType(); + if (isa(extractTy) && !extractTy.isInteger(64)) { + if (torchDTy.isSignlessInteger()) { + extract = rewriter.create( + op.getLoc(), rewriter.getIntegerType(64), extract); + } else { + extract = rewriter.create( + op.getLoc(), rewriter.getIntegerType(64), extract); + } + } + + if (isa(extractTy) && !extractTy.isF64()) { + extract = rewriter.create(op.getLoc(), + rewriter.getF64Type(), extract); + } + + rewriter.replaceOp(op, extract); + return success(); + } +}; + class ConvertAtenShapeToTensorPatternOp : public OpConversionPattern { public: @@ -70,6 +111,7 @@ class ConvertTorchToTensor ConversionTarget target(*context); target.addLegalDialect(); target.addLegalDialect(); + target.addIllegalOp(); target.addIllegalOp(); TypeConverter typeConverter; @@ -77,7 +119,8 @@ class ConvertTorchToTensor TorchConversion::setupBackendTypeConversion(target, typeConverter); RewritePatternSet patterns(context); - patterns.add(typeConverter, context); + patterns.add( + typeConverter, context); if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 91d468a6941f..673d7083f585 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -20,6 +20,7 @@ #include "torch-mlir/Conversion/TorchToLinalg/TorchToLinalg.h" #include "torch-mlir/Conversion/TorchToSCF/TorchToSCF.h" #include "torch-mlir/Conversion/TorchToTMTensor/TorchToTMTensor.h" +#include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h" #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #ifdef TORCH_MLIR_ENABLE_STABLEHLO @@ -76,6 +77,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( pm.addNestedPass(createConvertTorchToLinalgPass()); pm.addNestedPass(createConvertTorchToSCFPass()); pm.addNestedPass(createConvertTorchToArithPass()); + pm.addNestedPass(createConvertTorchToTensorPass()); pm.addPass(createConvertTorchConversionToMLProgramPass()); pm.addNestedPass(memref::createExpandOpsPass()); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 4f8c04b1d2c5..d7dba54a0580 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -130,6 +130,10 @@ 'ViewCollapseDynamicWithAtenSizeIntModule_basic', # END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(int) [TensorVariable()] {} + # ERROR: torch._dynamo.exc.Unsupported: Tensor.item + 'AtenItemIntOpModule_basic', + 'AtenItemFpOpModule_basic', + # ERROR: torch._dynamo.exc.Unsupported: call_method ListVariable() sort [] {'reverse': ConstantVariable(bool)} 'SortIntListReverse_basic', diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py index 74717d99fb4e..303c3f0a801a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py @@ -428,3 +428,41 @@ def forward(self, val): @register_test_case(module_factory=lambda: AtenIntTensorCharDtypeModule()) def AtenIntTensorCharDtypeModule_basic(module, tu: TestUtils): module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8)) + +# ============================================================================== + +class AtenItemIntOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.int8, True), + ]) + + def forward(self, val): + return int(val) + +@register_test_case(module_factory=lambda: AtenItemIntOpModule()) +def AtenItemIntOpModule_basic(module, tu: TestUtils): + module.forward(tu.randint(low=-100, high=100).to(dtype=torch.int8)) + +# ============================================================================== + +class AtenItemFpOpModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.float, True), + ]) + + def forward(self, val): + return float(val) + +@register_test_case(module_factory=lambda: AtenItemFpOpModule()) +def AtenItemFpOpModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1)) From 34f6948533287e67801384cb00f7150edc1461a5 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 31 Jan 2024 15:09:36 -0800 Subject: [PATCH 152/283] [torch] Support `!countIncludePad` when unpadded for average pool (#2836) We do not support average pool when `countIncludePad is set to false. However if the input is unpadded then the setting of the boolean is unneeded. Extended use by checking if padding is zero before rejecting the lowering. --- lib/Conversion/TorchToLinalg/Pooling.cpp | 5 +++- .../torch_mlir_e2e_test/test_suite/pooling.py | 24 ++++++++++++++++++- 2 files changed, 27 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index eed79072d0f9..76100c2c0e71 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -557,7 +557,10 @@ class ConvertAtenAvgPoolOp : public OpConversionPattern { m_TorchConstantBool(&countIncludePad))) return rewriter.notifyMatchFailure( op, "count_include_pad must be a constant"); - if (!countIncludePad) { + + // If the padding is zero then there is no padding to include. + if (!countIncludePad && + !llvm::all_of(paddingInts, [](int64_t p) { return p == 0; })) { return rewriter.notifyMatchFailure( op, "unimplemented: count_include_pad is expected to be true"); } diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index 1d3481196e5f..d26d9b121cf3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -847,6 +847,28 @@ def forward(self, x): def AvgPool2dCeilModeTrueModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0)) +class AvgPool2dWithoutPadModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.ap2d = torch.nn.AvgPool2d(kernel_size=[6, 8], + stride=[2, 2], + padding=[0, 0], + ceil_mode=False, + count_include_pad=False, + divisor_override=None) + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, x): + return self.ap2d(x) + +@register_test_case(module_factory=lambda: AvgPool2dWithoutPadModule()) +def AvgPool2dWithoutPadModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 4, 20, 20, low=0.5, high=1.0)) # ============================================================================== @@ -1141,4 +1163,4 @@ def forward(self,x): module_factory=lambda: AdaptiveMaxPool2dStaticWithIndices()) def AdaptiveMaxPool2dStaticWithIndices_basic( module, tu: TestUtils): - module.forward(tu.rand(1, 512, 10, 16)) \ No newline at end of file + module.forward(tu.rand(1, 512, 10, 16)) From 04be6ba77308a836779c237be984077a09f6683d Mon Sep 17 00:00:00 2001 From: Dave Liddell <44620210+daveliddell@users.noreply.github.com> Date: Wed, 31 Jan 2024 22:58:43 -0700 Subject: [PATCH 153/283] Make the onnx importer more robust for internal/external and large models (#2794) Fix for https://github.com/llvm/torch-mlir/issues/2765 The onnx docs say that you can't do shape inference using the in-memory API for models > 2 GB. This fix replaces that API with the file-based API. Since the new API generates an intermediate file, also added a --keep switch to keep that file, which I delete by default. --------- Co-authored-by: Dave Liddell --- .../torch_mlir/tools/import_onnx/__main__.py | 106 ++++++++++++- .../python/onnx_importer/command_line_test.py | 144 ++++++++++++++++++ 2 files changed, 244 insertions(+), 6 deletions(-) create mode 100644 test/python/onnx_importer/command_line_test.py diff --git a/python/torch_mlir/tools/import_onnx/__main__.py b/python/torch_mlir/tools/import_onnx/__main__.py index b300b4100b3e..399073be1570 100644 --- a/python/torch_mlir/tools/import_onnx/__main__.py +++ b/python/torch_mlir/tools/import_onnx/__main__.py @@ -14,7 +14,9 @@ python -m torch_mlir.tools.import_onnx ... """ import argparse +import os from pathlib import Path +import shutil import sys import onnx @@ -27,8 +29,8 @@ ) -def main(args): - model_proto = load_onnx_model(args.input_file) +def main(args: argparse.Namespace): + model_proto = load_onnx_model(args) context = Context() torch_d.register_dialect(context) model_info = onnx_importer.ModelInfo(model_proto) @@ -48,13 +50,84 @@ def main(args): print(m.get_asm(assume_verified=not args.no_verify)) -def load_onnx_model(file_path: Path) -> onnx.ModelProto: - raw_model = onnx.load(file_path) - inferred_model = onnx.shape_inference.infer_shapes(raw_model) +def load_onnx_model(args: argparse.Namespace) -> onnx.ModelProto: + # Do shape inference two ways. First, attempt in-memory to avoid redundant + # loading and the need for writing a temporary file somewhere. If that + # fails, typically because of the 2 GB protobuf size limit, try again via + # files. See + # https://github.com/onnx/onnx/blob/main/docs/PythonAPIOverview.md#shape-inference-a-large-onnx-model-2gb + # for details about the file-based technique. + + # Make a temp dir for all the temp files we'll be generating as a side + # effect of infering shapes. For now, the only file is a new .onnx holding + # the revised model with shapes. + # + # TODO: If the program temp_dir is None, we should be using an ephemeral + # temp directory instead of a hard-coded path in order to avoid data races + # by default. + input_dir = os.path.dirname(os.path.abspath(args.input_file)) + temp_dir = ( + Path(input_dir if args.temp_dir is None else args.temp_dir) + / "onnx-importer-temp" + ) + shutil.rmtree(temp_dir, ignore_errors=True) + temp_dir.mkdir(exist_ok=True) + + # Load the model, with possible external data coming from the default + # location, or the location specified on the conmand line. + if args.data_dir is None: + raw_model = onnx.load(args.input_file) + else: + raw_model = onnx.load(args.input_file, load_external_data=False) + onnx.load_external_data_for_model(raw_model, args.data_dir) + + # Run the checker to test whether the file is above the threshold for + # in-memory shape inference. If not, go ahead and do the shape inference. + try: + onnx.checker.check_model(raw_model) + inferred_model = onnx.shape_inference.infer_shapes(raw_model) + return inferred_model + except ValueError: + pass + + # The following code was an attempt to work around the bug where models + # with external data produce invalid output shapes after infer_shapes_path. + # It works with small models but threw an error for llama seeming to + # indicate that the protobuf is corrupt. + # + # temp_raw_file = temp_dir / "raw.onnx" + # onnx.save(raw_model, temp_raw_file, save_as_external_data=False) + # onnx.shape_inference.infer_shapes_path(temp_raw_file, temp_inferred_file) + # inferred_model = onnx.load(temp_inferred_file) + + # Model is too big for in-memory inference: do file-based shape inference + # to a temp file. + temp_inferred_file = temp_dir / "inferred.onnx" + onnx.shape_inference.infer_shapes_path(args.input_file, temp_inferred_file) + + # Sanity check the shape-inferred model to be sure we have a good model + # for the importer. This call uses the file-based method, as the + # in-memory method (passing the loaded model) fails due to the 2 GB limit. + # + # TODO: this call throws an exception because it can't find the external + # data files, and there doesn't appear to be a way to let the checker know + # where to find them. + # + # onnx.checker.check_model(temp_inferred_file) + + # Load the temp file and the external data. + inferred_model = onnx.load(temp_inferred_file, load_external_data=False) + data_dir = Path(input_dir if args.temp_dir is None else args.data_dir) + onnx.load_external_data_for_model(inferred_model, data_dir) + + # Remove the inferred shape file unless asked to keep it + if not args.keep_temps: + shutil.rmtree(temp_dir) + return inferred_model -def parse_arguments(argv=None): +def parse_arguments(argv=None) -> argparse.Namespace: parser = argparse.ArgumentParser(description="Torch-mlir ONNX import tool") parser.add_argument("input_file", help="ONNX protobuf input", type=Path) parser.add_argument( @@ -65,6 +138,27 @@ def parse_arguments(argv=None): action="store_true", help="Disable verification prior to printing", ) + parser.add_argument( + "--keep-temps", action="store_true", help="Keep intermediate files" + ) + parser.add_argument( + "--temp-dir", + help="Pre-existing directory in which to create temporary files." + ' For example, to place temporaries under the directory "foo/bar",' + ' specify --temp-dir=foo/bar. "foo/bar" must already exist.' + " Defaults to the directory of the input file.", + type=Path, + ) + parser.add_argument( + "--data-dir", + help="Path between CWD and the base directory of the data," + " excluding the directories given in the 'location' argument of " + " convert_model_to_external_data. For example, if 'location' was" + ' "data/data.bin" and the relative path from CWD to that .bin file is' + ' a/b/data/data.bin, then set data-dir to "a/b".' + " Defaults to the directory of the input file.", + type=Path, + ) args = parser.parse_args(argv) return args diff --git a/test/python/onnx_importer/command_line_test.py b/test/python/onnx_importer/command_line_test.py new file mode 100644 index 000000000000..32dc0cbeb22f --- /dev/null +++ b/test/python/onnx_importer/command_line_test.py @@ -0,0 +1,144 @@ +# Based on code Copyright (c) Advanced Micro Devices, Inc. +# +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s --output %t + +from pathlib import Path + +import logging +import shutil +import sys +import subprocess +import unittest +import unittest.mock + +import onnx + +from torch_mlir.tools.import_onnx import __main__ + +# For ONNX models + +import numpy +from onnx import numpy_helper, TensorProto +from onnx.helper import ( + make_model, make_node, make_graph, + make_tensor_value_info) +from onnx.external_data_helper import convert_model_to_external_data +from onnx.checker import check_model + +# Accept the output path on the command line or default to a sibling +# to this file. We have to pop this off explicitly or else unittest +# won't understand. +if len(sys.argv) > 1 and sys.argv[1] == "--output": + OUTPUT_PATH = Path(sys.argv[2]) + del sys.argv[1:3] +else: + OUTPUT_PATH = Path(__file__).resolve().parent / "output" + +OUTPUT_PATH.mkdir(parents=True, exist_ok=True) + + +def const_model() -> onnx.ModelProto: + # Note: data_path must be relative to model_file + + const = make_node( + 'Constant', [], ['c_shape'], 'const', + value=numpy_helper.from_array(numpy.array([4], dtype=numpy.int64))) + cofshape = make_node( + 'ConstantOfShape', ['c_shape'], ['c_out'], 'cofshape', + value=numpy_helper.from_array(numpy.array([1], dtype=numpy.int64))) + + outval = make_tensor_value_info('c_out', TensorProto.INT64, [None]) + graph = make_graph([const, cofshape], 'constgraph', [], [outval]) + + onnx_model = make_model(graph) + check_model(onnx_model) + return onnx_model + + +def linear_model() -> onnx.ModelProto: + # initializers + k_dim = 32 + value = numpy.arange(k_dim).reshape([k_dim, 1]) + value = numpy.asarray(value, dtype=numpy.float32) + A = numpy_helper.from_array(value, name='A') + + value = numpy.array([0.4], dtype=numpy.float32).reshape([1, 1]) + C = numpy_helper.from_array(value, name='C') + + # the part which does not change + X = make_tensor_value_info('X', TensorProto.FLOAT, [1, k_dim]) + Y = make_tensor_value_info('Y', TensorProto.FLOAT, [None, None]) + node1 = make_node('MatMul', ['X', 'A'], ['AX']) + node2 = make_node('Add', ['AX', 'C'], ['Y']) + graph = make_graph([node1, node2], 'lr', [X], [Y], [A, C]) + onnx_model = make_model(graph) + check_model(onnx_model) + return onnx_model + + +ALL_MODELS = [ + const_model, + linear_model +] + + +class CommandLineTest(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.test_dir = OUTPUT_PATH / "command-line" + shutil.rmtree(cls.test_dir, ignore_errors=True) + cls.test_dir.mkdir(parents=True, exist_ok=True) + + def get_run_path(self, model_name: str) -> Path: + run_path = CommandLineTest.test_dir / model_name + run_path.mkdir(exist_ok=True) + return run_path + + def run_model_intern(self, onnx_model: onnx.ModelProto, model_name: str): + run_path = self.get_run_path(model_name) + model_file = run_path / f"{model_name}-i.onnx" + mlir_file = run_path / f"{model_name}-i.torch.mlir" + onnx.save(onnx_model, model_file) + args = __main__.parse_arguments([ + str(model_file), "-o", str(mlir_file)]) + __main__.main(args) + + def run_model_extern(self, onnx_model: onnx.ModelProto, model_name: str): + run_path = self.get_run_path(model_name) + model_file = run_path / f"{model_name}-e.onnx" + mlir_file = run_path / f"{model_name}-e.torch.mlir" + data_dir_name = f"{model_name}-data" + model_data_dir = run_path / data_dir_name + model_data_dir.mkdir(exist_ok=True) + convert_model_to_external_data( + onnx_model, all_tensors_to_one_file=True, + location=data_dir_name + "/data.bin", + size_threshold=48, + convert_attribute=True) + onnx.save(onnx_model, model_file) + temp_dir = run_path / "temp" + temp_dir.mkdir(exist_ok=True) + args = __main__.parse_arguments([ + str(model_file), "-o", str(mlir_file), "--keep-temps", "--temp-dir", + str(temp_dir), "--data-dir", str(run_path)]) + __main__.main(args) + + def test_all(self): + for model_func in ALL_MODELS: + model_name = model_func.__name__ + model = model_func() + with self.subTest(f"model {model_name}", model_name=model_name): + with self.subTest("Internal data"): + self.run_model_intern(model, model_name) + with self.subTest("External data"): + self.run_model_extern(model, model_name) + + +if __name__ == "__main__": + logging.basicConfig(level=logging.DEBUG) + unittest.main() From c7d7d7f00494b588c31ac617e91354b12709009d Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Wed, 31 Jan 2024 22:07:06 -0800 Subject: [PATCH 154/283] [Bazel] Add TorchToTensor dep to TorchMLIRTorchConversionPasses (#2847) Fixes bazel build error: ``` ERROR: /root/.cache/bazel/_bazel_root/b89349c08f7224396763d14fe35cba11/external/torch-mlir/BUILD.bazel:547:11: Compiling lib/Dialect/TorchConversion/Transforms/Passes.cpp failed: (Exit 1): clang failed: error executing command /usr/lib/llvm-16/bin/clang -U_FORTIFY_SOURCE -fstack-protector -Wall -Wthread-safety -Wself-assign -Wunused-but-set-parameter -Wno-free-nonheap-object -fcolor-diagnostics -fno-omit-frame-pointer ... (remaining 224 arguments skipped) Use --sandbox_debug to see verbose messages from the sandbox and retain the sandbox build root for debugging external/torch-mlir/lib/Dialect/TorchConversion/Transforms/Passes.cpp:23:10: fatal error: 'torch-mlir/Conversion/TorchToTensor/TorchToTensor.h' file not found #include "torch-mlir/Conversion/TorchToTensor/TorchToTensor.h" ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ 1 error generated. Target @torch-mlir//:torch-mlir-opt failed to build ``` Bazel CI: https://github.com/sjain-stanford/torch-mlir/actions/runs/7735724133/job/21091865352 --- utils/bazel/torch-mlir-overlay/BUILD.bazel | 1 + 1 file changed, 1 insertion(+) diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index 493383cf9161..c37023c5e31f 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -564,6 +564,7 @@ cc_library( ":TorchMLIRTorchToSCF", ":TorchMLIRTorchToStablehlo", ":TorchMLIRTorchToTMTensor", + ":TorchMLIRTorchToTensor", ":TorchMLIRTorchToTosa", "@llvm-project//mlir:ConversionPasses", "@llvm-project//mlir:FuncDialect", From 29baa813bd9cb7aff52e81babf6fae55e0717524 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 1 Feb 2024 14:35:21 -0800 Subject: [PATCH 155/283] [onnx] Fix `pool` lowering for non-symmetric padding (#2837) `torch` requires that padding be symmetric for pooling operations. To support non-symmetric pad we need to separately materialize out the padding operation. --------- Co-authored-by: James Newling --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 69 +++++++++++++-- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 85 ++++++++++++++----- 2 files changed, 128 insertions(+), 26 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index df20a83515bf..baebb9321567 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -262,30 +262,87 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( if (!maybeRank) return rewriter.notifyMatchFailure(binder.op, "Unimplemented: unranked tensor"); - unsigned rank = *maybeRank; + int64_t rank = *maybeRank; + int64_t spatial = rank - 2; SmallVector kernel, padding, strides, dilations; if (binder.s64IntegerArrayAttr(kernel, "kernel_shape", {})) return rewriter.notifyMatchFailure(binder.op, "kernel_shape bind failure"); - if (kernel.size() != rank - 2) + if (kernel.size() != static_cast(spatial)) return rewriter.notifyMatchFailure( binder.op, "kernel list size does not match the number of axes"); - if (binder.s64IntegerArrayAttr(padding, "pads", {0})) + if (binder.s64IntegerArrayAttr(padding, "pads", {})) return rewriter.notifyMatchFailure(binder.op, "pads bind failure"); - if (padding.size() != 1 && padding.size() != 2 * (rank - 2)) + if (!padding.empty() && + padding.size() != static_cast(2 * spatial)) return rewriter.notifyMatchFailure( binder.op, "padding list must contain (begin,end) pair for each " "spatial axis"); - if (binder.s64IntegerArrayAttr(strides, "strides", {1})) + if (binder.s64IntegerArrayAttr(strides, "strides", {})) return rewriter.notifyMatchFailure(binder.op, "strides bind failure"); - if (strides.size() != 1 && strides.size() != rank - 2) + if (!strides.empty() && strides.size() != static_cast(spatial)) return rewriter.notifyMatchFailure( binder.op, "strides list size does not match the number of axes"); if (binder.s64IntegerArrayAttr(dilations, "dilations", {})) return rewriter.notifyMatchFailure(binder.op, "dilations bind failure"); + if (padding.empty()) + padding.resize(spatial, 0); + if (strides.empty()) + strides.resize(spatial, 1); + if (dilations.empty()) + dilations.resize(spatial, 1); + + // If the padding is symmetric we can push the padding operation to the + // torch operator. + if (padding.size() == static_cast(2 * spatial)) { + bool equal = true; + for (int i = 0; i < spatial; ++i) { + equal = equal && (padding[i] == padding[i + spatial]); + } + if (equal) + padding.resize(spatial); + } + + // Torch pool operators require equal padding on each size of each + // dimension so we materialize the padding behavior explicitly and set + // the padding to 0. + if (padding.size() == static_cast(2 * spatial)) { + auto operandTy = cast(operand.getType()); + llvm::SmallVector shuffledPadding(spatial * 2); + llvm::SmallVector paddedShape(operandTy.getSizes()); + shuffledPadding.resize(2 * rank); + for (int i = 0; i < spatial; ++i) { + paddedShape[i + 2] += padding[i] + padding[i + spatial]; + shuffledPadding[2 * i] = padding[i]; + shuffledPadding[2 * i + 1] = padding[i + spatial]; + } + + Value shuffledPaddingList = + createConstantIntList(binder, rewriter, padding); + Value zero; + if (resultType.getDtype().isa()) { + zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr( + std::numeric_limits::lowest())); + } else if (resultType.getDtype().isa()) { + zero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr( + std::numeric_limits::lowest())); + } + + auto paddedInputTy = rewriter.getType( + paddedShape, operandTy.getDtype()); + operand = rewriter.create( + binder.getLoc(), paddedInputTy, operand, shuffledPaddingList, + zero); + padding.clear(); + padding.resize(spatial, 0); + } + Value kernelSizeList = createConstantIntList(binder, rewriter, kernel); Value paddingList = createConstantIntList(binder, rewriter, padding); Value stridesList = createConstantIntList(binder, rewriter, strides); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 449b7e4feb32..92fbe86caed4 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -251,13 +251,17 @@ func.func @test_maxpool_2d_default(%arg0: !torch.vtensor<[1,3,32,32],f32>) -> !t // CHECK: %[[I2:.*]] = torch.constant.int 2 // CHECK: %[[I2_1:.*]] = torch.constant.int 2 // CHECK: %[[LIST22:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[I0:.*]] = torch.constant.int 0 - // CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0]] : (!torch.int) -> !torch.list - // CHECK: %[[I1:.*]] = torch.constant.int 1 - // CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list - // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[I0_0:.*]] = torch.constant.int 0 + // CHECK: %[[I0_1:.*]] = torch.constant.int 0 + // CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[I1_0:.*]] = torch.constant.int 1 + // CHECK: %[[I1_1:.*]] = torch.constant.int 1 + // CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1_0]], %[[I1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[I1_2:.*]] = torch.constant.int 1 + // CHECK: %[[I1_3:.*]] = torch.constant.int 1 + // CHECK: %[[LIST3:.*]] = torch.prim.ListConstruct %[[I1_2]], %[[I1_3]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: torch.aten.max_pool2d %arg0, %[[LIST22]], %[[LIST1]], %[[LIST0]], %[[LIST]], %[[FALSE]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,31,31],f32> + // CHECK: torch.aten.max_pool2d %arg0, %[[LIST22]], %[[LIST1]], %[[LIST0]], %[[LIST3]], %[[FALSE]] : !torch.vtensor<[1,3,32,32],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,31,31],f32> %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32],f32>) -> !torch.vtensor<[1,3,31,31],f32> return %0 : !torch.vtensor<[1,3,31,31],f32> } @@ -269,12 +273,15 @@ func.func @test_maxpool_2d_ceil(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !torch. // CHECK: %[[I3:.*]] = torch.constant.int 3 // CHECK: %[[I3_1:.*]] = torch.constant.int 3 // CHECK: %[[LIST33:.*]] = torch.prim.ListConstruct %[[I3]], %[[I3_1]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[I0:.*]] = torch.constant.int 0 - // CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0]] : (!torch.int) -> !torch.list + // CHECK: %[[I0_0:.*]] = torch.constant.int 0 + // CHECK: %[[I0_1:.*]] = torch.constant.int 0 + // CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[I2:.*]] = torch.constant.int 2 // CHECK: %[[I2_1:.*]] = torch.constant.int 2 // CHECK: %[[LIST22:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[I1_0:.*]] = torch.constant.int 1 + // CHECK: %[[I1_1:.*]] = torch.constant.int 1 + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[I1_0]], %[[I1_1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: torch.aten.max_pool2d %arg0, %[[LIST33]], %[[LIST22]], %[[LIST0]], %[[LIST]], %[[TRUE]] : !torch.vtensor<[1,1,4,4],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,1,2,2],f32> %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 1 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,1,4,4],f32>) -> !torch.vtensor<[1,1,2,2],f32> @@ -289,11 +296,18 @@ func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) -> // CHECK: %[[I2_1:.*]] = torch.constant.int 2 // CHECK: %[[I2_2:.*]] = torch.constant.int 2 // CHECK: %[[LIST222:.*]] = torch.prim.ListConstruct %[[I2]], %[[I2_1]], %[[I2_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %[[I0:.*]] = torch.constant.int 0 - // CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0]] : (!torch.int) -> !torch.list - // CHECK: %[[I1:.*]] = torch.constant.int 1 - // CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1]] : (!torch.int) -> !torch.list - // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[I0_0:.*]] = torch.constant.int 0 + // CHECK: %[[I0_1:.*]] = torch.constant.int 0 + // CHECK: %[[I0_2:.*]] = torch.constant.int 0 + // CHECK: %[[LIST0:.*]] = torch.prim.ListConstruct %[[I0_0]], %[[I0_1]], %[[I0_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[I1_0:.*]] = torch.constant.int 1 + // CHECK: %[[I1_1:.*]] = torch.constant.int 1 + // CHECK: %[[I1_2:.*]] = torch.constant.int 1 + // CHECK: %[[LIST1:.*]] = torch.prim.ListConstruct %[[I1_0]], %[[I1_1]], %[[I1_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[I1_3:.*]] = torch.constant.int 1 + // CHECK: %[[I1_4:.*]] = torch.constant.int 1 + // CHECK: %[[I1_5:.*]] = torch.constant.int 1 + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[I1_3]], %[[I1_4]], %[[I1_5]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[FALSE:.*]] = torch.constant.bool false // CHECK: torch.aten.max_pool3d %arg0, %[[LIST222]], %[[LIST1]], %[[LIST0]], %[[LIST]], %[[FALSE]] : !torch.vtensor<[1,3,32,32,32],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,3,31,31,31],f32> %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> @@ -303,21 +317,52 @@ func.func @test_maxpool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) -> // ----- // CHECK-LABEL: func.func @test_maxpool_pad -func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,112,112],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { +func.func @test_maxpool_pad(%arg0: !torch.vtensor<[1,64,111,111],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { + // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 + // CHECK: %[[INT2_0:.+]] = torch.constant.int 2 + // CHECK: %[[INT2_1:.+]] = torch.constant.int 2 + // CHECK: %[[PADI:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]], %[[INT2_0]], %[[INT2_1]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[MIN:.+]] = torch.constant.float -1.7976931348623157E+308 + // CHECK: %[[PADDED:.+]] = torch.aten.constant_pad_nd %arg0, %[[PADI]], %[[MIN]] : !torch.vtensor<[1,64,111,111],f32>, !torch.list, !torch.float -> !torch.vtensor<[1,64,114,114],f32> + // CHECK: %[[INT3:.*]] = torch.constant.int 3 + // CHECK: %[[INT3_0:.*]] = torch.constant.int 3 + // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3_0]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_1:.*]] = torch.constant.int 0 + // CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT0_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[INT2_4:.*]] = torch.constant.int 2 + // CHECK: %[[LIST3:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_4]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[INT1_0:.*]] = torch.constant.int 1 + // CHECK: %[[INT1_1:.*]] = torch.constant.int 1 + // CHECK: %[[EMPTY_LIST:.*]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: %[[FALSE:.*]] = torch.constant.bool false + // CHECK: %[[OUT:.*]] = torch.aten.max_pool2d %[[PADDED]], %[[LIST]], %[[LIST3]], %[[LIST2]], %[[EMPTY_LIST]], %[[FALSE]] : !torch.vtensor<[1,64,114,114],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,64,56,56],f32> + // CHECK: return %[[OUT]] : !torch.vtensor<[1,64,56,56],f32> + %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 0 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 2 : si64, 2 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,64,111,111],f32>) -> !torch.vtensor<[1,64,56,56],f32> + return %0 : !torch.vtensor<[1,64,56,56],f32> +} + + +// ----- + +// CHECK-LABEL: func.func @test_maxpool_symmetric_pad +func.func @test_maxpool_symmetric_pad(%arg0: !torch.vtensor<[1,64,112,112],f32>) -> !torch.vtensor<[1,64,56,56],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64} { // CHECK: %[[INT3:.*]] = torch.constant.int 3 // CHECK: %[[INT3_0:.*]] = torch.constant.int 3 // CHECK: %[[LIST:.*]] = torch.prim.ListConstruct %[[INT3]], %[[INT3_0]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[INT1:.*]] = torch.constant.int 1 // CHECK: %[[INT1_1:.*]] = torch.constant.int 1 - // CHECK: %[[INT1_2:.*]] = torch.constant.int 1 - // CHECK: %[[INT1_3:.*]] = torch.constant.int 1 - // CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_1]], %[[INT1_2]], %[[INT1_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: %[[LIST2:.*]] = torch.prim.ListConstruct %[[INT1]], %[[INT1_1]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[INT2:.*]] = torch.constant.int 2 // CHECK: %[[INT2_4:.*]] = torch.constant.int 2 // CHECK: %[[LIST3:.*]] = torch.prim.ListConstruct %[[INT2]], %[[INT2_4]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[EMPTY_LIST:.*]] = torch.prim.ListConstruct : () -> !torch.list + // CHECK: %[[INT1_2:.*]] = torch.constant.int 1 + // CHECK: %[[INT1_3:.*]] = torch.constant.int 1 + // CHECK: %[[DILATION:.*]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] : (!torch.int, !torch.int) -> !torch.list // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: %[[OUT:.*]] = torch.aten.max_pool2d %arg0, %[[LIST]], %[[LIST3]], %[[LIST2]], %[[EMPTY_LIST]], %[[FALSE]] : !torch.vtensor<[1,64,112,112],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,64,56,56],f32> + // CHECK: %[[OUT:.*]] = torch.aten.max_pool2d %arg0, %[[LIST]], %[[LIST3]], %[[LIST2]], %[[DILATION]], %[[FALSE]] : !torch.vtensor<[1,64,112,112],f32>, !torch.list, !torch.list, !torch.list, !torch.list, !torch.bool -> !torch.vtensor<[1,64,56,56],f32> // CHECK: return %[[OUT]] : !torch.vtensor<[1,64,56,56],f32> %0 = torch.operator "onnx.MaxPool"(%arg0) {torch.onnx.ceil_mode = 0 : si64, torch.onnx.kernel_shape = [3 : si64, 3 : si64], torch.onnx.pads = [1 : si64, 1 : si64, 1 : si64, 1 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,64,112,112],f32>) -> !torch.vtensor<[1,64,56,56],f32> return %0 : !torch.vtensor<[1,64,56,56],f32> From 962d5143085b2bea7db0c0e9bdc26bf5ea8db2b5 Mon Sep 17 00:00:00 2001 From: Ben Vanik Date: Thu, 1 Feb 2024 22:02:44 -0800 Subject: [PATCH 156/283] Fixing implicit double->float conversion warning. (#2850) `[build] D:\Dev\iree\third_party\torch-mlir\lib\Conversion\TorchOnnxToTorch\DefaultDomainGtoP.cpp(734): warning C4305: 'argument': truncation from 'double' to 'float'` --- lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index baebb9321567..5ebba10c9ebd 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -788,7 +788,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.tensorOperandAtIndex(b, 2) || binder.tensorResultTypeAtIndex(yType, 0) || binder.s64IntegerAttr(axis, "axis", -1) || - binder.f32FloatAttr(epsilon, "epsilon", 0.00001) || + binder.f32FloatAttr(epsilon, "epsilon", 0.00001f) || binder.s64IntegerAttr(stashType, "stash_type", 1)) return failure(); Value constEpsilon = rewriter.create( From 24b8c8672ae4366025aa8cb3155dedb58c8e3450 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Fri, 2 Feb 2024 13:46:33 -0500 Subject: [PATCH 157/283] [torch] Add folders for `torch.fill`, `torch.ones`, `torch.zeros` and `aten.getItem` (#2849) So that the CumSum Op in OPT can get the constant that it requires to be lowered to TMTensor --------- Co-authored-by: Rob Suderman Co-authored-by: Xida Ren --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 4 + .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 8 +- lib/Dialect/Torch/IR/TorchOps.cpp | 148 +++++++++++++++++- projects/pt1/e2e_testing/xfail_sets.py | 3 + .../build_tools/torch_ods_gen.py | 8 +- .../torch_mlir_e2e_test/test_suite/basic.py | 8 +- test/Dialect/Torch/canonicalize.mlir | 40 +++++ 7 files changed, 208 insertions(+), 11 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 0ae45798d188..a0ec9663b3e9 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8416,6 +8416,7 @@ def Torch_AtenOnesOp : Torch_Op<"aten.ones", [ printDefaultTorchOp(printer, *this, 5, 1); } }]; + let hasFolder = 1; } def Torch_AtenNewOnesOp : Torch_Op<"aten.new_ones", [ @@ -8471,6 +8472,7 @@ def Torch_AtenZerosOp : Torch_Op<"aten.zeros", [ printDefaultTorchOp(printer, *this, 5, 1); } }]; + let hasFolder = 1; } def Torch_AtenNewZerosOp : Torch_Op<"aten.new_zeros", [ @@ -9858,6 +9860,7 @@ def Torch_AtenItemOp : Torch_Op<"aten.item", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; } def Torch_AtenMaskedSelectOp : Torch_Op<"aten.masked_select", [ @@ -11202,6 +11205,7 @@ def Torch_AtenFullOp : Torch_Op<"aten.full", [ printDefaultTorchOp(printer, *this, 6, 1); } }]; + let hasFolder = 1; } def Torch_AtenFullLikeOp : Torch_Op<"aten.full_like", [ diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index e9221ed139d8..1161b981c09f 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1089,11 +1089,9 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, "expected result type to have a dtype"); } // resultTensorType.print(llvm::outs()); - Value resultDType = Torch::getDtypeIntValueForType( - rewriter, loc, resultTensorType.getDtype()); - - rewriter.replaceOpWithNewOp( - binder.op, resultType, operand, dim, resultDType); + Value none = rewriter.create(loc); + rewriter.replaceOpWithNewOp(binder.op, resultType, + operand, dim, none); return success(); }); patterns.onOp( diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 5877a35495e0..98de4f85b62b 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -6,9 +6,10 @@ // Also available under a BSD-style license. See LICENSE. // //===----------------------------------------------------------------------===// - +#define DEBUG_TYPE "torch-mlir-torch-dialect" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" +#include "llvm/Support/Debug.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/IR/Builders.h" @@ -2813,6 +2814,151 @@ OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenItemOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenItemOp::fold(FoldAdaptor adaptor) { + // see if we have a constant tensor + DenseElementsAttr attr; + if (matchPattern(getOperand(), m_Constant(&attr))) { + auto splat = attr.getSplatValue(); + if (auto intAttr = dyn_cast(splat)) { + return getI64IntegerAttr(getContext(), intAttr.getSInt()); + } + if (auto floatAttr = dyn_cast(splat)) { + return getF64FloatAttr(getContext(), floatAttr.getValueAsDouble()); + } + return nullptr; + } + + return nullptr; +} + +//===----------------------------------------------------------------------===// +// AtenOnesOp, AtenZerosOp, AtenFullOp +//===----------------------------------------------------------------------===// +OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) { + SmallVector sizes; + if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) { + LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: size operand is " + "not a list of constant integers.\n"); + return nullptr; + } + + Type resultType = getResult().getType(); + BaseTensorType resultTensorType = resultType.dyn_cast(); + if (!resultTensorType || !resultTensorType.hasDtype()) { + LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: result type is not " + "a tensor type or does not have a dtype.\n"); + return nullptr; + } + + ShapedType shapedty = + mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType + sizes, resultTensorType.getDtype()); + if (!shapedty) { + LLVM_DEBUG(llvm::dbgs() + << "Failing to fold AtenOnesOp: ShapedType cast failed.\n"); + return nullptr; + } + auto elementType = shapedty.getElementType(); + if (elementType.isa()) { + Attribute attribute = IntegerAttr::get(elementType, 1); + return DenseElementsAttr::get(shapedty, attribute); + } + if (elementType.isa()) { + Attribute attribute = FloatAttr::get(elementType, 1.0); + return DenseElementsAttr::get(shapedty, attribute); + } + LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: element type is " + "not integer or float.\n"); + return nullptr; +} + +OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) { + SmallVector sizes; + if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) { + LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: size operand is " + "not a list of constant integers.\n"); + return nullptr; + } + + Type resultType = getResult().getType(); + BaseTensorType resultTensorType = resultType.dyn_cast(); + if (!resultTensorType || !resultTensorType.hasDtype()) { + LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: result type is " + "not a tensor type or does not have a dtype.\n"); + return nullptr; + } + + ShapedType shapedty = + mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType + sizes, resultTensorType.getDtype()); + if (!shapedty) { + LLVM_DEBUG(llvm::dbgs() + << "Failing to fold AtenZerosOp: ShapedType cast failed.\n"); + return nullptr; + } + + auto elementType = shapedty.getElementType(); + if (elementType.isa()) { + Attribute attribute = IntegerAttr::get(elementType, 0); + return DenseElementsAttr::get(shapedty, attribute); + } + if (elementType.isa()) { + Attribute attribute = FloatAttr::get(elementType, 0.0); + return DenseElementsAttr::get(shapedty, attribute); + } + + LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: element type is " + "not integer or float.\n"); + return nullptr; +} + +OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) { + SmallVector sizes; + if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) { + LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: size operand is " + "not a list of constant integers.\n"); + return nullptr; + } + + Type resultType = getResult().getType(); + BaseTensorType resultTensorType = resultType.dyn_cast(); + if (!resultTensorType || !resultTensorType.hasDtype()) { + LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: result type is not " + "a tensor type or does not have a dtype.\n"); + return nullptr; + } + + ShapedType shapedty = + mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType + sizes, resultTensorType.getDtype()); + if (!shapedty) { + LLVM_DEBUG(llvm::dbgs() + << "Failing to fold AtenFullOp: ShapedType cast failed.\n"); + return nullptr; + } + auto elementType = shapedty.getElementType(); + if (elementType.isa()) { + int64_t value = 0; + if (matchPattern(getFillValue(), m_TorchConstantInt(&value))) { + Attribute attribute = IntegerAttr::get(elementType, value); + return DenseElementsAttr::get(shapedty, attribute); + } + } + if (elementType.isa()) { + double value = 0.0; + if (matchPattern(getFillValue(), m_TorchConstantFloat(&value))) { + Attribute attribute = FloatAttr::get(elementType, value); + return DenseElementsAttr::get(shapedty, attribute); + } + } + LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: element type is " + "not integer or float.\n"); + return nullptr; +} //===----------------------------------------------------------------------===// // AtenCeilFloatOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index d7dba54a0580..2ee5d279a9d3 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -26,6 +26,9 @@ TORCHDYNAMO_XFAIL_SET = { #### General TorchDynamo/PyTorch errors + # torch._dynamo.exc.Unsupported: Tensor.item + "CumsumModule_basic", + # TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0 # RuntimeError: Failed running call_function aten.convolution_backward(... # https://github.com/pytorch/pytorch/issues/89629 diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 43635bf2fa05..41a297ba62b8 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -564,9 +564,9 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::size : (Tensor) -> (int[])", has_canonicalizer=True) emit("aten::Bool.Tensor : (Tensor) -> (bool)") emit("aten::is_floating_point : (Tensor) -> (bool)", has_folder=True) - emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::ones : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True) emit("aten::new_ones : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)") - emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::zeros : (int[], int?, int?, Device?, bool?) -> (Tensor)", has_folder=True) emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)") @@ -618,7 +618,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)") emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)") emit_with_mutating_variants("aten::_index_put_impl : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)") - emit("aten::item : (Tensor) -> (Scalar)") + emit("aten::item : (Tensor) -> (Scalar)", has_folder=True) emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)") emit("aten::numel : (Tensor) -> (int)", has_canonicalizer=True) emit("aten::repeat : (Tensor, int[]) -> (Tensor)") @@ -669,7 +669,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::native_dropout : (Tensor, float, bool?) -> (Tensor, Tensor)") emit("aten::t : (Tensor) -> (Tensor)") emit("aten::numpy_T : (Tensor) -> (Tensor)") - emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::full : (int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)", has_folder=True) emit("aten::full_like : (Tensor, Scalar, int?, int?, Device?, bool?, int?) -> (Tensor)") emit("aten::new_full : (Tensor, int[], Scalar, int?, int?, Device?, bool?) -> (Tensor)") emit_with_mutating_variants("aten::baddbmm : (Tensor, Tensor, Tensor, Scalar, Scalar) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 91c3112135d0..c73d706f25cf 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -4092,7 +4092,13 @@ def __init__(self): ([-1, -1, -1], torch.float32, True), ]) def forward(self, val): - return torch.ops.aten.cumsum(val, 1) + # the onnx cumsum op uses a constant 1d tensor + # to specify the dimension along which to do cumsum + # we replicate that here to ensure that cumsum correctly + # trigger the relevant folders and provides TMTensor + # with a constant dimension + ones = torch.ones([1], dtype=torch.int32) + return torch.ops.aten.cumsum(val, ones.item()) @register_test_case(module_factory=lambda: CumsumModule()) def CumsumModule_basic(module, tu: TestUtils): diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 3cf82d9ed6a7..cb2ec2d14a54 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -29,6 +29,46 @@ func.func @torch.runtime.assert() { return } +// CHECK-LABEL: func.func @torch.aten.ones_item +// CHECK: %[[CONST:.*]] = torch.constant.int 1 +// CHECK: return %[[CONST]] : !torch.int +func.func @torch.aten.ones_item() -> !torch.int { + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.aten.ones %0, %int3, %none, %none, %none : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si32> + %2 = torch.aten.item %1 : !torch.vtensor<[1],si32> -> !torch.int + return %2 : !torch.int +} + +// CHECK-LABEL: func.func @torch.aten.zeros_item +// CHECK: %[[CONST:.*]] = torch.constant.int 0 +// CHECK: return %[[CONST]] : !torch.int +func.func @torch.aten.zeros_item() -> !torch.int { + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.aten.zeros %0, %int3, %none, %none, %none : !torch.list, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si32> + %2 = torch.aten.item %1 : !torch.vtensor<[1],si32> -> !torch.int + return %2 : !torch.int +} + +// CHECK-LABEL: func.func @torch.aten.full_item +// CHECK: %[[CONST:.*]] = torch.constant.int 1337 +// CHECK: return %[[CONST]] : !torch.int +func.func @torch.aten.full_item() -> !torch.int { + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 1337 + %int5 = torch.constant.int 5 + %none = torch.constant.none + %0 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list + %1 = torch.aten.full %0, %int3, %int5, %none, %none, %none : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[1],si32> + %2 = torch.aten.item %1 : !torch.vtensor<[1],si32> -> !torch.int + return %2 : !torch.int +} + // CHECK-LABEL: func.func @torch.aten.is_floating_point$fold_true // CHECK: %[[TRUE:.*]] = torch.constant.bool true // CHECK: return %[[TRUE]] : !torch.bool From d1cd117998d5d9b6d0f68784119d38a6b191463a Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Fri, 2 Feb 2024 11:02:53 -0800 Subject: [PATCH 158/283] [torch-mlir] remove trailing whitespace from md documentation (#2853) --- docs/add_ops.md | 6 ++--- ...dding_abstract_interpretation_functions.md | 10 ++++----- docs/architecture.md | 4 ++-- docs/importers/onnx_importer.md | 22 +++++++++---------- docs/ltc_backend.md | 4 ++-- docs/ltc_examples.md | 6 ++--- 6 files changed, 26 insertions(+), 26 deletions(-) diff --git a/docs/add_ops.md b/docs/add_ops.md index 0809283bbeae..225458cec631 100644 --- a/docs/add_ops.md +++ b/docs/add_ops.md @@ -1,17 +1,17 @@ # How to Add Ops to Torch-Mlir -Collected links and contacts for how to add ops to torch-mlir. +Collected links and contacts for how to add ops to torch-mlir.
Turbine Camp: Start Here -This document was previously known as `turbine-camp.md` to Nod.ai. "Turbine Camp" is part of Nod.ai's onboarding process. Welcome to turbine camp. This document originated at Nod.ai as a part of onboardding process, where new nod-ai folks learn about the architecture of our work by adding support for 2 ops to torch-mlir. I decided to put this into torch mlir because a lot of this is about torch-mlir. +This document was previously known as `turbine-camp.md` to Nod.ai. "Turbine Camp" is part of Nod.ai's onboarding process. Welcome to turbine camp. This document originated at Nod.ai as a part of onboardding process, where new nod-ai folks learn about the architecture of our work by adding support for 2 ops to torch-mlir. I decided to put this into torch mlir because a lot of this is about torch-mlir. Written & maintained by @renxida Guides by other folks that were used during the creation of this document: - [Chi Liu](https://gist.github.com/AmosLewis/dd31ab37517977b1c499d06495b4adc2) -- [Sunsoon](https://docs.google.com/document/d/1H79DwW_wnVzUU81EogwY5ueXgnl-QzKet1p2lnqPar4/edit?pli=1) +- [Sunsoon](https://docs.google.com/document/d/1H79DwW_wnVzUU81EogwY5ueXgnl-QzKet1p2lnqPar4/edit?pli=1) ## Before you begin... diff --git a/docs/adding_abstract_interpretation_functions.md b/docs/adding_abstract_interpretation_functions.md index b5e427e1adfd..eeebb9c315fa 100644 --- a/docs/adding_abstract_interpretation_functions.md +++ b/docs/adding_abstract_interpretation_functions.md @@ -4,7 +4,7 @@ As part of adding support for a Torch operator in Torch-MLIR, it is usually necessary to define a shape and dtype function so that the compiler can infer -the shapes and dtypes of result tensors for the operator. We use the +the shapes and dtypes of result tensors for the operator. We use the [abstract interpretation library](abstract_interp_lib.md) for this process. ## Step-by-step guide @@ -19,7 +19,7 @@ We will use the example of adding support for the `torch.aten.tanh` op. file is the "rosetta stone" that allows translating between e.g. `torch.aten.tanh`, `AtenTanhOp`, and the shape and dtype function signatures are: - + - `def aten〇tanh〡shape(self: List[int]) -> List[int]:` - `def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int:` @@ -39,10 +39,10 @@ We will use the example of adding support for the `torch.aten.tanh` op. But in general, you will need to write the function and test it (see the comments about "Shape, dtype, and decomposition function testing infrastructure" in `testing_framework.py`). New shape - functions should be added upstream following the example of [this PR](https://github.com/pytorch/pytorch/pull/76889), - though it can be useful to iterate locally in `abstract_interp_lib_gen.py` + functions should be added upstream following the example of [this PR](https://github.com/pytorch/pytorch/pull/76889), + though it can be useful to iterate locally in `abstract_interp_lib_gen.py` first. - + Similarly, dtype functions should ideally just be a call to the helper `promote_dtypes` defined in `library_generator.py`. However, some ops will require some extra logic to calculate the right result types. While dtypes diff --git a/docs/architecture.md b/docs/architecture.md index 8ee6bfda8a0a..4c102e140d7a 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -442,5 +442,5 @@ characteristics. ### Presentations and Talks -* 2021-10-07: MLIR ODM: Introduction to Torch-MLIR. ([recording](https://www.youtube.com/watch?v=QbNkex-gizs) and [slides](https://docs.google.com/presentation/d/1ZhzfE4EK6XV7AdQTYicrsE_OYjkER_yiB0vBeszRfzY/edit#slide=id.gf56404f79c_1_55)) -* 2022-08-20: Overview of Torch-MLIR passes. ([recording](https://www.youtube.com/watch?v=ZpwlVxsD9_U) and [slides](https://drive.google.com/file/d/1ZSlk1HGttRuVhJSxtP6spWt_hxClit2T/view)) +* 2021-10-07: MLIR ODM: Introduction to Torch-MLIR. ([recording](https://www.youtube.com/watch?v=QbNkex-gizs) and [slides](https://docs.google.com/presentation/d/1ZhzfE4EK6XV7AdQTYicrsE_OYjkER_yiB0vBeszRfzY/edit#slide=id.gf56404f79c_1_55)) +* 2022-08-20: Overview of Torch-MLIR passes. ([recording](https://www.youtube.com/watch?v=ZpwlVxsD9_U) and [slides](https://drive.google.com/file/d/1ZSlk1HGttRuVhJSxtP6spWt_hxClit2T/view)) diff --git a/docs/importers/onnx_importer.md b/docs/importers/onnx_importer.md index 796beba1f045..a0b861d6d9cb 100644 --- a/docs/importers/onnx_importer.md +++ b/docs/importers/onnx_importer.md @@ -11,8 +11,8 @@ for the reference importer which complies with the rules below. With the exception of certain special or complicated ONNX operators, most are relatively straight-forward to map, following this general procedure: -* Plan the ops you wish to support by consulting the - [ONNX operator database](https://onnx.ai/onnx/operators/). +* Plan the ops you wish to support by consulting the + [ONNX operator database](https://onnx.ai/onnx/operators/). * This database has detailed diffs wrt different support versions but at the level of detail we operate, most version diffs are inconsequential and just require a bit more pattern support. @@ -24,7 +24,7 @@ are relatively straight-forward to map, following this general procedure: corresponding with the alphabetic sort of the op and add a conversion. * Generate successful test cases: * All `onnx_importer.py` tests are dumped to the test temp dir (success - or failure). This is typically located under + or failure). This is typically located under `tools/torch-mlir/test/python/onnx_importer/Output`. The `.mlir` files under there should provide good variants to drive lit test coverage of conversion. @@ -34,25 +34,25 @@ are relatively straight-forward to map, following this general procedure: * There are often many variants of tests for checking conformance of different historic ONNX encodings, but these are often not load bearing at the MLIR level. - * Pick a handful of test cases and add them to + * Pick a handful of test cases and add them to `test/Conversion/TorchOnnxToTorch/simple_ops_x_to_y.mlir` corresponding to an alphabetic breakdown. At this time, ignore tests that are not exercising useful differences in the pattern implementations. - * (Optionally) Use `torch-mlir-opt` to validate the outputs of the new op. - First, build the project using + * (Optionally) Use `torch-mlir-opt` to validate the outputs of the new op. + First, build the project using `cmake --build build --target tools/torch-mlir/all`. This will generate the conversion binary, `torch-mlir-opt`. Then call `torch-mlir-opt` with the MLIR pass `convert-torch-onnx-to-torch`: ``` build/bin/torch-mlir-opt -convert-torch-onnx-to-torch \ -split-input-file [DESIRED_ONNX_FILE].mlir - ``` + ``` * Generate failure test cases: * Some ops have forms that do not (easily) map to torch-mlir. If you leave an op under-implemented, add a failing test case to `test/Conversion/TorchOnnxToTorch/unsupported_simple_ops.mlir`. -* Optional but recommended: Use your test case files to fuzz against the - torch-mlir backend of your choice by running a backend conversion pipeline +* Optional but recommended: Use your test case files to fuzz against the + torch-mlir backend of your choice by running a backend conversion pipeline and fixing any crashes/issues. * Send a patch with your changes. @@ -115,7 +115,7 @@ not yet implemented. The `IsolatedFromAbove` parent of the ops can contain the following metadata: -* `torch.onnx_meta.ir_version`: 64bit `IntegerAttr` corresponding to +* `torch.onnx_meta.ir_version`: 64bit `IntegerAttr` corresponding to `ModelProto.ir_version`. * `torch.onnx_meta.producer_name`: `StringAttr` corresponding to `ModelProto.producer_name`. @@ -135,7 +135,7 @@ are only minor variations of an op. Major variations should use ### Special op forms -Certain ONNX operators map to different structural components of +Certain ONNX operators map to different structural components of torch-mlir's representation: * `ConstantOfShape`: Mapped to `torch.vtensor.literal` with diff --git a/docs/ltc_backend.md b/docs/ltc_backend.md index b0177542899b..d047bbf9d812 100644 --- a/docs/ltc_backend.md +++ b/docs/ltc_backend.md @@ -103,7 +103,7 @@ At some point, the tensors will be synced in order to execute the computation -- >>> torch._lazy.mark_step() ``` -This triggers a call to `LazyGraphExecutor::SyncLiveTensorsGraph` somewhere in the guts of LTC, which collects all the `TorchMlirNode`s (technically `torch::lazy::Node`s at this point) from the current trace and +This triggers a call to `LazyGraphExecutor::SyncLiveTensorsGraph` somewhere in the guts of LTC, which collects all the `TorchMlirNode`s (technically `torch::lazy::Node`s at this point) from the current trace and creates an instance of `TorchMlirLoweringContext`. Here, the `TorchMlirNode`s are lowered to JIT via `mlir_node_lowering.cpp` and inserted into a `jit::Graph`. Next, `TorchMlirLoweringContext::Build` is executed and the final `jit::Graph` is sent to `torch_mlir::importJitFunctionAsFuncOp` to generate MLIR using the existing infrastructure from Torch-MLIR. @@ -121,7 +121,7 @@ Finally, the compiled computation is sent to `TorchMlirBackendImpl::ExecuteCompu ## Implementing a custom backend -A reference implementation of a custom backend is available [here](../python/torch_mlir/csrc/reference_lazy_backend/). +A reference implementation of a custom backend is available [here](../python/torch_mlir/csrc/reference_lazy_backend/). All the work involved with generating MLIR is handled in the base LTC backend, so vendors only need to worry about implementing `Compile`, `ExecuteComputation`, and some other minor methods to interface with the device. A pybind is needed to invoke C++ code to register the autogen PyTorch kernels and the custom backend itself. diff --git a/docs/ltc_examples.md b/docs/ltc_examples.md index b9306edce492..217761a51ebd 100644 --- a/docs/ltc_examples.md +++ b/docs/ltc_examples.md @@ -33,18 +33,18 @@ Received 1 arguments, and returned 2 results during ExecuteCompile! Results: tensor([[0.7616, 0.9640, 0.9951, 0.9993, 0.9999]], device='lazy:0') -JIT Graph: +JIT Graph: graph(%p0 : Float(1, 5)): %1 : Float(1, 5) = aten::tanh(%p0) return (%p0, %1) -MLIR: +MLIR: func.func @graph(%arg0: !torch.vtensor<[1,5],f32>) -> (!torch.vtensor<[1,5],f32>, !torch.vtensor<[1,5],f32>) { %0 = torch.aten.tanh %arg0 : !torch.vtensor<[1,5],f32> -> !torch.vtensor<[1,5],f32> return %arg0, %0 : !torch.vtensor<[1,5],f32>, !torch.vtensor<[1,5],f32> } -Input/Output Alias Mapping: +Input/Output Alias Mapping: Output: 0 -> Input param: 0 In Mark Step: true From f4562a8eaa3d13e17768f50164ede38383f7983e Mon Sep 17 00:00:00 2001 From: Gaurav Shukla Date: Mon, 5 Feb 2024 23:46:58 +0530 Subject: [PATCH 159/283] [ONNX] Fix the lowering of onnx.expand op (#2861) Signed-off-by: Gaurav Shukla --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 1 - .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 36 +++++++++---------- 2 files changed, 18 insertions(+), 19 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 1161b981c09f..05a1e5fcb15d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1387,7 +1387,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Torch::BaseTensorType shapeType = shape.getType().cast(); SmallVector selectSizes; - selectSizes.push_back(1); Type selectResultType = shapeType.getWithSizesAndDtype( llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype()); // Variable to store 1-D onnx shape tensor, shapeSizes[0] has the diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 797f9b6c2054..e757e3776d1b 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1040,11 +1040,11 @@ func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !tor -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si32> -> !torch.int + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> + // CHECK: torch.aten.item %0 : !torch.vtensor<[],si32> -> !torch.int // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[1],si32> - // CHECK: torch.aten.item %2 : !torch.vtensor<[1],si32> -> !torch.int + // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> + // CHECK: torch.aten.item %2 : !torch.vtensor<[],si32> -> !torch.int // CHECK: torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list // CHECK: torch.aten.broadcast_to %arg0, %4 : !torch.vtensor<[1,4],f32>, !torch.list -> !torch.vtensor<[3,4],f32> %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32> @@ -1057,14 +1057,14 @@ func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !tor func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: torch.aten.item %2 : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %4 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: torch.aten.item %4 : !torch.vtensor<[],si64> -> !torch.int // CHECK: torch.prim.ListConstruct %1, %3, %5 : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: torch.aten.broadcast_to %arg0, %6 : !torch.vtensor<[3,1],f32>, !torch.list -> !torch.vtensor<[2,3,6],f32> %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32> @@ -1077,17 +1077,17 @@ func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !tor func.func @test_expand_dim3_shape4(%arg0: !torch.vtensor<[1,3,1],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,3,3,3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: torch.aten.item %0 : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %2 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: torch.aten.item %2 : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %4 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: torch.aten.item %4 : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: torch.aten.item %6 : !torch.vtensor<[],si64> -> !torch.int // CHECK: torch.prim.ListConstruct %1, %3, %5, %7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %9 = torch.aten.broadcast_to %arg0, %8 : !torch.vtensor<[1,3,1],f32>, !torch.list -> !torch.vtensor<[3,3,3,3],f32> %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,3,1],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,3,3,3],f32> From b3a56c0711fcd49698ebaa73173fc7fcd986cf34 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Mon, 5 Feb 2024 15:13:43 -0500 Subject: [PATCH 160/283] Update add_ops to mention llvm-project/mlir/utils/generate-test-checks.py (#2862) --- docs/add_ops.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/docs/add_ops.md b/docs/add_ops.md index 225458cec631..1805f1700b47 100644 --- a/docs/add_ops.md +++ b/docs/add_ops.md @@ -70,6 +70,11 @@ Helpful examples: - [A Dec 2023 example where an ONNX op is implemented](https://github.com/llvm/torch-mlir/pull/2641/files#diff-b584b152020af6d2e5dbf62a08b2f25ed5afc2c299228383b9651d22d44b5af4R493) - [Vivek's example of ONNX op lowering](https://github.com/llvm/torch-mlir/commit/dc9ea08db5ac295b4b3f91fc776fef6a702900b9) +## List of Tools you may need to use (this will be incorporated into the above instructions later) + +- Generate FILECHECK tests from MLIR test cases: `torch-mlir-opt -convert- /tmp/your_awesome_testcase.mlir | externals/llvm-project/mlir/utils/generate-test-checks.py +`. Please don't just paste the generated tests - reference them to write your own + ## Contacts People who've worked on this for a while - Vivek (@vivek97 on discord) From cb52c4b3cc81f39bc2306d064ae2789a088e441f Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 5 Feb 2024 14:23:46 -0800 Subject: [PATCH 161/283] [onnx] Fix `onnx-to-torch` lowering for flatten shape (#2834) The existing `flatten` lowering did not define what the intermediate shape was. This could result in failures to lower further to linalg as the intermediate shape was unknown. Added a shape refinement section. --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 25 ++++++++--- .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 42 +++++++++---------- 2 files changed, 40 insertions(+), 27 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 05a1e5fcb15d..06a9662c9c2d 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -501,7 +501,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( auto message = llvm::formatv("unimplemented support for the given " "dtype conversion (onnx 'type' = {0})", dtypeIntOnnx); - llvm::errs() << message << "\n"; auto y = rewriter.notifyMatchFailure(binder.op, message); return y; @@ -1444,16 +1443,30 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.tensorResultType(resultType)) return failure(); + auto operandTy = cast(operand.getType()); + llvm::SmallVector shape(operandTy.getSizes()); + int64_t rank = shape.size(); + // If axis is negative, count from the right instead of left - int64_t rank = - cast(operand.getType()).getSizes().size(); if (axis < 0) axis = rank + axis; - Value collapsedRight; - auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation( - binder.op->getContext()); + // We collapse in the dimensions to the right of the axis. + for (int i = axis + 1; i < rank; ++i) { + bool dynamic = shape[axis] == Torch::kUnknownSize || + shape[i] == Torch::kUnknownSize; + if (dynamic) { + shape[axis] = Torch::kUnknownSize; + } else { + shape[axis] = shape[axis] * shape[i]; + } + } + shape.resize(axis + 1, 1); + + auto baseType = rewriter.getType( + shape, operandTy.getDtype()); + Value collapsedRight; if (axis >= rank) { // If the right range is empty, add a dim of size 1 to the // right side of the shape: diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index e757e3776d1b..a71d4e428e18 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1311,23 +1311,23 @@ func.func @dense_constant() -> () attributes {torch.onnx_meta.ir_version = 8 : s func.func @test_flatten_4d_axis_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2 // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3 - // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3,20],f32> // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1 - // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32> + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,3,20],f32>, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32> %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 2 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> return %0 : !torch.vtensor<[6,20],f32> } // ----- -// CHECK-LABEL: @test_flatten_4d_axis_0 +// // CHECK-LABEL: @test_flatten_4d_axis_0 func.func @test_flatten_4d_axis_0(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0 // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3 - // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[120],f32> // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0 - // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,120],f32> + // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor<[120],f32>, !torch.int -> !torch.vtensor<[1,120],f32> %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> return %0 : !torch.vtensor<[1,120],f32> } @@ -1337,10 +1337,10 @@ func.func @test_flatten_4d_axis_0(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torc // CHECK-LABEL: @test_flatten_4d_axis_4 func.func @test_flatten_4d_axis_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[120,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 4 - // CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int -> !torch.vtensor<[2,3,4,5,1],f32> // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 3 - // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[120,1],f32> + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,3,4,5,1],f32>, !torch.int, !torch.int -> !torch.vtensor<[120,1],f32> %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 4 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[120,1],f32> return %0 : !torch.vtensor<[120,1],f32> } @@ -1351,10 +1351,10 @@ func.func @test_flatten_4d_axis_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torc func.func @test_flatten_4d_axis_negative_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2 // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3 - // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3,20],f32> // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 1 - // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32> + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,3,20],f32>, !torch.int, !torch.int -> !torch.vtensor<[6,20],f32> %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -2 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> return %0 : !torch.vtensor<[6,20],f32> } @@ -1365,10 +1365,10 @@ func.func @test_flatten_4d_axis_negative_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) func.func @test_flatten_4d_axis_negative_1(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[24,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 3 // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3 - // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3,4,5],f32> // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 2 - // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[24,5],f32> + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[24,5],f32> %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[24,5],f32> return %0 : !torch.vtensor<[24,5],f32> } @@ -1379,9 +1379,9 @@ func.func @test_flatten_4d_axis_negative_1(%arg0: !torch.vtensor<[2,3,4,5],f32>) func.func @test_flatten_4d_axis_negative_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0 // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 3 - // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3,4,5],f32>, !torch.int, !torch.int -> !torch.vtensor<[120],f32> // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0 - // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,120],f32> + // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor<[120],f32>, !torch.int -> !torch.vtensor<[1,120],f32> %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -4 : si64} : (!torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[1,120],f32> return %0 : !torch.vtensor<[1,120],f32> } @@ -1392,10 +1392,10 @@ func.func @test_flatten_4d_axis_negative_4(%arg0: !torch.vtensor<[2,3,4,5],f32>) func.func @test_flatten_2d_axis_1(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 1 // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 1 - // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0 - // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,3],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,3],f32> %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2,3],f32>) -> !torch.vtensor<[2,3],f32> return %0 : !torch.vtensor<[2,3],f32> } @@ -1406,9 +1406,9 @@ func.func @test_flatten_2d_axis_1(%arg0: !torch.vtensor<[2,3],f32>) -> !torch.vt func.func @test_flatten_1d_axis_0(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0 // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0 - // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2],f32> // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0 - // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,2],f32> + // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[1,2],f32> %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> return %0 : !torch.vtensor<[1,2],f32> } @@ -1419,9 +1419,9 @@ func.func @test_flatten_1d_axis_0(%arg0: !torch.vtensor<[2],f32>) -> !torch.vten func.func @test_flatten_1d_axis_negative_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 0 // CHECK-DAG: %[[RIGHT_END:.*]] = torch.constant.int 0 - // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[CR:.*]] = torch.prims.collapse %arg0, %[[RIGHT_START]], %[[RIGHT_END]] : !torch.vtensor<[2],f32>, !torch.int, !torch.int -> !torch.vtensor<[2],f32> // CHECK-DAG: %[[LEFT_INDEX:.*]] = torch.constant.int 0 - // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor, !torch.int -> !torch.vtensor<[1,2],f32> + // CHECK: torch.aten.unsqueeze %[[CR]], %[[LEFT_INDEX]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[1,2],f32> %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = -1 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2],f32> return %0 : !torch.vtensor<[1,2],f32> } @@ -1431,10 +1431,10 @@ func.func @test_flatten_1d_axis_negative_1(%arg0: !torch.vtensor<[2],f32>) -> !t // COM: CHECK-LABEL: @test_flatten_1d_axis_1 func.func @test_flatten_1d_axis_1(%arg0: !torch.vtensor<[2],f32>) -> !torch.vtensor<[2,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_INDEX:.*]] = torch.constant.int 1 - // CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor + // CHECK-DAG: %[[CR:.*]] = torch.aten.unsqueeze %arg0, %[[RIGHT_INDEX]] : !torch.vtensor<[2],f32>, !torch.int -> !torch.vtensor<[2,1],f32> // CHECK-DAG: %[[LEFT_START:.*]] = torch.constant.int 0 // CHECK-DAG: %[[LEFT_END:.*]] = torch.constant.int 0 - // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor, !torch.int, !torch.int -> !torch.vtensor<[2,1],f32> + // CHECK: torch.prims.collapse %[[CR]], %[[LEFT_START]], %[[LEFT_END]] : !torch.vtensor<[2,1],f32>, !torch.int, !torch.int -> !torch.vtensor<[2,1],f32> %0 = torch.operator "onnx.Flatten"(%arg0) {torch.onnx.axis = 1 : si64} : (!torch.vtensor<[2],f32>) -> !torch.vtensor<[2,1],f32> return %0 : !torch.vtensor<[2,1],f32> } From e3faef522460054aafe7911469abef17545adedf Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 5 Feb 2024 16:09:41 -0800 Subject: [PATCH 162/283] [onnx] Convert `onnx.QLinearConv` to `torch` (#2851) Leaning on the QDQ functionality in torch we can support the QLinearConv operation by piggybacking through `torch.Convolution`. This includes some changes such as allowing the `onnx` rewriter to run recursively. Doing so allows `QLinearConv` to decopmose to `onnx.Convolution` which is then lowered to `torch`. --- .../Conversion/TorchOnnxToTorch/Patterns.h | 5 +- .../Conversion/TorchOnnxToTorch/Utils.h | 2 + .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 2 +- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 139 +++++++++++++++--- lib/Conversion/TorchOnnxToTorch/Utils.cpp | 21 +++ lib/Conversion/TorchToLinalg/Linear.cpp | 2 +- .../Torch/Transforms/FuseQuantizedOps.cpp | 45 +++--- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 77 ++++++++++ test/Dialect/Torch/fuse-quantized-ops.mlir | 39 ++++- 9 files changed, 285 insertions(+), 47 deletions(-) diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h index 2df6f95c8ad6..261b4df3bd09 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h @@ -251,7 +251,10 @@ class OnnxCustomOpConversionPattern OnnxCustomOpConversionPattern(MLIRContext *context, std::string domainPrefix, int64_t domainVersion) : OpConversionPattern(context), domainPrefix(std::move(domainPrefix)), - domainVersion(domainVersion) {} + domainVersion(domainVersion) { + // Onnx lowerings could produce other Onnx operations during the rewrite. + setHasBoundedRewriteRecursion(); + } LogicalResult matchAndRewrite(Torch::OperatorOp op, OpAdaptor adaptor, diff --git a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h index 058fee4da4a2..afc14a95ef13 100644 --- a/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h +++ b/include/torch-mlir/Conversion/TorchOnnxToTorch/Utils.h @@ -18,6 +18,8 @@ Value createConstantIntList(OpBinder binder, ConversionPatternRewriter &rewriter, SmallVector cstInput); +Type getQTorchTypeFromTorchIntType(Type ty); + } // namespace mlir::torch::onnx_c #endif // TORCHMLIR_CONVERSION_TORCHONNXTOTORCH_UTILS_H diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 06a9662c9c2d..9550e982b8c4 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -690,7 +690,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return failure(); }); patterns.onOp( - "Conv", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "Conv", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { std::string autoPad; if (binder.customOpNameStringAttr(autoPad, "auto_pad", "NOTSET")) return failure(); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index d54c1e1b9dd0..8227514b5cf5 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchOnnxToTorch/Patterns.h" +#include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" @@ -99,6 +100,117 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return failure(); }); + patterns.onOp( + "QLinearConv", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + if ((binder.tensorOperands(operands, 8) && + binder.tensorOperands(operands, 9)) || + binder.tensorResultType(resultType)) + return failure(); + Value a = operands[0]; + Value aScale = operands[1]; + Value aZp = operands[2]; + Value b = operands[3]; + Value bScale = operands[4]; + Value bZp = operands[5]; + Value cScale = operands[6]; + Value cZp = operands[7]; + Value c = operands.size() == 9 ? operands[8] : nullptr; + + auto check = [](Value v) { + auto vTy = v.getType().cast(); + return llvm::all_of(vTy.getSizes(), [](int64_t d) { return d == 1; }); + }; + if (!check(aScale) || !check(aZp) || !check(bScale) || !check(bZp) || + !check(cScale) || !check(cScale)) + return rewriter.notifyMatchFailure( + binder.op, "not supported for non per-tensor quantization"); + + auto extract = [&rewriter, &binder](Value v) { + auto vTy = v.getType().cast(); + Type extractTy = rewriter.getType(); + if (isa(vTy.getDtype())) + extractTy = rewriter.getType(); + + return rewriter.create(binder.getLoc(), extractTy, + v); + }; + + aZp = extract(aZp); + bZp = extract(bZp); + cZp = extract(cZp); + aScale = extract(aScale); + bScale = extract(bScale); + cScale = extract(cScale); + + auto make = [&rewriter, &binder](Value v, Value scale, + Value zp) -> Value { + auto ty = v.getType().cast(); + auto newTy = getQTorchTypeFromTorchIntType(ty); + return rewriter.create( + binder.getLoc(), newTy, v, scale, zp); + }; + + a = make(a, aScale, aZp); + b = make(b, bScale, bZp); + + auto cTy = rewriter.getType( + resultType.getOptionalSizes(), + rewriter.getIntegerType(32, /*issigned=*/true)); + + // TODO(suderman): insert convolution operator. + llvm::SmallVector newOperands = {a, b}; + if (c) + newOperands.push_back(c); + + cTy = rewriter.getType( + resultType.getOptionalSizes(), + rewriter.getType()); + + llvm::SmallVector newAttributes; + newAttributes.push_back( + rewriter.getNamedAttr("name", rewriter.getStringAttr("onnx.Conv"))); + for (auto namedAttr : binder.op->getAttrDictionary()) { + if (namedAttr.getName().getValue().compare("name") == 0) + continue; + llvm::errs() << namedAttr.getName() << "\n"; + newAttributes.push_back(namedAttr); + } + + c = rewriter + .create(binder.getLoc(), cTy, newOperands, + newAttributes) + .getResult(0); + + Value outScale = rewriter.create( + binder.getLoc(), rewriter.getType(), aScale, + bScale); + Value outZp = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + c = rewriter.create( + binder.getLoc(), cTy, c, outScale, outZp); + cTy = rewriter.getType( + resultType.getOptionalSizes(), rewriter.getF32Type()); + + c = rewriter.create(binder.getLoc(), cTy, + c); + cTy = dyn_cast( + getQTorchTypeFromTorchIntType(resultType)); + Value dtyVal = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr( + rewriter.getIntegerType(64), + static_cast( + Torch::getScalarTypeForType(cTy.getDtype())))); + c = rewriter.create( + binder.getLoc(), cTy, c, cScale, cZp, dtyVal); + rewriter.replaceOpWithNewOp(binder.op, resultType, + c); + return success(); + }); patterns.onOp( "QLinearMatMul", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -157,28 +269,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( bScale = extract(bScale); cScale = extract(cScale); - auto getQTy = - [&rewriter](Torch::ValueTensorType ty) -> Torch::ValueTensorType { - auto dt = ty.getDtype(); - Type newDt; - if (dt.isUnsignedInteger(8)) { - newDt = rewriter.getType(); - } else if (dt.isSignedInteger(8)) { - newDt = rewriter.getType(); - } else if (dt.isSignedInteger(32)) { - newDt = rewriter.getType(); - } else { - return nullptr; - } - - return rewriter.getType(ty.getOptionalSizes(), - newDt); - }; - - auto make = [&rewriter, &binder, &getQTy](Value v, Value scale, - Value zp) -> Value { + auto make = [&rewriter, &binder](Value v, Value scale, + Value zp) -> Value { auto ty = v.getType().cast(); - auto newTy = getQTy(ty); + auto newTy = getQTorchTypeFromTorchIntType(ty); return rewriter.create( binder.getLoc(), newTy, v, scale, zp); }; @@ -214,7 +308,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( c = rewriter.create(binder.getLoc(), cTy, c); - cTy = getQTy(resultType); + cTy = dyn_cast( + getQTorchTypeFromTorchIntType(resultType)); Value dtyVal = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr( diff --git a/lib/Conversion/TorchOnnxToTorch/Utils.cpp b/lib/Conversion/TorchOnnxToTorch/Utils.cpp index 8f5a2e67c0cb..ef3da8b3b3fa 100644 --- a/lib/Conversion/TorchOnnxToTorch/Utils.cpp +++ b/lib/Conversion/TorchOnnxToTorch/Utils.cpp @@ -8,6 +8,7 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchOnnxToTorch/Utils.h" +#include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" using namespace mlir; using namespace mlir::torch; @@ -26,3 +27,23 @@ Value mlir::torch::onnx_c::createConstantIntList( Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), cstValue); } + +Type mlir::torch::onnx_c::getQTorchTypeFromTorchIntType(Type ty) { + Torch::ValueTensorType tty = dyn_cast(ty); + if (!tty) + return nullptr; + + auto ctx = ty.getContext(); + Type dty = tty.getDtype(); + + if (dty.isUnsignedInteger(8)) + dty = Torch::QUInt8Type::get(ctx); + if (dty.isSignedInteger(8)) + dty = Torch::QInt8Type::get(ctx); + if (dty.isSignedInteger(32)) + dty = Torch::QInt32Type::get(ctx); + + if (!dty) + return nullptr; + return Torch::ValueTensorType::get(ctx, tty.getOptionalSizes(), dty); +} diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 4523febb9b9d..3557b27a2eb2 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -653,7 +653,7 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { op, "lhs and rhs of convolution must either be both int or fp"); } - if (inputZp && weightZp) { + if (inputZp && weightZp && !isa(bias.getType())) { auto biasDTy = bias.getType().cast().getElementType(); if (!biasDTy.isInteger(32)) { return rewriter.notifyMatchFailure( diff --git a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp index 15d5ec105ed4..6bc8a8ba084a 100644 --- a/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp +++ b/lib/Dialect/Torch/Transforms/FuseQuantizedOps.cpp @@ -64,10 +64,6 @@ template class QuantizeBias : public OpRewritePattern { if (operands.size() < 3) return failure(); - Value bias = operands[2]; - if (bias.getDefiningOp()) - return failure(); - Value lhsScale; if (auto qLhs = operands[0].getDefiningOp()) @@ -82,11 +78,18 @@ template class QuantizeBias : public OpRewritePattern { return failure(); auto resultTy = cast(op.getType()); - auto biasTy = bias.getType().cast(); - auto biasETy = biasTy.getOptionalDtype(); - if (!biasETy || !isa(biasETy)) + if (!isa(resultTy.getDtype())) return failure(); + Value bias = operands[2]; + auto biasTy = bias.getType().dyn_cast(); + + if (biasTy) { + auto biasETy = biasTy.getOptionalDtype(); + if (!biasETy || !isa(biasETy)) + return failure(); + } + Value biasScale = rewriter.create( op.getLoc(), lhsScale.getType(), lhsScale, rhsScale); @@ -95,19 +98,21 @@ template class QuantizeBias : public OpRewritePattern { rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); auto qi32Ty = rewriter.getType(); - auto newBiasTy = - rewriter.getType(biasTy.getOptionalSizes(), qi32Ty); - Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), qi32Ty); - bias = rewriter.create( - op.getLoc(), newBiasTy, bias, biasScale, zero, dtype); - bias = rewriter.create( - op.getLoc(), - rewriter.getType( - biasTy.getOptionalSizes(), - rewriter.getIntegerType(32, IntegerType::Signed)), - bias); - - operands[2] = bias; + + if (biasTy) { + auto newBiasTy = + rewriter.getType(biasTy.getOptionalSizes(), qi32Ty); + Value dtype = getDtypeIntValueForType(rewriter, op.getLoc(), qi32Ty); + bias = rewriter.create( + op.getLoc(), newBiasTy, bias, biasScale, zero, dtype); + bias = rewriter.create( + op.getLoc(), + rewriter.getType( + biasTy.getOptionalSizes(), + rewriter.getIntegerType(32, IntegerType::Signed)), + bias); + operands[2] = bias; + } auto convTy = rewriter.getType( resultTy.getOptionalSizes(), diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 9d947dce5ce5..ae36661bdd43 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -47,6 +47,83 @@ func.func @test_quantizelinear_i32(%arg0: !torch.vtensor<[6],f32>, %arg1: !torch // ----- +// CHECK-LABEL: @test_qlinearconv_nobias +func.func @test_qlinearconv_nobias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>) -> !torch.vtensor<[1,1,7,7],ui8> + // CHECK: %[[aZp:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int + // CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 + // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_2:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_3:.+]] = torch.constant.int 1 + // CHECK: %[[INT0_2:.+]] = torch.constant.int 0 + // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] + // CHECK: %[[KERNEL:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] + // CHECK: %[[DILATION:.+]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] + // CHECK: %[[STRIDE:.+]] = torch.prim.ListConstruct %[[INT0_2]], %[[INT0_2]] + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[INT1_5:.+]] = torch.constant.int 1 + // CHECK: %[[CONV:.+]] = torch.aten.convolution %[[A]], %[[B]], %[[NONE]], %[[DILATION]], %[[PAD]], %[[KERNEL]], %[[FALSE]], %[[STRIDE]], %[[INT1_5]] : !torch.vtensor<[1,1,7,7],!torch.quint8>, !torch.vtensor<[1,1,1,1],!torch.quint8>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32> + // CHECK: %[[convScale:.+]] = torch.aten.mul.float %[[aScale]], %[[bScale]] : !torch.float, !torch.float -> !torch.float + // CHECK: %[[INT0_6:.+]] = torch.constant.int 0 + // CHECK: %[[C:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[convScale]], %[[INT0_6]] : !torch.vtensor<[1,1,7,7],!torch.qint32>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32> + // CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[C]] : !torch.vtensor<[1,1,7,7],!torch.qint32> -> !torch.vtensor<[1,1,7,7],f32> + // CHECK: %[[INT13:.+]] = torch.constant.int 13 + // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[DEQ]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[1,1,7,7],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + // CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[1,1,7,7],!torch.quint8> -> !torch.vtensor<[1,1,7,7],ui8> + // CHECK: return %[[INT]] : !torch.vtensor<[1,1,7,7],ui8> + return %0 : !torch.vtensor<[1,1,7,7],ui8> +} + +// ----- + +// CHECK-LABEL: @test_qlinearconv_bias +func.func @test_qlinearconv_bias(%arg0: !torch.vtensor<[1,1,7,7],ui8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],ui8>, %arg3: !torch.vtensor<[1,1,1,1],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[],f32>, %arg7: !torch.vtensor<[],ui8>, %arg8 : !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.QLinearConv"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8) : (!torch.vtensor<[1,1,7,7],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[1,1,1,1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[],f32>, !torch.vtensor<[],ui8>, !torch.vtensor<[7],si32>) -> !torch.vtensor<[1,1,7,7],ui8> + // CHECK: %[[aZp:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[bZp:.+]] = torch.aten.item %arg5 : !torch.vtensor<[1],ui8> -> !torch.int + // CHECK: %[[cZp:.+]] = torch.aten.item %arg7 : !torch.vtensor<[],ui8> -> !torch.int + // CHECK: %[[aScale:.+]] = torch.aten.item %arg1 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[bScale:.+]] = torch.aten.item %arg4 : !torch.vtensor<[1],f32> -> !torch.float + // CHECK: %[[cScale:.+]] = torch.aten.item %arg6 : !torch.vtensor<[],f32> -> !torch.float + // CHECK: %[[A:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[aScale]], %[[aZp]] : !torch.vtensor<[1,1,7,7],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + // CHECK: %[[B:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg3, %[[bScale]], %[[bZp]] : !torch.vtensor<[1,1,1,1],ui8>, !torch.float, !torch.int -> !torch.vtensor<[1,1,1,1],!torch.quint8> + // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 + // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_1:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_2:.+]] = torch.constant.int 1 + // CHECK: %[[INT1_3:.+]] = torch.constant.int 1 + // CHECK: %[[INT0_2:.+]] = torch.constant.int 0 + // CHECK: %[[PAD:.+]] = torch.prim.ListConstruct %[[INT0_0]], %[[INT0_1]] + // CHECK: %[[KERNEL:.+]] = torch.prim.ListConstruct %[[INT1_0]], %[[INT1_1]] + // CHECK: %[[DILATION:.+]] = torch.prim.ListConstruct %[[INT1_2]], %[[INT1_3]] + // CHECK: %[[STRIDE:.+]] = torch.prim.ListConstruct %[[INT0_2]], %[[INT0_2]] + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[INT1_5:.+]] = torch.constant.int 1 + // CHECK: %[[CONV:.+]] = torch.aten.convolution %[[A]], %[[B]], %arg8, %[[DILATION]], %[[PAD]], %[[KERNEL]], %[[FALSE]], %[[STRIDE]], %[[INT1_5]] : !torch.vtensor<[1,1,7,7],!torch.quint8>, !torch.vtensor<[1,1,1,1],!torch.quint8>, !torch.vtensor<[7],si32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32> + // CHECK: %[[convScale:.+]] = torch.aten.mul.float %[[aScale]], %[[bScale]] : !torch.float, !torch.float -> !torch.float + // CHECK: %[[INT0_6:.+]] = torch.constant.int 0 + // CHECK: %[[C:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[convScale]], %[[INT0_6]] : !torch.vtensor<[1,1,7,7],!torch.qint32>, !torch.float, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.qint32> + // CHECK: %[[DEQ:.+]] = torch.aten.dequantize.self %[[C]] : !torch.vtensor<[1,1,7,7],!torch.qint32> -> !torch.vtensor<[1,1,7,7],f32> + // CHECK: %[[INT13:.+]] = torch.constant.int 13 + // CHECK: %[[QUANT:.+]] = torch.aten.quantize_per_tensor %[[DEQ]], %[[cScale]], %[[cZp]], %[[INT13]] : !torch.vtensor<[1,1,7,7],f32>, !torch.float, !torch.int, !torch.int -> !torch.vtensor<[1,1,7,7],!torch.quint8> + // CHECK: %[[INT:.+]] = torch.aten.int_repr %[[QUANT]] : !torch.vtensor<[1,1,7,7],!torch.quint8> -> !torch.vtensor<[1,1,7,7],ui8> + // CHECK: return %[[INT]] : !torch.vtensor<[1,1,7,7],ui8> + return %0 : !torch.vtensor<[1,1,7,7],ui8> +} + +// ----- + // CHECK-LABEL: @test_qlinearmatmul_2D func.func @test_qlinearmatmul_2D(%arg0: !torch.vtensor<[2,4],ui8>, %arg1: !torch.vtensor<[1],f32>, %arg2: !torch.vtensor<[1],ui8>, %arg3: !torch.vtensor<[4,3],ui8>, %arg4: !torch.vtensor<[1],f32>, %arg5: !torch.vtensor<[1],ui8>, %arg6: !torch.vtensor<[1],f32>, %arg7: !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,3],ui8> attributes {torch.onnx_meta.ir_version = 5 : si64, torch.onnx_meta.opset_version = 10 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %0 = torch.operator "onnx.QLinearMatMul"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7) : (!torch.vtensor<[2,4],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[4,3],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>, !torch.vtensor<[1],f32>, !torch.vtensor<[1],ui8>) -> !torch.vtensor<[2,3],ui8> diff --git a/test/Dialect/Torch/fuse-quantized-ops.mlir b/test/Dialect/Torch/fuse-quantized-ops.mlir index 1aaeb9ce1cd8..f98cb842f5d3 100644 --- a/test/Dialect/Torch/fuse-quantized-ops.mlir +++ b/test/Dialect/Torch/fuse-quantized-ops.mlir @@ -28,8 +28,8 @@ func.func @mm(%arg0: !torch.vtensor<[4, 4],si8>, %arg1: !torch.vtensor<[4, 4],si // ----- -// CHECK-LABEL: @convolution -func.func @convolution(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> { +// CHECK-LABEL: @convolution_bias +func.func @convolution_bias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>, %arg2 : !torch.vtensor<[3], f32>) -> !torch.vtensor<[1,3,7,7],f32> { %scale = torch.constant.float 0.5 %false = torch.constant.bool false %zero = torch.constant.int 0 @@ -60,3 +60,38 @@ func.func @convolution(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtens // CHECK: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32> return %16 : !torch.vtensor<[1,3,7,7],f32> } + + +// ----- + +// CHECK-LABEL: @convolution_nobias +func.func @convolution_nobias(%arg0: !torch.vtensor<[1,3,8,8],si8>, %arg1: !torch.vtensor<[3,3,2,2],si8>) -> !torch.vtensor<[1,3,7,7],f32> { + %scale = torch.constant.float 0.5 + %false = torch.constant.bool false + %zero = torch.constant.int 0 + %one = torch.constant.int 1 + %zp = torch.constant.int -128 + %none = torch.constant.none + %6 = torch.aten._make_per_tensor_quantized_tensor %arg0, %scale, %one : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> + %7 = torch.aten.dequantize.tensor %6 : !torch.vtensor<[1,3,8,8],!torch.qint8> -> !torch.vtensor<[1,3,8,8],f32> + %12 = torch.aten._make_per_tensor_quantized_tensor %arg1, %scale, %zero : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8> + %13 = torch.aten.dequantize.tensor %12 : !torch.vtensor<[3,3,2,2],!torch.qint8> -> !torch.vtensor<[3,3,2,2],f32> + %14 = torch.prim.ListConstruct %one, %one : (!torch.int, !torch.int) -> !torch.list + %15 = torch.prim.ListConstruct %zero, %zero : (!torch.int, !torch.int) -> !torch.list + %16 = torch.aten.convolution %7, %13, %none, %14, %15, %14, %false, %15, %one : !torch.vtensor<[1,3,8,8],f32>, !torch.vtensor<[3,3,2,2],f32>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],f32> + + // CHECK-DAG: %[[SCALEO:.+]] = torch.constant.float 2.500000e-01 + // CHECK-DAG: %[[HALF:.+]] = torch.constant.float 5.000000e-01 + // CHECK-DAG: %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[ONE:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[QLHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg0, %[[HALF]], %[[ONE]] : !torch.vtensor<[1,3,8,8],si8>, !torch.float, !torch.int -> !torch.vtensor<[1,3,8,8],!torch.qint8> + // CHECK-DAG: %[[QRHS:.+]] = torch.aten._make_per_tensor_quantized_tensor %arg1, %[[HALF]], %[[ZERO]] : !torch.vtensor<[3,3,2,2],si8>, !torch.float, !torch.int -> !torch.vtensor<[3,3,2,2],!torch.qint8> + // CHECK-DAG: %[[ONES:.+]] = torch.prim.ListConstruct %[[ONE]], %[[ONE]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[ZEROS:.+]] = torch.prim.ListConstruct %[[ZERO]], %[[ZERO]] : (!torch.int, !torch.int) -> !torch.list + // CHECK-DAG: %[[CONV:.+]] = torch.aten.convolution %[[QLHS]], %[[QRHS]], %[[NONE]], %[[ONES]], %[[ZEROS]], %[[ONES]], %[[FALSE]], %[[ZEROS]], %[[ONE]] : !torch.vtensor<[1,3,8,8],!torch.qint8>, !torch.vtensor<[3,3,2,2],!torch.qint8>, !torch.none, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.list, !torch.int -> !torch.vtensor<[1,3,7,7],si32> + // CHECK-DAG: %[[QOUT:.+]] = torch.aten._make_per_tensor_quantized_tensor %[[CONV]], %[[SCALEO]], %[[ZERO]] : !torch.vtensor<[1,3,7,7],si32>, !torch.float, !torch.int -> !torch.vtensor<[1,3,7,7],!torch.qint32> + // CHECK: %[[FOUT:.+]] = torch.aten.dequantize.tensor %[[QOUT]] : !torch.vtensor<[1,3,7,7],!torch.qint32> -> !torch.vtensor<[1,3,7,7],f32> + return %16 : !torch.vtensor<[1,3,7,7],f32> +} From 041a54ae0c29e04703f0d9616bc4effebf2e6998 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 5 Feb 2024 16:23:04 -0800 Subject: [PATCH 163/283] [torch] Supporting `torch.aten.mul.float` lowering to `arith` (#2833) Simple missing scalar operation for multiply floats was missing. --- lib/Conversion/TorchToArith/TorchToArith.cpp | 4 +++- projects/pt1/e2e_testing/xfail_sets.py | 3 +++ .../torch_mlir_e2e_test/test_suite/scalar.py | 22 +++++++++++++++++++ 3 files changed, 28 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index d2000d7fc3d2..0ca2d108a5e3 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -443,9 +443,11 @@ class ConvertTorchToArith typeConverter, context); patterns.add>( typeConverter, context); - target.addIllegalOp(); + target.addIllegalOp(); patterns.add>( typeConverter, context); + patterns.add>( + typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 2ee5d279a9d3..973f75a2637a 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -100,6 +100,7 @@ # START tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {} 'AtenSubFloatModule_basic', + 'AtenMulFloatModule_basic', 'BoolFloatFalseModule_basic', 'BoolFloatTrueModule_basic', 'CeilFloatModule_basic', @@ -109,6 +110,7 @@ 'GtFloatIntModule_basic', 'NeFloatIntModule_basic', 'SubFloatModule_basic', + 'MulFloatModule_basic', 'TensorToFloatZeroRank_basic', 'TensorToFloat_basic', # END tests failing due to: torch._dynamo.exc.Unsupported: call_function BuiltinVariable(float) [TensorVariable()] {} @@ -1489,6 +1491,7 @@ "SliceStartEqEndModule_basic", "SqrtIntModule_basic", "SubFloatModule_basic", + "MulFloatModule_basic", "SubIntModule_basic", "TensorsStackPromoteDTypeModule_basic", "TensorToBoolZeroRank_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py index 303c3f0a801a..51b9fb993088 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scalar.py @@ -78,6 +78,28 @@ def SubFloatModule_basic(module, tu: TestUtils): module.forward(tu.rand().double(), tu.rand().double()) +# ============================================================================== + +class MulFloatModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.float64, True), + ([], torch.float64, True), + ]) + def forward(self, lhs, rhs): + return float(lhs) * float(rhs) + + +@register_test_case(module_factory=lambda: MulFloatModule()) +def MulFloatModule_basic(module, tu: TestUtils): + module.forward(tu.rand().double(), tu.rand().double()) + + # ============================================================================== From 1cb14f6879914f84c4e9fcae9a6af550f77be953 Mon Sep 17 00:00:00 2001 From: Dave Liddell <44620210+daveliddell@users.noreply.github.com> Date: Mon, 5 Feb 2024 18:10:42 -0700 Subject: [PATCH 164/283] Rob's atenTensor folder (#2867) If a tensor is initialized by a list with a single constant integer, this folder turns it into a torch.vtensor.literal --------- Co-authored-by: Dave Liddell --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + lib/Dialect/Torch/IR/TorchOps.cpp | 21 +++++++++++++++++++ .../build_tools/torch_ods_gen.py | 2 +- test/Dialect/Torch/canonicalize.mlir | 11 ++++++++++ 4 files changed, 34 insertions(+), 1 deletion(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index a0ec9663b3e9..fad589576314 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8582,6 +8582,7 @@ def Torch_AtenTensorOp : Torch_Op<"aten.tensor", [ printDefaultTorchOp(printer, *this, 4, 1); } }]; + let hasFolder = 1; } def Torch_AtenTensorBoolOp : Torch_Op<"aten.tensor.bool", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 98de4f85b62b..c557d2595598 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2758,6 +2758,27 @@ void AtenDeviceWithIndexOp::getCanonicalizationPatterns( }); } +//===----------------------------------------------------------------------===// +// AtenTensorOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) { + // If a torch.aten.tensor op is initialized by a list with a constant, single + // element, fold it into a torch.vtensor.literal + auto resultTy = dyn_cast(getType()); + Type eTy = resultTy.getDtype(); + ShapedType shapedTy = resultTy.toBuiltinTensor().clone(eTy); + + SmallVector data; + if (matchPattern(getData(), m_TorchListOfConstantInts(data)) && + data.size() == 1) { + Attribute attribute = IntegerAttr::get(eTy, data[0]); + return DenseElementsAttr::get(shapedTy, attribute); + } + + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenIntTensorOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 41a297ba62b8..cb9c484b7244 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -570,7 +570,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::new_zeros : (Tensor, int[], int?, int?, Device?, bool?) -> (Tensor)") emit("aten::eye : (int, int?, int?, Device?, bool?) -> (Tensor)") emit("aten::eye.m : (int, int, int?, int?, Device?, bool?) -> (Tensor)") - emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)") + emit("aten::tensor : (t[], int?, Device?, bool) -> (Tensor)", has_folder=True) emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)") emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)") emit("aten::scalar_tensor : (Scalar, int?, int?, Device?, bool?) -> (Tensor)") diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index cb2ec2d14a54..83055f3be89d 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1461,6 +1461,17 @@ func.func @torch.aten.squeeze.dim$zero_rank(%arg0: !torch.tensor<[],f32>) -> !to return %0 : !torch.tensor<[],f32> } +// CHECK-LABEL: func.func @torch.aten.tensor$one_elem( +// CHECK-NEXT: torch.vtensor.literal(dense<42> : tensor<1xsi64>) : !torch.vtensor<[1],si64> +func.func @torch.aten.tensor$one_elem() -> (!torch.vtensor<[1],si64>) { + %none = torch.constant.none + %false = torch.constant.bool false + %int42 = torch.constant.int 42 + %66 = torch.prim.ListConstruct %int42 : (!torch.int) -> !torch.list + %67 = torch.aten.tensor %66, %none, %none, %false : !torch.list, !torch.none, !torch.none, !torch.bool -> !torch.vtensor<[1],si64> + return %67 : !torch.vtensor<[1],si64> +} + // CHECK-LABEL: func.func @torch.aten.to.dtype$same_dtype( // CHECK-SAME: %[[ARG:.*]]: !torch.tensor<*,f32>) -> !torch.tensor<*,f32> { // CHECK-NEXT: return %[[ARG]] : !torch.tensor<*,f32> From faf7d4aaa5e493243fa3632cc288160bf0caab45 Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Tue, 6 Feb 2024 00:19:31 -0600 Subject: [PATCH 165/283] [fx_importer] Add support for 0D tensors (#2870) Adds an escape hatch from creating a DenseResourceElementsAttr for single value tensors into DenseElementsAttr. For 0d or 1element, splats are better as DenseElementsAttr. Don't use DenseResourceElementsAttr for it --- python/torch_mlir/extras/fx_importer.py | 100 +++++++++++++++--------- test/python/fx_importer/basic_test.py | 3 +- 2 files changed, 62 insertions(+), 41 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 8cffcb1ea935..5328e8730cc3 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -51,6 +51,7 @@ Attribute, Block, Context, + DenseElementsAttr, DenseResourceElementsAttr, FloatAttr, BF16Type, @@ -207,28 +208,28 @@ } -def sparsity_encoding(shape: torch.Size, sparse_layout : torch.layout) -> str: - """Returns sparse tensor encoding for the given sparse layout as string. +def sparsity_encoding(shape: torch.Size, sparse_layout: torch.layout) -> str: + """Returns sparse tensor encoding for the given sparse layout as string. - The method currently just supports 2-dim sparse formats. This should be - generalized to the torch.sparse encodings for prefix dense batch dimensions - and suffix dense subtensor dimensions. Since MLIR supports a superset of what - is currently implememented in torch.sparse, this should not a be problem. - """ + The method currently just supports 2-dim sparse formats. This should be + generalized to the torch.sparse encodings for prefix dense batch dimensions + and suffix dense subtensor dimensions. Since MLIR supports a superset of what + is currently implememented in torch.sparse, this should not a be problem. + """ - # TODO: any rank - if len(shape) != 2: - raise RuntimeError(f"Unsupported sparse rank {len(shape)}") + # TODO: any rank + if len(shape) != 2: + raise RuntimeError(f"Unsupported sparse rank {len(shape)}") - if sparse_layout is torch.sparse_coo: - return '#sparse_tensor.encoding<{map=(i,j)->(i:compressed(nonunique),j:singleton)}>' - if sparse_layout is torch.sparse_csr: - return '#sparse_tensor.encoding<{map=(i,j)->(i:dense,j:compressed)}>' - if sparse_layout is torch.sparse_csc: - return '#sparse_tensor.encoding<{map=(i,j)->(j:dense,i:compressed)}>' - # TODO: block format (derive block size!) + if sparse_layout is torch.sparse_coo: + return "#sparse_tensor.encoding<{map=(i,j)->(i:compressed(nonunique),j:singleton)}>" + if sparse_layout is torch.sparse_csr: + return "#sparse_tensor.encoding<{map=(i,j)->(i:dense,j:compressed)}>" + if sparse_layout is torch.sparse_csc: + return "#sparse_tensor.encoding<{map=(i,j)->(j:dense,i:compressed)}>" + # TODO: block format (derive block size!) - raise RuntimeError(f"Unsupported sparse layout {sparse_layout}") + raise RuntimeError(f"Unsupported sparse layout {sparse_layout}") def is_symbolic(obj: Any) -> bool: @@ -477,15 +478,20 @@ def format_asm_shape(self, shape: torch.Size) -> str: """Return IrType for !torch.vtensor with the given shape and dtype""" - def get_vtensor_type(self, shape: torch.Size, dtype: torch.dtype, sparse_layout: torch.layout = None): + def get_vtensor_type( + self, shape: torch.Size, dtype: torch.dtype, sparse_layout: torch.layout = None + ): shape_asm = self.format_asm_shape(shape) mlir_dtype = str(self.dtype_to_type(dtype)) if sparse_layout is not None: - sparsity = sparsity_encoding(shape, sparse_layout) - return IrType.parse( - f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)},{sparsity}>", context=self._c) + sparsity = sparsity_encoding(shape, sparse_layout) + return IrType.parse( + f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)},{sparsity}>", + context=self._c, + ) return IrType.parse( - f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c) + f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c + ) def node_val_to_type(self, node: torch_fx.Node) -> IrType: try: @@ -521,7 +527,9 @@ def node_val_to_type(self, node: torch_fx.Node) -> IrType: f"FIXME: Illegal access to torch.fx.Node.meta: {e} ({node.meta.keys()} : {node.meta})" ) - def tensor_metadata_to_type(self, tm: TensorMetadata, sparse_layout : torch.layout = None) -> IrType: + def tensor_metadata_to_type( + self, tm: TensorMetadata, sparse_layout: torch.layout = None + ) -> IrType: tm_shape = tuple( item.node if is_symbolic(item) else item for item in list(tm.shape) ) @@ -686,9 +694,11 @@ def _import_symbolic_torch_op( # operations on symbolic arguments as regular python expressions rather than as torch ops if is_builtin_function_or_method(target): arg_types = [ - arg.meta["val"].node.pytype - if isinstance(arg, torch.fx.Node) - else type(arg) + ( + arg.meta["val"].node.pytype + if isinstance(arg, torch.fx.Node) + else type(arg) + ) for arg in node.args ] is_int = [item == int for item in arg_types] @@ -1018,7 +1028,7 @@ def create_mlir_tensor_type(tensor: torch.Tensor) -> IrType: tensor_type = RankedTensorType.get(tuple(tensor.size()), element_type) return tensor_type except KeyError: - raise TypeError(f"Could not map Torch dtype {dtype} to an IREE type") + raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type") def _make_vtensor_literal_op( @@ -1038,15 +1048,28 @@ def _make_vtensor_literal_op( # buffer is via the indirection: Tensor -> list -> numpy array. This allows us to create a vtensor literal as # desired, but also limits which data types we can support in this function (see TORCH_DTYPE_TO_NPY_TYPE above) np_tensor = np.array(tensor.tolist()).astype(npy_dtype) - bytes_view = memoryview(np_tensor) - tensor_type = create_mlir_tensor_type(tensor) - shape_desc = "_".join([str(d) for d in tensor.shape]) - blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}" - elements_attr = DenseResourceElementsAttr.get_from_buffer( - bytes_view, - blob_name, - tensor_type, - ) + # One element constants are more optimizable as splat DenseElementsAttr. DenseResourceElementsAttr does not + # support splats, so don't use it for that case. In addition, at the time of writing, it has bugs with handling + # 0d tensors. + if np_tensor.size == 1: + try: + dtype = tensor.dtype + element_type = TORCH_DTYPE_TO_MLIR_TYPE[dtype]() + except KeyError: + raise TypeError(f"Could not map Torch dtype {dtype} to an MLIR type") + elements_attr = DenseElementsAttr.get( + type=element_type, array=np_tensor, shape=np_tensor.shape + ) + else: + bytes_view = memoryview(np_tensor) + tensor_type = create_mlir_tensor_type(tensor) + shape_desc = "_".join([str(d) for d in tensor.shape]) + blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}" + elements_attr = DenseResourceElementsAttr.get_from_buffer( + bytes_view, + blob_name, + tensor_type, + ) mapping.value = elements_attr else: elements_attr = mapping.value @@ -1105,8 +1128,7 @@ def lookup(self, t: type) -> Any: # Opaque value to indicate something is empty. Used in cases where 'None' # may have a different meaning. -class EmptyType: - ... +class EmptyType: ... Empty = EmptyType() diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 62d3b1203e03..acd2a559fa52 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -47,7 +47,7 @@ def run(f): # CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> # CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32> # CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32> -# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense_resource : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> +# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<{{.*>+}} : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> # CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]] # CHECK-DAG: %[[mul_a:.+]] = torch.aten.mul.Tensor %[[tanh]], %[[a]] # CHECK-DAG: %[[mul_b:.+]] = torch.aten.mul.Tensor %[[mul_a]], %[[b]] @@ -58,7 +58,6 @@ def run(f): # CHECK: dialect_resources: # CHECK-DAG: torch_tensor_1_4_torch.float32 # CHECK-DAG: torch_tensor_3_1_torch.float32 -# CHECK-DAG: torch_tensor_1_1_torch.float32 def test_import_frozen_exported_program(): # Tests the basic structural premises of import_frozen_exported_program, # namely that free tensors (buffers) and parameters are treated as From cc06391630dc6d2f189787389873a6310212fba6 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Tue, 6 Feb 2024 16:12:12 -0500 Subject: [PATCH 166/283] AtenSortOp Folder (#2864) A chunk off https://github.com/llvm/torch-mlir/pull/2856 https://github.com/llvm/torch-mlir/pull/2860 --------- Co-authored-by: Xida Ren Co-authored-by: Rob Suderman --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + lib/Dialect/Torch/IR/TorchOps.cpp | 46 +++++++++++++++++++ .../build_tools/torch_ods_gen.py | 2 +- test/Dialect/Torch/canonicalize.mlir | 29 ++++++++++++ 4 files changed, 77 insertions(+), 1 deletion(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index fad589576314..7f0f5af7e43f 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -12559,6 +12559,7 @@ def Torch_AtenSortOp : Torch_Op<"aten.sort", [ printDefaultTorchOp(printer, *this, 3, 2); } }]; + let hasFolder = 1; } def Torch_AtenSplitTensorOp : Torch_Op<"aten.split.Tensor", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index c557d2595598..1857aff4dbbd 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1710,6 +1710,52 @@ void AtenSortIntOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenSortOp +//===----------------------------------------------------------------------===// + +LogicalResult AtenSortOp::fold(FoldAdaptor adaptor, + SmallVectorImpl &results) { + auto operand = getSelf(); + auto operandType = dyn_cast(operand.getType()); + if (!operandType || !operandType.hasSizes()) + return failure(); + + // only ValueTensorType has toBuiltinTensor + auto indicesTensorType = dyn_cast(getResult(1).getType()); + if (!indicesTensorType) + return failure(); + + if (!indicesTensorType.hasDtype()) + return failure(); + auto indicesType = + indicesTensorType.toBuiltinTensor().clone(indicesTensorType.getDtype()); + if (!indicesType || !indicesType.hasStaticShape()) + return failure(); + + bool unaryDim = false; + IntegerAttr dimAttribute = dyn_cast_if_present(adaptor.getDim()); + if (!dimAttribute) + return failure(); + int64_t dimInt = dimAttribute.getValue().getSExtValue(); + if (dimInt < 0) + dimInt += operandType.getSizes().size(); + if (dimAttribute) { + unaryDim = operandType.getSizes()[dimInt] == 1; + } + + OpBuilder builder(getContext()); + if (unaryDim || llvm::all_of(operandType.getSizes(), + [](int64_t dim) { return dim == 1; })) { + results.push_back(operand); + results.push_back(DenseElementsAttr::get( + indicesType, builder.getZeroAttr(indicesType.getElementType()))); + return success(); + } + + return failure(); +} + //===----------------------------------------------------------------------===// // NonValueTensorLiteralOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index cb9c484b7244..7893c26db947 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -728,7 +728,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::ne.int_list : (int[], int[]) -> (bool)") emit("aten::any.bool : (bool[]) -> (bool)", has_folder=True) emit("aten::sort.int : (int[], bool) -> ()", has_canonicalizer=True) - emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)") + emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)", has_folder=True) emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])") emit("aten::split_with_sizes : (Tensor, int[], int) -> (Tensor[])") emit("aten::unbind.int : (Tensor, int) -> (Tensor[])") diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 83055f3be89d..77a9e8ad3d61 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2012,6 +2012,35 @@ func.func @torch.aten.sort.int$reverse_true() -> !torch.list { return %0 : !torch.list } +// CHECK-LABEL: @torch.aten.sort$unary_element +// CHECK : %[[INDICES:.*]] = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64> +// CHECK-NOT : torch.aten.sort %arg +// CHECK : return %arg0, %[[INDICES]] : !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> +func.func @torch.aten.sort$unary_element(%arg0 : !torch.vtensor<[1],si64>, %arg1 : !torch.int, %arg2 : !torch.bool) -> (!torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) { + %0, %1 = torch.aten.sort %arg0, %arg1, %arg2 : !torch.vtensor<[1],si64>, !torch.int, !torch.bool -> !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> + return %0, %1 : !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64> +} + + +// CHECK-LABEL: @torch.aten.sort$unary_dim +// CHECK : %[[INDICES:.*]] = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> +// CHECK-NOT : torch.aten.sort %arg +// CHECK : return %arg0, %[[INDICES]] : !torch.vtensor<[3, 1,4],si64>, !torch.vtensor<[1],si64> +func.func @torch.aten.sort$unary_dim(%arg0 : !torch.vtensor<[3, 1, 4],si64>, %arg1 : !torch.bool) -> (!torch.vtensor<[3, 1, 4],si64>, !torch.vtensor<[1],si64>) { + %dim = torch.constant.int 1 + %0, %1 = torch.aten.sort %arg0, %dim, %arg1 : !torch.vtensor<[3, 1, 4],si64>, !torch.int, !torch.bool -> !torch.vtensor<[3, 1, 4],si64>, !torch.vtensor<[1],si64> + return %0, %1 : !torch.vtensor<[3, 1,4],si64>, !torch.vtensor<[1],si64> +} + +// CHECK-LABEL: @torch.aten.sort$nofold +// CHECK : torch.aten.sort %arg +func.func @torch.aten.sort$nofold (%arg0 : !torch.vtensor<[3, 1, 4],si64>, %arg1 : !torch.bool) -> (!torch.vtensor<[3, 1, 4],si64>, !torch.vtensor<[3],si64>) { + %dim = torch.constant.int 0 + %0, %1 = torch.aten.sort %arg0, %dim, %arg1 : !torch.vtensor<[3, 1, 4],si64>, !torch.int, !torch.bool -> !torch.vtensor<[3, 1, 4],si64>, !torch.vtensor<[3],si64> + return %0, %1 : !torch.vtensor<[3, 1, 4],si64>, !torch.vtensor<[3],si64> +} + + // CHECK-LABEL: @torch.aten.cat$fold_single_operand // CHECK-SAME: %[[ARG0:.+]]: !torch.tensor // CHECK: return %[[ARG0]] : !torch.tensor From bfcf93ea2191a760ed14bbefc76f3cb34806806a Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Tue, 6 Feb 2024 19:07:59 -0800 Subject: [PATCH 167/283] Rename torch_mlir.compile APIs and introduce FX based analogs (#2842) Link to related RFC: https://discourse.llvm.org/t/rfc-rename-torch-mlir-compile-apis-and-introduce-fx-based-analogs/76646 This commit updates the documentation, tests, CMake files, and API for the proposed changes in the RFC. There is a new torch_mlir/fx.py for user level APIs related to importing modules and a corresponding test for this path can be found at test/python/fx_importer/basic_test.py. --------- Co-authored-by: MaheshRavishankar --- docs/architecture.md | 2 +- docs/development.md | 25 ++++++++++++++----- docs/{long_term_roadmap.md => roadmap.md} | 16 ++++++++++++ projects/pt1/examples/torchdynamo_resnet18.py | 4 +-- projects/pt1/examples/torchscript_resnet18.py | 4 +-- .../torchscript_resnet18_all_output_types.py | 8 +++--- .../torchscript_resnet_inference.ipynb | 4 +-- .../torchscript_stablehlo_backend_resnet.py | 4 +-- .../torchscript_stablehlo_backend_tinybert.py | 4 +-- projects/pt1/python/CMakeLists.txt | 2 +- .../test/compile_api/already_scripted.py | 8 +++--- .../python/test/compile_api/already_traced.py | 8 +++--- .../test/compile_api/backend_legal_ops.py | 6 ++--- projects/pt1/python/test/compile_api/basic.py | 20 +++++++-------- .../pt1/python/test/compile_api/make_fx.py | 4 +-- .../test/compile_api/multiple_methods.py | 10 ++++---- .../test/compile_api/output_type_spec.py | 6 ++--- .../pt1/python/test/compile_api/tracing.py | 20 +++++++-------- projects/pt1/python/torch_mlir/dynamo.py | 2 +- .../{__init__.py => torchscript.py} | 6 ++--- .../configs/linalg_on_tensors_backend.py | 4 +-- .../configs/stablehlo_backend.py | 4 +-- .../configs/torchdynamo.py | 2 +- .../configs/tosa_backend.py | 4 +-- .../pt1/python/torch_mlir_e2e_test/utils.py | 4 +-- .../test/python/custom_op_shape_dtype_fn.py | 4 +-- .../jit_ir/node_import/unimplemented.py | 6 ++--- python/CMakeLists.txt | 7 ++++++ python/torch_mlir/fx.py | 25 +++++++++++++++++++ test/python/compile.py | 4 +-- test/python/fx_importer/basic_test.py | 25 ++----------------- 31 files changed, 146 insertions(+), 106 deletions(-) rename docs/{long_term_roadmap.md => roadmap.md} (94%) rename projects/pt1/python/torch_mlir/{__init__.py => torchscript.py} (99%) create mode 100644 python/torch_mlir/fx.py diff --git a/docs/architecture.md b/docs/architecture.md index 4c102e140d7a..e2ef378bd99c 100644 --- a/docs/architecture.md +++ b/docs/architecture.md @@ -184,7 +184,7 @@ semantics. And often users want to erase the shapes in the trace to allow dynamic shapes for the trace. Additionally, the Python-level data structures and APIs are very parallel between `torch.jit.script` and `torch.jit.trace`, so we consider both of those as the same from the perspective of the responsibilities -of the compiler. Both are accessed via the `torch_mlir.compile` Python API. +of the compiler. Both are accessed via the `torch_mlir.torchscript.compile` Python API. ### Modeling the `torch.nn.Module` object (`IValue`) hierarchy for TorchScript diff --git a/docs/development.md b/docs/development.md index 782058a63ea7..3e9192f5fa8e 100644 --- a/docs/development.md +++ b/docs/development.md @@ -120,37 +120,50 @@ cmake --build build ### Linux and macOS ```shell -export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/projects/pt1/examples +export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/test/python/fx_importer ``` ### Windows PowerShell ```shell -$env:PYTHONPATH = "$PWD/build/tools/torch-mlir/python_packages/torch_mlir;$PWD/projects/pt1/examples" +$env:PYTHONPATH = "$PWD/build/tools/torch-mlir/python_packages/torch_mlir;$PWD/test/python/fx_importer" ``` ## Testing MLIR output in various dialects -To test the compiler's output to the different MLIR dialects, you can use the example `projects/pt1/examples/torchscript_resnet18_all_output_types.py`. +To test the MLIR output to torch dialect, you can use `test/python/fx_importer/basic_test.py`. Make sure you have activated the virtualenv and set the `PYTHONPATH` above (if running on Windows, modify the environment variable as shown above): ```shell source mlir_venv/bin/activate +export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/test/python/fx_importer +python test/python/fx_importer/basic_test.py +``` + +This will display the basic example in TORCH dialect. + +To test the compiler's output to the different MLIR dialects, you can also use the deprecated path +using torchscript with the example `projects/pt1/examples/torchscript_resnet18_all_output_types.py`. +This path doesn't give access to the current generation work that is being driven via the fx_importer +and may lead to errors. + +Same as above, but with different python path and example: +```shell export PYTHONPATH=`pwd`/build/tools/torch-mlir/python_packages/torch_mlir:`pwd`/projects/pt1/examples python projects/pt1/examples/torchscript_resnet18_all_output_types.py ``` This will display the Resnet18 network example in three dialects: TORCH, LINALG on TENSORS and TOSA. -The main functionality is on `torch_mlir.compile()`'s `output_type`. +The main functionality is on `torch_mlir.torchscript.compile()`'s `output_type`. Ex: ```python -module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch") +module = torch_mlir.torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch") ``` -Currently, `output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`. +`output_type` can be: `TORCH`, `TOSA`, `LINALG_ON_TENSORS`, `RAW` and `STABLEHLO`. ## Jupyter diff --git a/docs/long_term_roadmap.md b/docs/roadmap.md similarity index 94% rename from docs/long_term_roadmap.md rename to docs/roadmap.md index 0f0940efc32d..f60502a52423 100644 --- a/docs/long_term_roadmap.md +++ b/docs/roadmap.md @@ -51,6 +51,22 @@ the ecosystem are: Most of this document describes long-term ecosystem changes that will address these, drastically improving Torch-MLIR's ability to meet its goals. +## Current API Paths + +Currently, there are two main API paths for the torch-mlir project: + +- The first path is part of the legacy project pt1 code + (torch_mlir.torchscript.compile). This allows users to test the compiler's + output to the different MLIR dialects (`TORCH`, `TOSA`, `LINALG_ON_TENSORS`, + `RAW` and `STABLEHLO`). This path is deprecated and doesn’t give access to + the current generation work that is being driven via the fx_importer. It is + tied to the old Torchscript path. +- The second path (torch_mlir.fx.export_and_import) allows users to import a + consolidated torch.export.ExportedProgram instance of an arbitrary Python + callable (an nn.Module, a function or a method) and output to torch dialect + mlir module. This path is aligned with PyTorch's roadmap, but the path is + not fully functional yet. + ## Roadmap ### Refactoring the frontend diff --git a/projects/pt1/examples/torchdynamo_resnet18.py b/projects/pt1/examples/torchdynamo_resnet18.py index d7abd80da665..377c632da36f 100644 --- a/projects/pt1/examples/torchdynamo_resnet18.py +++ b/projects/pt1/examples/torchdynamo_resnet18.py @@ -14,7 +14,7 @@ import torchvision.models as models from torchvision import transforms -import torch_mlir +from torch_mlir import torchscript from torch_mlir.dynamo import make_simple_dynamo_backend from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend @@ -71,7 +71,7 @@ def predictions(torch_func, jit_func, img, labels): @make_simple_dynamo_backend def refbackend_torchdynamo_backend(fx_graph: torch.fx.GraphModule, example_inputs: List[torch.Tensor]): - mlir_module = torch_mlir.compile( + mlir_module = torchscript.compile( fx_graph, example_inputs, output_type="linalg-on-tensors") backend = refbackend.RefBackendLinalgOnTensorsBackend() compiled = backend.compile(mlir_module) diff --git a/projects/pt1/examples/torchscript_resnet18.py b/projects/pt1/examples/torchscript_resnet18.py index ac46e6f4523b..62e5eda6cc83 100644 --- a/projects/pt1/examples/torchscript_resnet18.py +++ b/projects/pt1/examples/torchscript_resnet18.py @@ -12,7 +12,7 @@ import torchvision.models as models from torchvision import transforms -import torch_mlir +from torch_mlir import torchscript from torch_mlir_e2e_test.linalg_on_tensors_backends import refbackend @@ -67,7 +67,7 @@ def predictions(torch_func, jit_func, img, labels): resnet18 = models.resnet18(pretrained=True) resnet18.train(False) -module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors") +module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors") backend = refbackend.RefBackendLinalgOnTensorsBackend() compiled = backend.compile(module) jit_module = backend.load(compiled) diff --git a/projects/pt1/examples/torchscript_resnet18_all_output_types.py b/projects/pt1/examples/torchscript_resnet18_all_output_types.py index a17fa40521d3..70a920550b2d 100644 --- a/projects/pt1/examples/torchscript_resnet18_all_output_types.py +++ b/projects/pt1/examples/torchscript_resnet18_all_output_types.py @@ -6,15 +6,15 @@ import torch import torchvision -import torch_mlir +from torch_mlir import torchscript resnet18 = torchvision.models.resnet18(pretrained=True) resnet18.eval() -module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch") +module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="torch") print("TORCH OutputType\n", module.operation.get_asm(large_elements_limit=10)) -module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors") +module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="linalg-on-tensors") print("LINALG_ON_TENSORS OutputType\n", module.operation.get_asm(large_elements_limit=10)) # TODO: Debug why this is so slow. -module = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="tosa") +module = torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type="tosa") print("TOSA OutputType\n", module.operation.get_asm(large_elements_limit=10)) diff --git a/projects/pt1/examples/torchscript_resnet_inference.ipynb b/projects/pt1/examples/torchscript_resnet_inference.ipynb index 3ab7cc64dadb..9970f90b8bb2 100644 --- a/projects/pt1/examples/torchscript_resnet_inference.ipynb +++ b/projects/pt1/examples/torchscript_resnet_inference.ipynb @@ -184,7 +184,7 @@ "\n", "# Compile the model with an example input.\n", "# We lower to the linalg-on-tensors form that the reference backend supports.\n", - "compiled = torch_mlir.compile(TanhModule(), torch.ones(3), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)\n", + "compiled = torch_mlir.torchscript.compile(TanhModule(), torch.ones(3), output_type=torch_mlir.OutputType.LINALG_ON_TENSORS)\n", "# Load it on the reference backend.\n", "jit_module = compile_and_load_on_refbackend(compiled)\n", "# Run it!\n", @@ -326,7 +326,7 @@ "source": [ "resnet18 = torchvision.models.resnet18(weights=torchvision.models.ResNet18_Weights.DEFAULT)\n", "resnet18.eval()\n", - "compiled = torch_mlir.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=\"linalg-on-tensors\")\n", + "compiled = torch_mlir.torchscript.compile(resnet18, torch.ones(1, 3, 224, 224), output_type=\"linalg-on-tensors\")\n", "jit_module = compile_and_load_on_refbackend(compiled)" ] }, diff --git a/projects/pt1/examples/torchscript_stablehlo_backend_resnet.py b/projects/pt1/examples/torchscript_stablehlo_backend_resnet.py index 7a97359cff62..e42828ed776e 100644 --- a/projects/pt1/examples/torchscript_stablehlo_backend_resnet.py +++ b/projects/pt1/examples/torchscript_stablehlo_backend_resnet.py @@ -1,13 +1,13 @@ import torch import torchvision.models as models -import torch_mlir +from torch_mlir import torchscript model = models.resnet18(pretrained=True) model.eval() data = torch.randn(2,3,200,200) out_stablehlo_mlir_path = "./resnet18_stablehlo.mlir" -module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=False) +module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=False) with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf: outf.write(str(module)) diff --git a/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py b/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py index c035be3a54fe..c68daf12dd86 100644 --- a/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py +++ b/projects/pt1/examples/torchscript_stablehlo_backend_tinybert.py @@ -1,5 +1,5 @@ import torch -import torch_mlir +from torch_mlir import torchscript from transformers import BertForMaskedLM @@ -17,7 +17,7 @@ def forward(self, data): data = torch.randint(30522, (2, 128)) out_stablehlo_mlir_path = "./bert_tiny_stablehlo.mlir" -module = torch_mlir.compile(model, data, output_type=torch_mlir.OutputType.STABLEHLO, use_tracing=True) +module = torchscript.compile(model, data, output_type=torchscript.OutputType.STABLEHLO, use_tracing=True) with open(out_stablehlo_mlir_path, "w", encoding="utf-8") as outf: outf.write(str(module)) diff --git a/projects/pt1/python/CMakeLists.txt b/projects/pt1/python/CMakeLists.txt index ce40426988a7..642b86b50490 100644 --- a/projects/pt1/python/CMakeLists.txt +++ b/projects/pt1/python/CMakeLists.txt @@ -18,7 +18,7 @@ declare_mlir_python_sources(TorchMLIRPythonSources.TopLevel ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" ADD_TO_PARENT TorchMLIRPythonTorchExtensionsSources SOURCES - __init__.py + torchscript.py _dynamo_fx_importer.py compiler_utils.py dynamo.py diff --git a/projects/pt1/python/test/compile_api/already_scripted.py b/projects/pt1/python/test/compile_api/already_scripted.py index 367170081228..7d9720727a38 100644 --- a/projects/pt1/python/test/compile_api/already_scripted.py +++ b/projects/pt1/python/test/compile_api/already_scripted.py @@ -6,7 +6,7 @@ # RUN: %PYTHON %s | FileCheck %s import torch -import torch_mlir +from torch_mlir import torchscript class BasicModule(torch.nn.Module): @@ -15,17 +15,17 @@ def sin(self, x): return torch.ops.aten.sin(x) -example_args = torch_mlir.ExampleArgs() +example_args = torchscript.ExampleArgs() example_args.add_method("sin", torch.ones(2, 3)) scripted = torch.jit.script(BasicModule()) -print(torch_mlir.compile(scripted, example_args)) +print(torchscript.compile(scripted, example_args)) # CHECK: module # CHECK-DAG: func.func @sin scripted = torch.jit.script(BasicModule()) try: # CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition. - torch_mlir.compile(scripted, torch_mlir.ExampleArgs().add_method("nonexistent", torch.ones(2, 3))) + torchscript.compile(scripted, torchscript.ExampleArgs().add_method("nonexistent", torch.ones(2, 3))) except Exception as e: print(e) diff --git a/projects/pt1/python/test/compile_api/already_traced.py b/projects/pt1/python/test/compile_api/already_traced.py index a719eb743c73..32f7b5653fca 100644 --- a/projects/pt1/python/test/compile_api/already_traced.py +++ b/projects/pt1/python/test/compile_api/already_traced.py @@ -6,23 +6,23 @@ # RUN: %PYTHON %s | FileCheck %s import torch -import torch_mlir +from torch_mlir import torchscript class BasicModule(torch.nn.Module): def forward(self, x): return torch.ops.aten.sin(x) example_arg = torch.ones(2, 3) -example_args = torch_mlir.ExampleArgs.get(example_arg) +example_args = torchscript.ExampleArgs.get(example_arg) traced = torch.jit.trace(BasicModule(), example_arg) -print(torch_mlir.compile(traced, example_args)) +print(torchscript.compile(traced, example_args)) # CHECK: module # CHECK-DAG: func.func @forward traced = torch.jit.trace(BasicModule(), example_arg) try: # CHECK: Model does not have exported method 'nonexistent', requested in `example_args`. Consider adding `@torch.jit.export` to the method definition. - torch_mlir.compile(traced, torch_mlir.ExampleArgs().add_method("nonexistent", example_arg)) + torchscript.compile(traced, torchscript.ExampleArgs().add_method("nonexistent", example_arg)) except Exception as e: print(e) diff --git a/projects/pt1/python/test/compile_api/backend_legal_ops.py b/projects/pt1/python/test/compile_api/backend_legal_ops.py index 98c034930243..64ebf7a522fa 100644 --- a/projects/pt1/python/test/compile_api/backend_legal_ops.py +++ b/projects/pt1/python/test/compile_api/backend_legal_ops.py @@ -7,7 +7,7 @@ import torch -import torch_mlir +from torch_mlir import torchscript class AddmmModule(torch.nn.Module): def __init__(self): @@ -15,9 +15,9 @@ def __init__(self): def forward(self, x, y, z): return torch.ops.aten.addmm(x, y, z) -example_args = 3 * [torch_mlir.TensorPlaceholder([-1, -1], torch.float32)] +example_args = 3 * [torchscript.TensorPlaceholder([-1, -1], torch.float32)] -print(torch_mlir.compile(AddmmModule(), example_args, +print(torchscript.compile(AddmmModule(), example_args, output_type="torch", backend_legal_ops=["aten.addmm"])) # CHECK-LABEL: @forward # CHECK: torch.aten.addmm diff --git a/projects/pt1/python/test/compile_api/basic.py b/projects/pt1/python/test/compile_api/basic.py index 999d2fe4a820..0c516b620863 100644 --- a/projects/pt1/python/test/compile_api/basic.py +++ b/projects/pt1/python/test/compile_api/basic.py @@ -7,7 +7,7 @@ import torch -import torch_mlir +from torch_mlir import torchscript class TanhModule(torch.nn.Module): def __init__(self): @@ -18,24 +18,24 @@ def forward(self, x): tanh_example_input = torch.ones(2, 3) # Simplest case: One example argument. -print(torch_mlir.compile(TanhModule(), tanh_example_input)) +print(torchscript.compile(TanhModule(), tanh_example_input)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> # Use a TensorPlaceholder to represent dynamic axes. -placeholder = torch_mlir.TensorPlaceholder.like(tanh_example_input, dynamic_axes=[1]) -print(torch_mlir.compile(TanhModule(), placeholder)) +placeholder = torchscript.TensorPlaceholder.like(tanh_example_input, dynamic_axes=[1]) +print(torchscript.compile(TanhModule(), placeholder)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32> # Explicitly construct a TensorPlaceholder. -placeholder = torch_mlir.TensorPlaceholder([-1, 2], torch.float32) -print(torch_mlir.compile(TanhModule(), placeholder)) +placeholder = torchscript.TensorPlaceholder([-1, 2], torch.float32) +print(torchscript.compile(TanhModule(), placeholder)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[?,2],f32> -> !torch.vtensor<[?,2],f32> # Basic smoke test for the raw output type. -print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type=torch_mlir.OutputType.RAW)) +print(torchscript.compile(TanhModule(), tanh_example_input, output_type=torchscript.OutputType.RAW)) # CHECK: torch.nn_module { # CHECK: } : !torch.nn.Module<"{{.*}}.TanhModule"> @@ -47,12 +47,12 @@ def forward(self, lhs, rhs ): # N > 1 inputs. mm_example_inputs = [torch.ones(2, 3), torch.ones(3, 4)] -print(torch_mlir.compile(MmModule(), mm_example_inputs)) +print(torchscript.compile(MmModule(), mm_example_inputs)) # CHECK-LABEL: @forward # CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[2,4],f32> # Mixes Tensor's and TensorPlaceholder's. -mm_dynamic_inputs = [mm_example_inputs[0], torch_mlir.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1])] -print(torch_mlir.compile(MmModule(), mm_dynamic_inputs)) +mm_dynamic_inputs = [mm_example_inputs[0], torchscript.TensorPlaceholder.like(mm_example_inputs[1], dynamic_axes=[1])] +print(torchscript.compile(MmModule(), mm_dynamic_inputs)) # CHECK-LABEL: @forward # CHECK: torch.aten.mm %{{.*}}, %{{.*}} : !torch.vtensor<[2,3],f32>, !torch.vtensor<[3,?],f32> -> !torch.vtensor<[2,?],f32> diff --git a/projects/pt1/python/test/compile_api/make_fx.py b/projects/pt1/python/test/compile_api/make_fx.py index 62add20a576b..ec859d86e369 100644 --- a/projects/pt1/python/test/compile_api/make_fx.py +++ b/projects/pt1/python/test/compile_api/make_fx.py @@ -8,7 +8,7 @@ import functorch import torch -import torch_mlir +from torch_mlir import torchscript def simple(x): return x * x @@ -17,6 +17,6 @@ def simple(x): graph = functorch.make_fx(simple)(torch.randn(1,)) # Simplest case: One example argument. -print(torch_mlir.compile(graph, example_input)) +print(torchscript.compile(graph, example_input)) # CHECK-LABEL: @forward # CHECK: torch.aten.mul.Tensor %{{.*}} : !torch.vtensor<[1],f32>, !torch.vtensor<[1],f32> -> !torch.vtensor<[1],f32> \ No newline at end of file diff --git a/projects/pt1/python/test/compile_api/multiple_methods.py b/projects/pt1/python/test/compile_api/multiple_methods.py index f70b14ab68ab..067e775bfc71 100644 --- a/projects/pt1/python/test/compile_api/multiple_methods.py +++ b/projects/pt1/python/test/compile_api/multiple_methods.py @@ -6,7 +6,7 @@ # RUN: %PYTHON %s | FileCheck %s import torch -import torch_mlir +from torch_mlir import torchscript class TwoMethodsModule(torch.nn.Module): @@ -17,14 +17,14 @@ def cos(self, x): return torch.ops.aten.cos(x) -example_args = torch_mlir.ExampleArgs() +example_args = torchscript.ExampleArgs() example_args.add_method("sin", torch.ones(2, 3)) example_args.add_method("cos", torch.ones(2, 4)) # Note: Due to https://github.com/pytorch/pytorch/issues/88735 we need to # check the `use_tracing` case first. -print(torch_mlir.compile(TwoMethodsModule(), example_args, use_tracing=True)) +print(torchscript.compile(TwoMethodsModule(), example_args, use_tracing=True)) # CHECK: module # CHECK-DAG: func.func @sin # CHECK-DAG: func.func @cos @@ -34,8 +34,8 @@ def cos(self, x): # Otherwise the user would have to do this manually, which is tedious. This # technically mutates the user input model which is not great but probably okay # for this kind of API sugar. Users can always take full control of the process -# by scripting the model themselves before passing it to `torch_mlir.compile`. -print(torch_mlir.compile(TwoMethodsModule(), example_args)) +# by scripting the model themselves before passing it to `torchscript.compile`. +print(torchscript.compile(TwoMethodsModule(), example_args)) # CHECK: module # CHECK-DAG: func.func @sin # CHECK-DAG: func.func @cos diff --git a/projects/pt1/python/test/compile_api/output_type_spec.py b/projects/pt1/python/test/compile_api/output_type_spec.py index b975c2b5c0ae..92ed1e425d8d 100644 --- a/projects/pt1/python/test/compile_api/output_type_spec.py +++ b/projects/pt1/python/test/compile_api/output_type_spec.py @@ -7,7 +7,7 @@ import torch -import torch_mlir +from torch_mlir import torchscript class TanhModule(torch.nn.Module): def __init__(self): @@ -17,9 +17,9 @@ def forward(self, x): tanh_example_input = torch.ones(2, 3) -print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type=torch_mlir.OutputType.TORCH)) +print(torchscript.compile(TanhModule(), tanh_example_input, output_type=torchscript.OutputType.TORCH)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> -print(torch_mlir.compile(TanhModule(), tanh_example_input, output_type="torch")) +print(torchscript.compile(TanhModule(), tanh_example_input, output_type="torch")) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> diff --git a/projects/pt1/python/test/compile_api/tracing.py b/projects/pt1/python/test/compile_api/tracing.py index ea74fea12ab4..bbf652f07a28 100644 --- a/projects/pt1/python/test/compile_api/tracing.py +++ b/projects/pt1/python/test/compile_api/tracing.py @@ -7,7 +7,7 @@ import torch -import torch_mlir +from torch_mlir import torchscript class TanhModule(torch.nn.Module): @@ -17,38 +17,38 @@ def forward(self, x): tanh_example_input = torch.ones(2, 3) # Simplest case: One example argument. -print(torch_mlir.compile(TanhModule(), tanh_example_input, use_tracing=True)) +print(torchscript.compile(TanhModule(), tanh_example_input, use_tracing=True)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> # Simplest case: Passed as a tuple. -print(torch_mlir.compile(TanhModule(), (tanh_example_input,), use_tracing=True)) +print(torchscript.compile(TanhModule(), (tanh_example_input,), use_tracing=True)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> # Simplest case: Passed as a list. -print(torch_mlir.compile(TanhModule(), [tanh_example_input], use_tracing=True)) +print(torchscript.compile(TanhModule(), [tanh_example_input], use_tracing=True)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,3],f32> -> !torch.vtensor<[2,3],f32> # TensorPlaceholder support. -placeholder = torch_mlir.TensorPlaceholder.like( +placeholder = torchscript.TensorPlaceholder.like( tanh_example_input, dynamic_axes=[1]) -print(torch_mlir.compile(TanhModule(), [placeholder], +print(torchscript.compile(TanhModule(), [placeholder], use_tracing=True, ignore_traced_shapes=True)) # CHECK-LABEL: @forward # CHECK: torch.aten.tanh %{{.*}} : !torch.vtensor<[2,?],f32> -> !torch.vtensor<[2,?],f32> try: # CHECK: `ignore_traced_shapes` requires `use_tracing` - torch_mlir.compile(TanhModule(), [placeholder], ignore_traced_shapes=True) + torchscript.compile(TanhModule(), [placeholder], ignore_traced_shapes=True) except Exception as e: print(e) try: # CHECK: TensorPlaceholder can only be used with tracing when `ignore_traced_shapes=True` - torch_mlir.compile(TanhModule(), [placeholder], use_tracing=True) + torchscript.compile(TanhModule(), [placeholder], use_tracing=True) except Exception as e: print(e) @@ -60,13 +60,13 @@ def forward(self, x): try: # CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}' - torch_mlir.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True) + torchscript.compile(DictModule(), {'a': torch.tensor(3.0)}, use_tracing=True) except Exception as e: print(e) try: # CHECK: Only Tensor's, TensorPlaceholder's, or sequences of Tensor's and TensorPlaceholder's are supported as example args for method inputs. Got '{'a': tensor(3.)}' - torch_mlir.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True) + torchscript.compile(DictModule(), [{'a': torch.tensor(3.0)}], use_tracing=True) except Exception as e: print(e) diff --git a/projects/pt1/python/torch_mlir/dynamo.py b/projects/pt1/python/torch_mlir/dynamo.py index d3d7978bbfee..fa00bb9a847f 100644 --- a/projects/pt1/python/torch_mlir/dynamo.py +++ b/projects/pt1/python/torch_mlir/dynamo.py @@ -125,7 +125,7 @@ def make_simple_dynamo_backend(user_backend): Args: user_backend: A function with the signature used by ordinary TorchDynamo backends. But the torch.fx.GraphModule passed to it - will be normalized for consumption by `torch_mlir.compile`. + will be normalized for consumption by `torchscript.compile`. Returns: A function with the signature used by TorchDynamo backends. """ diff --git a/projects/pt1/python/torch_mlir/__init__.py b/projects/pt1/python/torch_mlir/torchscript.py similarity index 99% rename from projects/pt1/python/torch_mlir/__init__.py rename to projects/pt1/python/torch_mlir/torchscript.py index c916043c2cdd..f3412b83addb 100644 --- a/projects/pt1/python/torch_mlir/__init__.py +++ b/projects/pt1/python/torch_mlir/torchscript.py @@ -22,7 +22,7 @@ class OutputType(Enum): - """The kind of output that `torch_mlir.compile` can produce. + """The kind of output that `torchscript.compile` can produce. In MLIR terminology, this describes the mix of dialects that will be produced by the conversion process. @@ -392,13 +392,13 @@ def compile(model: torch.nn.Module, strip_overloads(model) # Get the model as JIT IR (TorchScript) for import. - # TODO: Longer-term, we probably need to split `torch_mlir.compile`. + # TODO: Longer-term, we probably need to split `torchscript.compile`. # There should be an "acquisition" step that does # tracing/scripting/importing from FX/using torchdynamo.export/etc. # + any lowering to the backend contract. Then there should be a # "backend lowering" step that does the actual lowering to each # backend. This separation should be visible at the Python API level, and - # we can implement a deliberately simplified API like `torch_mlir.compile` + # we can implement a deliberately simplified API like `torchscript.compile` # on top of those building blocks. if isinstance(model, torch.jit.ScriptModule): # If the user already converted the model to JIT IR themselves, just diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py index 6ad41dd6dccb..8c99278b0ec3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/linalg_on_tensors_backend.py @@ -6,7 +6,7 @@ from typing import Any import torch -import torch_mlir +from torch_mlir import torchscript from torch_mlir_e2e_test.linalg_on_tensors_backends.abc import LinalgOnTensorsBackend from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem @@ -30,7 +30,7 @@ def __init__(self, backend: LinalgOnTensorsBackend): def compile(self, program: torch.nn.Module) -> Any: example_args = convert_annotations_to_placeholders(program.forward) - module = torch_mlir.compile( + module = torchscript.compile( program, example_args, output_type="linalg-on-tensors") return self.backend.compile(module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/stablehlo_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/stablehlo_backend.py index 8a244b756e6c..1ab8a8d22b4f 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/stablehlo_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/stablehlo_backend.py @@ -6,7 +6,7 @@ from typing import Any import torch -import torch_mlir +from torch_mlir import torchscript from torch_mlir_e2e_test.stablehlo_backends.abc import StablehloBackend from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem @@ -30,7 +30,7 @@ def __init__(self, backend: StablehloBackend): def compile(self, program: torch.nn.Module) -> Any: example_args = convert_annotations_to_placeholders(program.forward) - module = torch_mlir.compile(program, example_args, output_type="stablehlo") + module = torchscript.compile(program, example_args, output_type="stablehlo") return self.backend.compile(module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py index c53227acf36a..e5c2475c7669 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py @@ -17,7 +17,7 @@ from torch_mlir._dynamo_fx_importer import import_fx_graph_as_func from torch_mlir.dynamo import _get_decomposition_table -from torch_mlir import ( +from torch_mlir.torchscript import ( _example_args, OutputType, BACKEND_LEGAL_OPS, diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py index 8efab87a2bfe..8aa2d0e63eb6 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/tosa_backend.py @@ -6,7 +6,7 @@ from typing import Any import torch -import torch_mlir +from torch_mlir import torchscript from torch_mlir_e2e_test.tosa_backends.abc import TosaBackend from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem @@ -30,7 +30,7 @@ def __init__(self, backend: TosaBackend, use_make_fx: bool = False): def compile(self, program: torch.nn.Module) -> Any: example_args = convert_annotations_to_placeholders(program.forward) - module = torch_mlir.compile( + module = torchscript.compile( program, example_args, output_type="tosa", use_make_fx=self.use_make_fx) return self.backend.compile(module) diff --git a/projects/pt1/python/torch_mlir_e2e_test/utils.py b/projects/pt1/python/torch_mlir_e2e_test/utils.py index 403c455cba64..e3a76581f668 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/utils.py +++ b/projects/pt1/python/torch_mlir_e2e_test/utils.py @@ -3,13 +3,13 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -from torch_mlir import TensorPlaceholder +from torch_mlir.torchscript import TensorPlaceholder from torch_mlir_e2e_test.annotations import TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME def convert_annotations_to_placeholders(forward_method): """Converts the annotations on a forward method into tensor placeholders. - These placeholders are suitable for being passed to `torch_mlir.compile`. + These placeholders are suitable for being passed to `torchscript.compile`. """ annotations = getattr(forward_method, TORCH_MLIR_ARG_ANNOTATIONS_ATTR_NAME) placeholders = [] diff --git a/projects/pt1/test/python/custom_op_shape_dtype_fn.py b/projects/pt1/test/python/custom_op_shape_dtype_fn.py index a46f1c594031..a3a2b965d655 100644 --- a/projects/pt1/test/python/custom_op_shape_dtype_fn.py +++ b/projects/pt1/test/python/custom_op_shape_dtype_fn.py @@ -5,7 +5,7 @@ import torch import torch.multiprocessing as mp import torch.utils.cpp_extension -import torch_mlir +from torch_mlir import torchscript from torch_mlir_e2e_test.annotations import export, annotate_args @@ -56,7 +56,7 @@ def run(): mod = CustomOpExampleModule() mod.eval() - module = torch_mlir.compile( + module = torchscript.compile( mod, torch.ones(3, 4), output_type="torch", diff --git a/projects/pt1/test/python/importer/jit_ir/node_import/unimplemented.py b/projects/pt1/test/python/importer/jit_ir/node_import/unimplemented.py index eb6bb2f09ff3..533ef7586748 100644 --- a/projects/pt1/test/python/importer/jit_ir/node_import/unimplemented.py +++ b/projects/pt1/test/python/importer/jit_ir/node_import/unimplemented.py @@ -1,5 +1,5 @@ import torch -import torch_mlir +from torch_mlir import torchscript # RUN: %PYTHON %s | torch-mlir-opt | FileCheck %s @@ -39,6 +39,6 @@ def forward(self, data): with torch.no_grad(): return data -output_type = torch_mlir.OutputType.RAW -mod = torch_mlir.compile(Model(), [torch.tensor([0, 1, 2, 3])], output_type) +output_type = torchscript.OutputType.RAW +mod = torchscript.compile(Model(), [torch.tensor([0, 1, 2, 3])], output_type) print(mod) diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index d725aae6c584..6300df01e4ec 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -39,6 +39,13 @@ declare_mlir_python_sources(TorchMLIRPythonSources.Importers extras/onnx_importer.py ) +declare_mlir_python_sources(TorchMLIRPythonSources.PublicAPI + ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" + ADD_TO_PARENT TorchMLIRPythonSources + SOURCES + fx.py +) + declare_mlir_python_sources(TorchMLIRPythonSources.Tools ROOT_DIR "${TORCH_MLIR_PYTHON_ROOT_DIR}" ADD_TO_PARENT TorchMLIRPythonSources diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py new file mode 100644 index 000000000000..78b46cc3ea29 --- /dev/null +++ b/python/torch_mlir/fx.py @@ -0,0 +1,25 @@ +from typing import Optional + +import torch +import torch.export +import torch.nn as nn + +from torch_mlir.extras.fx_importer import FxImporter +from torch_mlir import ir +from torch_mlir.dialects import torch as torch_d + +def export_and_import( + f, + *args, + fx_importer: Optional[FxImporter] = None, + constraints: Optional[torch.export.Constraint] = None, + **kwargs, +): + context = ir.Context() + torch_d.register_dialect(context) + + if fx_importer is None: + fx_importer = FxImporter(context=context) + prog = torch.export.export(f, args, kwargs, constraints=constraints) + fx_importer.import_frozen_exported_program(prog) + return fx_importer.module_op diff --git a/test/python/compile.py b/test/python/compile.py index fc2917e9c76a..990738085020 100644 --- a/test/python/compile.py +++ b/test/python/compile.py @@ -3,7 +3,7 @@ import gc import sys import torch -import torch_mlir +from torch_mlir import torchscript def run_test(f): @@ -26,7 +26,7 @@ def forward(self, x): # CHECK-LABEL: TEST: test_enable_ir_printing @run_test def test_enable_ir_printing(): - torch_mlir.compile(TinyModel(), + torchscript.compile(TinyModel(), torch.ones(1, 3, 20, 20), output_type="linalg-on-tensors", enable_ir_printing=True) diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index acd2a559fa52..36c554862506 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -1,5 +1,3 @@ -# Copyright 2023 Advanced Micro Devices, Inc -# # Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception @@ -13,26 +11,7 @@ import torch.export import torch.nn as nn -from torch_mlir.extras.fx_importer import FxImporter -from torch_mlir import ir -from torch_mlir.dialects import torch as torch_d - - -def export_and_import( - f, - *args, - fx_importer: Optional[FxImporter] = None, - constraints: Optional[torch.export.Constraint] = None, - **kwargs, -): - context = ir.Context() - torch_d.register_dialect(context) - - if fx_importer is None: - fx_importer = FxImporter(context=context) - prog = torch.export.export(f, args, kwargs, constraints=constraints) - fx_importer.import_frozen_exported_program(prog) - return fx_importer.module_op +from torch_mlir import fx def run(f): @@ -75,5 +54,5 @@ def __init__(self): def forward(self, x): return torch.tanh(x) * get_a() * self.b * self.p - m = export_and_import(Basic(), torch.randn(3, 4)) + m = fx.export_and_import(Basic(), torch.randn(3, 4)) print(m) From 723b8b1d285638b8b0e594345df821a6f1f8c468 Mon Sep 17 00:00:00 2001 From: James Newling Date: Wed, 7 Feb 2024 11:55:38 +0000 Subject: [PATCH 168/283] Fix dev docs error/typo (#2880) Just a one line change in a .md file --- docs/development.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/development.md b/docs/development.md index 3e9192f5fa8e..27ff8b7bfaad 100644 --- a/docs/development.md +++ b/docs/development.md @@ -77,7 +77,7 @@ By default we download the latest version of libtorch. We have an experimental p * Enabling `--debug` and `--debug-only` flags (see [MLIR docs](https://mlir.llvm.org/getting_started/Debugging/)) for the `torch-mlir-opt` tool ```shell -DCMAKE_BUILD_TYPE=RelWithDebInfo \ # or =Debug - -DIREE_ENABLE_ASSERTIONS=ON \ + -DLLVM_ENABLE_ASSERTIONS=ON \ ``` ### Building against a pre-built LLVM From fc04bc7ee9732428baf6fdba2a728b2b3bee021a Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Wed, 7 Feb 2024 14:00:46 -0500 Subject: [PATCH 169/283] [torch] AtenSliceOp folder that produces splat results (#2869) Includes `slice` folder and lit tests --------- Co-authored-by: Xida Ren --- lib/Dialect/Torch/IR/TorchOps.cpp | 50 ++++++++++++++++++++++------ test/Dialect/Torch/canonicalize.mlir | 47 ++++++++++++++++++++++++++ 2 files changed, 87 insertions(+), 10 deletions(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 1857aff4dbbd..98d0369d9f51 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2552,22 +2552,52 @@ OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { - int64_t start, end, step; - if (matchPattern(getStart(), m_TorchConstantInt(&start)) && - matchPattern(getEnd(), m_TorchConstantInt(&end)) && - matchPattern(getStep(), m_TorchConstantInt(&step)) && step == 1 && - start == 0 && end == std::numeric_limits::max()) + DenseElementsAttr input = + dyn_cast_or_null(adaptor.getSelf()); + IntegerAttr start = dyn_cast_or_null(adaptor.getStart()); + IntegerAttr end = dyn_cast_or_null(adaptor.getEnd()); + IntegerAttr step = dyn_cast_or_null(adaptor.getStep()); + IntegerAttr dim = dyn_cast_or_null(adaptor.getDim()); + + if (start && end && step && step.getValue().getSExtValue() == 1 && + start.getValue().getSExtValue() == 0 && + end.getValue().getSExtValue() == std::numeric_limits::max()) return getOperand(0); - auto inType = getOperand(0).getType().dyn_cast(); - auto outType = getResult().getType().dyn_cast(); - if (inType != outType) - return nullptr; - if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes()) + auto inType = getOperand(0).getType().dyn_cast(); + auto outType = getResult().getType().dyn_cast(); + if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() || + !inType.hasDtype() || !outType.hasDtype() || + inType.getDtype() != outType.getDtype()) return nullptr; + if (inType.getSizes().size() != outType.getSizes().size() || !inType.areAllSizesKnown() || !outType.areAllSizesKnown()) return nullptr; + + if (input && input.isSplat()) + return DenseElementsAttr::get( + outType.toBuiltinTensor().clone(inType.getDtype()), + input.getSplatValue()); + + // If the output is a single value we can index into a constant input and grab + // that single value: + if (input && start && dim && + llvm::all_of(outType.getSizes(), [](int64_t dim) { return dim == 1; })) { + bool unaryNonDim = true; + int64_t dimInt = dim.getValue().getSExtValue(); + for (int i = 0, s = inType.getSizes().size(); i < s; ++i) { + unaryNonDim &= inType.getSizes()[i] == 1 || i == dimInt; + } + if (unaryNonDim) { + Attribute value = + input.getValues()[start.getValue().getSExtValue()]; + return DenseElementsAttr::get( + outType.toBuiltinTensor().clone(inType.getDtype()), value); + } + } + + // If the input and output shapes are the same we can just fold: for (size_t i = 0; i < inType.getSizes().size(); ++i) { if (inType.getSizes()[i] != outType.getSizes()[i]) return nullptr; diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 77a9e8ad3d61..1f0d7971abd3 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2104,6 +2104,53 @@ func.func @torch.aten.slice.tensor$no_fold_step(%arg0: !torch.vtensor<[?],f32>, return %0 : !torch.vtensor<[?],f32> } +// ----- +// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_1() -> (!torch.vtensor<[1,1],si64>, !torch.vtensor<[1,1],si64>) { +// CHECK-NOT: torch.aten.slice.Tensor +// CHECK: %[[RET_0:.*]] = torch.vtensor.literal(dense<50> : tensor<1x1xsi64>) : !torch.vtensor<[1,1],si64> +// CHECK-NOT: torch.aten.slice.Tensor +// CHECK: %[[RET_1:.*]] = torch.vtensor.literal(dense<70> : tensor<1x1xsi64>) : !torch.vtensor<[1,1],si64> +// CHECK-NOT: torch.aten.slice.Tensor +// CHECK: return %[[RET_0]], %[[RET_1]] +func.func @torch.aten.slice.tensor$fold_dim_1() -> (!torch.vtensor<[1, 1],si64>, !torch.vtensor<[1, 1],si64>) { + %tensor = torch.vtensor.literal(dense<[[10,20,30,40,50,60,70,80,90,100]]> : tensor<1x10xsi64>) : !torch.vtensor<[1, 10],si64> + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %int6 = torch.constant.int 6 + %int7 = torch.constant.int 7 + %dim = torch.constant.int 1 + %0 = torch.aten.slice.Tensor %tensor, %dim, %int4, %int5, %int1 : !torch.vtensor<[1, 10], si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], si64> + %1 = torch.aten.slice.Tensor %tensor, %dim, %int6, %int7, %int1 : !torch.vtensor<[1, 10], si64>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], si64> + return %0, %1 : !torch.vtensor<[1,1],si64>, !torch.vtensor<[1,1],si64> +} + + +// ----- +// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) { +// CHECK-NOT: torch.aten.slice.Tensor +// CHECK: %[[RET_0:.*]] = torch.vtensor.literal(dense<1.600000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> +// CHECK-NOT: torch.aten.slice.Tensor +// CHECK: %[[RET_1:.*]] = torch.vtensor.literal(dense<6.400000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> +// CHECK-NOT: torch.aten.slice.Tensor +// CHECK: return %[[RET_0]], %[[RET_1]] : !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32> +func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1],f32>) { + %tensor = torch.vtensor.literal(dense<[[2.0],[4.0],[8.0],[16.0],[32.0],[64.0],[128.0],[256.0],[512.0],[1024.0]]> : tensor<10x1xf32>) : !torch.vtensor<[10, 1],f32> + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 1 + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %int5 = torch.constant.int 5 + %int6 = torch.constant.int 6 + %dim = torch.constant.int 0 + %0 = torch.aten.slice.Tensor %tensor, %dim, %int3, %int4, %int1 : !torch.vtensor<[10, 1], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], f32> + %1 = torch.aten.slice.Tensor %tensor, %dim, %int5, %int6, %int1 : !torch.vtensor<[10, 1], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], f32> + return %0, %1 : !torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1], f32> +} + + + // CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { // CHECK: %int-1 = torch.constant.int -1 // CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64> From 32dbf99ce24cd8ebf17db0dcba97a46f5f7a70d4 Mon Sep 17 00:00:00 2001 From: mmakevic <150796284+mmakevic@users.noreply.github.com> Date: Wed, 7 Feb 2024 21:34:52 +0100 Subject: [PATCH 170/283] Implement lowering of torch.aten.all.dim (#2873) Lowering of torch.aten.all.dim to linalg. Per PyTorch documentation: > This function matches the behaviour of NumPy in returning output of dtype bool for all supported dtypes except uint8. For uint8 the dtype of output is uint8 itself. Since there is no support for ui8 in torch-mlir currently (https://github.com/llvm/torch-mlir/pull/1384#issuecomment-1260011334) implementation returns failure for that case. --- lib/Conversion/TorchToLinalg/Reduction.cpp | 16 +++++ .../Transforms/AbstractInterpLibrary.cpp | 17 +++++ .../build_tools/abstract_interp_lib_gen.py | 10 +++ .../test_suite/reduction.py | 72 +++++++++++++++++++ 4 files changed, 115 insertions(+) diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index da5ee799a566..50fa1f8e610d 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -277,6 +277,10 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc, if (isa(op) || isa(op)) return b.create(loc, b.getZeroAttr(elementType)); + if (isa(op)) { + return b.create(loc, b.getBoolAttr(true)); + } + op->emitError("unimplemented lowering in createInitElementForReduceOp"); return nullptr; } @@ -357,6 +361,11 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, auto ord = b.create(loc, twoAttr); auto pow = b.create(loc, abs, ord); return b.create(loc, pow, result); + } else if (isa(op)) { + Value elem = payloadArgs[0]; + Value result = payloadArgs[1]; + Value self = convertScalarToDtype(b, loc, elem, resultElementType); + return b.create(loc, self, result); } op->emitError("unimplemented lowering in createLinalgPayloadForReduceOp"); return nullptr; @@ -447,6 +456,9 @@ class ConvertReductionOp : public ConversionPattern { if (auto normOp = dyn_cast(op)) return computeReductionOpInfoForDimVariantOp(normOp, operands, rewriter); + if (auto allOp = dyn_cast(op)) + return computeReductionOpInfoForDimVariantOp(allOp, operands, rewriter); + return rewriter.notifyMatchFailure(op, "not a supported reduce op"); } @@ -535,6 +547,9 @@ class ConvertReductionOp : public ConversionPattern { !elemType.isa()) return rewriter.notifyMatchFailure( op, "only float types are valid for vector norm ops"); + if (isa(op) && elemType.isa() && + elemType.getIntOrFloatBitWidth() == 8) + return rewriter.notifyMatchFailure(op, "uint8 is not supported"); // No checks for all other reduction operations return success(); } @@ -610,6 +625,7 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality( target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 57ece8cfdd7e..4290ce23c44b 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -7006,6 +7006,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.all.dim\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.list {\n" +" %0 = torch.derefine %arg1 : !torch.int to !torch.optional\n" +" %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.max.dim\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.tuple, list> {\n" " %0 = torch.derefine %arg1 : !torch.int to !torch.optional\n" " %1 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %0, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" @@ -11809,6 +11814,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.all.dim\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.bool) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %int0 = torch.constant.int 0\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = torch.aten.eq.int %0#1, %int0 : !torch.int, !torch.int -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %int11 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.min\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 922b207a2c57..dadd87a15a08 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -543,6 +543,9 @@ def aten〇one_hot〡shape(self: List[int], num_classes: int = -1) -> List[int]: def aten〇any〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]: return upstream_shape_functions.argmax(self, dim, keepdim) +def aten〇all〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> List[int]: + return upstream_shape_functions.argmax(self, dim, keepdim) + def aten〇max〇dim〡shape(self: List[int], dim: int, keepdim: bool = False) -> Tuple[List[int], List[int]]: reduced_shape = upstream_shape_functions.argmax(self, dim, keepdim) return reduced_shape, reduced_shape @@ -3766,6 +3769,13 @@ def aten〇any〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim return self_dtype return torch.bool +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0)) +def aten〇all〇dim〡dtype(self_rank_dtype: Tuple[int, int], dim: int, keepdim: bool = False) -> int: + self_rank, self_dtype = self_rank_dtype + if self_dtype == torch.uint8: + return self_dtype + return torch.bool + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇min〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 8418d1ae8f5a..ea2ff1609f3a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -316,6 +316,78 @@ def ReduceProdDimIntFloatModule_basic(module, tu: TestUtils): # ============================================================================== +class ReduceAllDimEmpty(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a, dim=0, keepdim=False) + +@register_test_case(module_factory=lambda: ReduceAllDimEmpty()) +def ReduceAllDimEmpty_basic(module, tu: TestUtils): + module.forward(torch.tensor([])) + +# ============================================================================== + +class ReduceAllDimFloat(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1,-1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a, dim=1, keepdim=True) + +@register_test_case(module_factory=lambda: ReduceAllDimFloat()) +def ReduceAllDimFloat_basic(module, tu: TestUtils): + module.forward(torch.tensor([[5.0,1e-6,-5.0],[0,5.0,0]])) + +# ============================================================================== + +class ReduceAllDimInt(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1,-1], torch.int32, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a, dim=1, keepdim=True) + +@register_test_case(module_factory=lambda: ReduceAllDimInt()) +def ReduceAllDimInt_basic(module, tu: TestUtils): + module.forward(torch.tensor([[5,-5,0],[5,1e10,5]]).to(torch.int32)) + +# ============================================================================== + +class ReduceAllDimBool(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1,-1], torch.bool, True), + ]) + def forward(self, a): + return torch.ops.aten.all(a, dim=1, keepdim=False) + +@register_test_case(module_factory=lambda: ReduceAllDimBool()) +def ReduceAllDimBool_basic(module, tu: TestUtils): + module.forward(torch.tensor([[True, False, True], [True, True, True]])) + +# ============================================================================== + class ReduceMaxAlongDim(torch.nn.Module): def __init__(self): super().__init__() From 23647ab2d178f8f2edc3bb4a0c207cca7d347ca5 Mon Sep 17 00:00:00 2001 From: Dave Liddell <44620210+daveliddell@users.noreply.github.com> Date: Wed, 7 Feb 2024 17:17:15 -0700 Subject: [PATCH 171/283] [torhc] aten.index_select folder (#2871) Folds aten::index_select ops under the following conditions: 1. If the input and output are the same shape, the indexing operation is a NOP, so just return the input. 2. If the input has shape <1x1x...xNx...x1> (all 1's except for one dim), and the output shape is <1x1x...x1> (all 1's), then there is a single index, so extract the single element value and return a tensor with that value. --------- Co-authored-by: Dave Liddell --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + .../torch-mlir/Dialect/Torch/IR/TorchOps.h | 38 +++++++++ lib/Dialect/Torch/IR/TorchOps.cpp | 85 +++++++++++++++++++ .../build_tools/torch_ods_gen.py | 2 +- test/Dialect/Torch/canonicalize.mlir | 52 ++++++++++++ 5 files changed, 177 insertions(+), 1 deletion(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 7f0f5af7e43f..885d367a987e 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -9785,6 +9785,7 @@ def Torch_AtenIndexSelectOp : Torch_Op<"aten.index_select", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; } def Torch_Aten_IndexPutImplOp : Torch_Op<"aten._index_put_impl", [ diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h index 64b70e097c39..c82a98cc5aba 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h @@ -294,6 +294,44 @@ bool isListPotentiallyMutated(Value list); /// the list. bool potentiallyMutatesListOperands(Operation *op); +/// Returns the value from an `IntegerAttr` as an `int64_t`. +/// +/// @param intAttr the `IntegerAttr` from which to extract the value +/// @return the value as an `int64_t` +/// +/// Regardless of the signed-ness of the attribute, this function returns +/// the value as a signed integer, which implies that if the attribute has +/// a 64-bit unsigned value, it will be converted to an int64_t in the manner +/// that uint64_t is cast to int64_t in C++. +inline int64_t getIntAttrAsSigned(IntegerAttr intAttr) { + if (intAttr.getType().isUnsignedInteger()) + return intAttr.getValue().getZExtValue(); + return intAttr.getValue().getSExtValue(); +} + +/// Returns the value from an `IntegerAttr` as an integral index. +/// +/// @param intAttr the `IntegerAttr` from which to extract the index +/// @param dimSize the size of the dimension that the attribute indexes into +/// @return the index value +/// +/// Use this function when the given `IntegerAttr` represents an index into +/// a range, such as an index into a tensor dimension. If `dimSize` is given, +/// negative index values are converted into positive vales by counting +/// elements from the "right" side of the dimension, as in python, numpy, etc. +/// For example, an index of -2 and a dimSize of 10 returns 8 because 8 is the +/// 2nd index from the high end of the range 0 to 9. If `dimSize` is not +/// given, any negative indices are returned as negative numbers. +/// +/// No bounds checking is performed on the index to ensure that it is within +/// the legal range for `dimSize`. +inline int64_t getIntAttrAsIndex(IntegerAttr intAttr, int dimSize = -1) { + int64_t signedIndex = getIntAttrAsSigned(intAttr); + if (dimSize < 0 || signedIndex > 0) + return signedIndex; + return dimSize + signedIndex; // count backwards from dimSize +} + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 98d0369d9f51..9d100b11dc27 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2911,6 +2911,91 @@ OpFoldResult AtenDivIntOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenIndexSelectOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) { + auto self = getSelf(); + auto index = getIndex(); + auto selfTy = dyn_cast(self.getType()); + auto indexTy = dyn_cast(index.getType()); + auto resultTy = dyn_cast(getType()); + if (!selfTy || !indexTy || !resultTy || !selfTy.hasSizes() || + !indexTy.hasSizes() || !resultTy.hasSizes() || !selfTy.hasDtype() || + !indexTy.hasDtype() || !resultTy.hasDtype()) + return nullptr; + + auto selfSizes = selfTy.getSizes(); + auto indexSizes = indexTy.getSizes(); + auto resultSizes = resultTy.getSizes(); + + if (selfTy.getDtype() != resultTy.getDtype() || + selfSizes.size() != resultSizes.size() || indexSizes.size() != 1) + return nullptr; + + // If the selection results in a tensor of the same dimensions as the + // input, the selection must have specified every index of the input, + // so the result is exactly the same as the input. + + bool fullTensor = true; + for (int i = 0, s = selfSizes.size(); i < s; ++i) { + fullTensor &= selfSizes[i] == resultSizes[i]; + fullTensor &= selfSizes[i] != Torch::kUnknownSize; + fullTensor &= resultSizes[i] != Torch::kUnknownSize; + } + + if (fullTensor && indexSizes[0] == 1) + return self; + + // If the input tensor, index dimension, or indexes are non-constant, + // can't fold. + + auto selfAttr = dyn_cast_or_null(adaptor.getSelf()); + auto dimAttr = dyn_cast_or_null(adaptor.getDim()); + auto indexAttr = dyn_cast_or_null(adaptor.getIndex()); + + if (!selfAttr || !dimAttr || !indexAttr) + return {}; + + // If the input's dimensions are all 1 except for one dimension, and if + // there is a single index in the index list (as detected by the result + // dimension being 1), then fold to a <1x1x...x1> tensor literal containing + // a single element. Handles float and int types. + + int64_t dimInt = dimAttr.getInt(); + // If the selected dim is negative, count backwards from the last dim + if (dimInt < 0) + dimInt = selfSizes.size() + dimInt; + assert(uint64_t(dimInt) < selfSizes.size() && + "Selected dim > number of dims"); + + for (int i = 0, s = selfSizes.size(); i < s; ++i) { + if ((selfSizes[i] != 1 && i != dimInt) || resultSizes[i] != 1) + return nullptr; + } + + // Get the single index value for the selected dimension + auto splatValue = indexAttr.getSplatValue(); + int64_t indexInt = getIntAttrAsIndex(splatValue, selfSizes[dimInt]); + + // Extract the single constant value from the input tensor and turn the + // extracted value into a single-element tensor of the output shape and dtype + auto splattr = selfAttr.getValues()[indexInt]; + + auto dty = resultTy.getDtype(); + auto attrTy = resultTy.toBuiltinTensor().clone(dty); + if (auto floatAttr = dyn_cast(splattr)) + return DenseElementsAttr::get( + attrTy, FloatAttr::get(dty, floatAttr.getValueAsDouble())); + + if (auto intAttr = dyn_cast(splattr)) { + return DenseElementsAttr::get(attrTy, + IntegerAttr::get(dty, intAttr.getValue())); + } + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenItemOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 7893c26db947..6ee3a3e34cae 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -616,7 +616,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::broadcast_to : (Tensor, int[]) -> (Tensor)", has_folder=True) emit("aten::index.Tensor : (Tensor, Tensor?[]) -> (Tensor)") emit("aten::index.Tensor_hacked_twin : (Tensor, Tensor[]) -> (Tensor)") - emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)") + emit("aten::index_select : (Tensor, int, Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::_index_put_impl : (Tensor, Tensor?[], Tensor, bool, bool) -> (Tensor)") emit("aten::item : (Tensor) -> (Scalar)", has_folder=True) emit("aten::masked_select : (Tensor, Tensor) -> (Tensor)") diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 1f0d7971abd3..ac1797b4c398 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2280,3 +2280,55 @@ func.func @torch.aten.detach$canonicalize(%arg0: !torch.tensor<[1],f32>) -> !tor %1 = torch.aten.detach %arg0 : !torch.tensor<[1],f32> -> !torch.tensor return %1 : !torch.tensor } + +// CHECK-LABEL: func.func @torch.aten.index_select$noop( +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,2,3],si64> +// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[1,2,3],si64> +func.func @torch.aten.index_select$noop(%arg0 : !torch.vtensor<[1,2,3],si64>, %arg1 : !torch.int, %arg2 : !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,2,3],si64> { + %0 = torch.aten.index_select %arg0, %arg1, %arg2 : !torch.vtensor<[1,2,3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1,2,3],si64> + return %0 : !torch.vtensor<[1,2,3],si64> +} + +// CHECK-LABEL: func.func @torch.aten.index_select$const_si_si( +// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<60> : tensor<1xsi64>) : !torch.vtensor<[1],si64> +// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],si64> +func.func @torch.aten.index_select$const_si_si() -> !torch.vtensor<[1],si64> { + %tensor = torch.vtensor.literal(dense<[10,20,30,40,50,60,70,80,90,100]> : tensor<10xsi64>) : !torch.vtensor<[10],si64> + %dim = torch.constant.int 0 + %index = torch.vtensor.literal(dense<5> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> + return %0 : !torch.vtensor<[1],si64> +} + +// CHECK-LABEL: func.func @torch.aten.index_select$const_si_ui( +// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<60> : tensor<1xsi64>) : !torch.vtensor<[1],si64> +// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],si64> +func.func @torch.aten.index_select$const_si_ui() -> !torch.vtensor<[1],si64> { + %tensor = torch.vtensor.literal(dense<[10,20,30,40,50,60,70,80,90,100]> : tensor<10xsi64>) : !torch.vtensor<[10],si64> + %dim = torch.constant.int 0 + %index = torch.vtensor.literal(dense<5> : tensor<1xui64>) : !torch.vtensor<[1],ui64> + %0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],si64>, !torch.int, !torch.vtensor<[1],ui64> -> !torch.vtensor<[1],si64> + return %0 : !torch.vtensor<[1],si64> +} + +// CHECK-LABEL: func.func @torch.aten.index_select$const_f32_ui( +// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<6.6{{.*}}> : tensor<1xf32>) : !torch.vtensor<[1],f32> +// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],f32> +func.func @torch.aten.index_select$const_f32_ui() -> !torch.vtensor<[1],f32> { + %tensor = torch.vtensor.literal(dense<[1.1,2.2,3.3,4.4,5.5,6.6,7.7,8.8,9.9,10.0]> : tensor<10xf32>) : !torch.vtensor<[10],f32> + %dim = torch.constant.int 0 + %index = torch.vtensor.literal(dense<5> : tensor<1xui64>) : !torch.vtensor<[1],ui64> + %0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],f32>, !torch.int, !torch.vtensor<[1],ui64> -> !torch.vtensor<[1],f32> + return %0 : !torch.vtensor<[1],f32> +} + +// CHECK-LABEL: func.func @torch.aten.index_select$const_f32_si_neg( +// CHECK-NEXT: %[[RES:.*]] = torch.vtensor.literal(dense<7.{{.*}}> : tensor<1xf32>) : !torch.vtensor<[1],f32> +// CHECK-NEXT: return %[[RES]] : !torch.vtensor<[1],f32> +func.func @torch.aten.index_select$const_f32_si_neg() -> !torch.vtensor<[1],f32> { + %tensor = torch.vtensor.literal(dense<[1.1,2.2,3.3,4.4,5.5,6.6,7.7,8.8,9.9,10.0]> : tensor<10xf32>) : !torch.vtensor<[10],f32> + %dim = torch.constant.int -1 + %index = torch.vtensor.literal(dense<-4> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],f32>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],f32> + return %0 : !torch.vtensor<[1],f32> +} From a8aad2a5ab8fefd22c5d76413489f53a7a7044c1 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 7 Feb 2024 16:43:31 -0800 Subject: [PATCH 172/283] [torch] Add `torch.aten.where.*` folders (#2886) Where operation can be statically computed when involving splats of known value. Added handling these cases with multiple tests. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 4 + lib/Dialect/Torch/IR/TorchOps.cpp | 120 ++++++++++++++++++ .../build_tools/torch_ods_gen.py | 8 +- 3 files changed, 128 insertions(+), 4 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 885d367a987e..726c166849f7 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -10635,6 +10635,7 @@ def Torch_AtenWhereSelfOp : Torch_Op<"aten.where.self", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; } def Torch_AtenWhereScalarOp : Torch_Op<"aten.where.Scalar", [ @@ -10660,6 +10661,7 @@ def Torch_AtenWhereScalarOp : Torch_Op<"aten.where.Scalar", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; } def Torch_AtenWhereScalarOtherOp : Torch_Op<"aten.where.ScalarOther", [ @@ -10685,6 +10687,7 @@ def Torch_AtenWhereScalarOtherOp : Torch_Op<"aten.where.ScalarOther", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; } def Torch_AtenWhereScalarSelfOp : Torch_Op<"aten.where.ScalarSelf", [ @@ -10710,6 +10713,7 @@ def Torch_AtenWhereScalarSelfOp : Torch_Op<"aten.where.ScalarSelf", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; } def Torch_AtenNanToNumOp : Torch_Op<"aten.nan_to_num", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 9d100b11dc27..4cbd49843d00 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -3152,6 +3152,126 @@ OpFoldResult AtenCeilFloatOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenWhereSelfOp +//===----------------------------------------------------------------------===// + +static Attribute getBroadcastedAttr(Attribute attr, ValueTensorType ty) { + if (!attr || !ty.hasDtype() || !ty.hasSizes()) + return nullptr; + + auto dty = ty.getDtype(); + + if (auto valueDense = dyn_cast(attr)) { + if (!valueDense.isSplat()) + return nullptr; + auto splattr = valueDense.getSplatValue(); + auto attrty = ty.toBuiltinTensor().clone(dty); + return DenseElementsAttr::get(attrty, splattr); + } + + if (auto intAttr = dyn_cast_or_null(attr)) { + if (!isa(dty)) + return nullptr; + int64_t intval = intAttr.getInt(); + auto attrty = ty.toBuiltinTensor().clone(dty); + return DenseElementsAttr::get(attrty, IntegerAttr::get(dty, intval)); + } + + if (auto fpAttr = dyn_cast_or_null(attr)) { + if (!isa(dty)) + return nullptr; + double dblval = fpAttr.getValueAsDouble(); + auto attrty = ty.toBuiltinTensor().clone(dty); + return DenseElementsAttr::get(attrty, FloatAttr::get(dty, dblval)); + } + + return nullptr; +} + +OpFoldResult AtenWhereSelfOp::fold(FoldAdaptor adaptor) { + auto dense = dyn_cast_or_null(adaptor.getCondition()); + auto resultTy = dyn_cast(getType()); + if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes() || !dense || + !dense.isSplat()) + return nullptr; + + auto condattr = dense.getSplatValue(); + auto value = getSelf(); + auto valueAttr = adaptor.getSelf(); + if (condattr.isZero()) { + value = getOther(); + valueAttr = adaptor.getOther(); + } + + auto valueTy = dyn_cast(value.getType()); + if (valueTy && valueTy.hasSizes() && valueTy.hasDtype() && + valueTy == resultTy) + return value; + + return getBroadcastedAttr(valueAttr, resultTy); +} + +//===----------------------------------------------------------------------===// +// AtenWhereScalarOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenWhereScalarOp::fold(FoldAdaptor adaptor) { + auto dense = dyn_cast_or_null(adaptor.getCondition()); + auto resultTy = dyn_cast(getType()); + if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes() || !dense || + !dense.isSplat()) + return nullptr; + + auto condattr = dense.getSplatValue(); + auto valueAttr = adaptor.getSelf(); + if (condattr.isZero()) { + valueAttr = adaptor.getOther(); + } + + return getBroadcastedAttr(valueAttr, resultTy); +} + +//===----------------------------------------------------------------------===// +// AtenWhereScalarOtherOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenWhereScalarOtherOp::fold(FoldAdaptor adaptor) { + auto dense = dyn_cast_or_null(adaptor.getCondition()); + auto resultTy = dyn_cast(getType()); + if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes() || !dense || + !dense.isSplat()) + return nullptr; + + auto condattr = dense.getSplatValue(); + auto valueAttr = adaptor.getSelf(); + if (condattr.isZero()) { + valueAttr = adaptor.getOther(); + } + + return getBroadcastedAttr(valueAttr, resultTy); +} + +//===----------------------------------------------------------------------===// +// AtenWhereScalarSelfOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenWhereScalarSelfOp::fold(FoldAdaptor adaptor) { + auto dense = dyn_cast_or_null(adaptor.getCondition()); + auto resultTy = dyn_cast(getType()); + if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes() || !dense || + !dense.isSplat()) + return nullptr; + + auto condattr = dense.getSplatValue(); + auto valueAttr = adaptor.getSelf(); + if (condattr.isZero()) { + valueAttr = adaptor.getOther(); + } + + return getBroadcastedAttr(valueAttr, resultTy); +} + //===----------------------------------------------------------------------===// // PrimMaxIntOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 6ee3a3e34cae..065331dfb54b 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -649,10 +649,10 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::type_as : (Tensor, Tensor) -> (Tensor)") emit("aten::view : (Tensor, int[]) -> (Tensor)", has_folder=True) emit("aten::_unsafe_view : (Tensor, int[]) -> (Tensor)") - emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)") - emit("aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)") - emit("aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)") - emit("aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)") + emit("aten::where.self : (Tensor, Tensor, Tensor) -> (Tensor)", has_folder=True) + emit("aten::where.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_folder=True) + emit("aten::where.ScalarOther : (Tensor, Tensor, Scalar) -> (Tensor)", has_folder=True) + emit("aten::where.ScalarSelf : (Tensor, Scalar, Tensor) -> (Tensor)", has_folder=True) emit("aten::nan_to_num : (Tensor, float?, float?, float?) -> (Tensor)") emit("aten::slice.Tensor : (Tensor, int, int?, int?, int) -> (Tensor)", has_folder=True) emit("aten::len.Tensor : (Tensor) -> (int)") From 4df96616dba72400071535c75188d94df7e44184 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 8 Feb 2024 07:14:07 +0530 Subject: [PATCH 173/283] [MLIR][TORCH] Modify Onnx.Reshape lowering for static shape cases (#2852) This commit modifies the OnnxToTorch lowering of Onnx.Reshape op by creating the result shape list for the aten.reshape using the result shape values inferred from the op's result shape. Signed-Off By: Vivek Khandelwal --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 23 +++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 161 +++--------------- 2 files changed, 42 insertions(+), 142 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 8227514b5cf5..764cfc247d25 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1656,6 +1656,29 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.tensorResultType(resultType) || binder.s64IntegerAttr(allowzero, "allowzero", 0)) return failure(); + + // If the result shape is static then we can create a result shape list + // directly using the result shape values (integers). + if (resultType.hasSizes()) { + bool hasStaticShape = resultType.areAllSizesKnown(); + ArrayRef resultShapeInt = resultType.getSizes(); + if (hasStaticShape) { + SmallVector resultShape; + for (int64_t dim : resultShapeInt) { + resultShape.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dim))); + } + Value resultShapeList = rewriter.create( + binder.getLoc(), + Torch::ListType::get( + Torch::IntType::get(binder.op->getContext())), + resultShape); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, resultShapeList); + return success(); + } + } + Torch::BaseTensorType shapeType = shape.getType().cast(); SmallVector dimList; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index ae36661bdd43..8a5d5b1efea8 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1256,33 +1256,11 @@ func.func @test_slice_default_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: // CHECK-LABEL: func.func @test_reshape_negative_dim func.func @test_reshape_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int - // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT2_1:.+]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %arg1, %int0, %int2_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int - // CHECK: %[[INT4:.+]] = torch.constant.int 4 - // CHECK: torch.aten.mul.int %15, %int4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: torch.aten.reshape %arg0, %18 : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,6,2],f32> + // CHECK: %[[INT6:.+]] = torch.constant.int 6 + // CHECK: %[[INT2_0:.+]] = torch.constant.int 2 + // CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT6]], %[[INT2_0]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,6,2],f32> %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,6,2],f32> return %0 : !torch.vtensor<[2,6,2],f32> } @@ -1291,40 +1269,12 @@ func.func @test_reshape_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: // CHECK-LABEL: func.func @test_reshape_negative_extended_dims func.func @test_reshape_negative_extended_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,2,3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int + // CHECK: %[[INT2:.+]] = torch.constant.int 2 // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT2_1:.+]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %arg1, %int0, %int2_1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int // CHECK: %[[INT4:.+]] = torch.constant.int 4 - // CHECK: torch.aten.mul.int %15, %int4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT3_2:.+]] = torch.constant.int 3 - // CHECK: torch.aten.select.int %arg1, %int0, %int3_2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %18 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %19, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %20 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %21, %int0 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %19, %22 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5, %11, %17, %23 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: torch.aten.reshape %arg0, %24 : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[1,2,3,4],f32> + // CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT1]], %[[INT2]], %[[INT3]], %[[INT4]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[1,2,3,4],f32> %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[1,2,3,4],f32> return %0 : !torch.vtensor<[1,2,3,4],f32> } @@ -1333,17 +1283,9 @@ func.func @test_reshape_negative_extended_dims(%arg0: !torch.vtensor<[2,3,4],f32 // CHECK-LABEL: func.func @test_reshape_one_dim func.func @test_reshape_one_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[24],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: torch.aten.reshape %arg0, %6 : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[24],f32> + // CHECK: %[[INT24:.+]] = torch.constant.int 24 + // CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT24]] : (!torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[24],f32> %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[24],f32> return %0 : !torch.vtensor<[24],f32> } @@ -1352,25 +1294,10 @@ func.func @test_reshape_one_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torc // CHECK-LABEL: func.func @test_reshape_reduced_dims func.func @test_reshape_reduced_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,12],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int - // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5, %11 : (!torch.int, !torch.int) -> !torch.list - // CHECK: torch.aten.reshape %arg0, %12 : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,12],f32> + // CHECK: %[[INT12:.+]] = torch.constant.int 12 + // CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT12]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,12],f32> %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[2,12],f32> return %0 : !torch.vtensor<[2,12],f32> } @@ -1379,33 +1306,11 @@ func.func @test_reshape_reduced_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: // CHECK-LABEL: func.func @test_reshape_reordered_all_dims func.func @test_reshape_reordered_all_dims(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[4,2,3],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int + // CHECK: %[[INT4:.+]] = torch.constant.int 4 // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT2_1:.+]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %arg1, %int0, %int2_1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int - // CHECK: %[[INT4:.+]] = torch.constant.int 4 - // CHECK: torch.aten.mul.int %15, %int4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5, %11, %17 : (!torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: torch.aten.reshape %arg0, %18 : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[4,2,3],f32> + // CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT4]], %[[INT2]], %[[INT3]] : (!torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[4,2,3],f32> %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[4,2,3],f32> return %0 : !torch.vtensor<[4,2,3],f32> } @@ -1414,40 +1319,12 @@ func.func @test_reshape_reordered_all_dims(%arg0: !torch.vtensor<[2,3,4],f32>, % // CHECK-LABEL: func.func @test_reshape_zero_and_negative_dim func.func @test_reshape_zero_and_negative_dim(%arg0: !torch.vtensor<[2,3,4],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[2,3,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %6 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %7, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %8 : !torch.bool -> !torch.int // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: torch.aten.mul.int %9, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %7, %10 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT2_1:.+]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %arg1, %int0, %int2_1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %12 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %13, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %14 : !torch.bool -> !torch.int + // CHECK: %[[INT1:.+]] = torch.constant.int 1 // CHECK: %[[INT4:.+]] = torch.constant.int 4 - // CHECK: torch.aten.mul.int %15, %int4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %13, %16 : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT3_2:.+]] = torch.constant.int 3 - // CHECK: torch.aten.select.int %arg1, %int0, %int3_2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %18 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.eq.int %19, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %20 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %21, %int0 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %19, %22 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5, %11, %17, %23 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: torch.aten.reshape %arg0, %24 : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,3,1,4],f32> + // CHECK: %[[RESULT_SHAPE:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT3]], %[[INT1]], %[[INT4]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.reshape %arg0, %[[RESULT_SHAPE]] : !torch.vtensor<[2,3,4],f32>, !torch.list -> !torch.vtensor<[2,3,1,4],f32> %0 = torch.operator "onnx.Reshape"(%arg0, %arg1) : (!torch.vtensor<[2,3,4],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[2,3,1,4],f32> return %0 : !torch.vtensor<[2,3,1,4],f32> } From 21f070e95fd219a01ac663c697f0d36c1fa279e8 Mon Sep 17 00:00:00 2001 From: Ashay Rane Date: Wed, 7 Feb 2024 21:19:27 -0800 Subject: [PATCH 174/283] onnx: fix checks in TorchOnnxToTorch pass to match the ONNX spec (#2848) This PR contains three commits to update the validation checks in the ONNX -> Torch conversion pass for the AveragePool, Pad, and Slice operators: > onnx: fix preconditions for lowering AveragePool ops > > The `pads` attribute of the AveragePool operator specifies the value to > pad at both the beginning as well as the end of the axis (see > https://onnx.ai/onnx/operators/onnx__AveragePool.html#attributes), so > the size of this attribute should be twice the rank of the input tensor. > However, our TorchOnnxToTorch bails out early since it incorrectly > compares the pads attribute with the rank (not twice the rank) of the > input tensor. > > This patch fixes the code to match the spec and adds a lit test. > onnx: allow optional constant value for Pad operator > > The `constant_value` input of the onnx.Pad operator is optional (see > https://onnx.ai/onnx/operators/onnx__Pad.html#inputs), but the existing > logic for lowering the operator into the Torch dialect assumes that it > is mandatory. > > This patch makes the attribute optional and constructs a default value > (a list of zeros the size of the input tensor) if the attribute was not > specified. > onnx: fix checks for axes and steps inputs of Slice operator > > The ONNX Spec for the Slice operator allows the `starts` and `ends` > inputs to have fewer indices that the dimensions of the `data` tensor > (see https://onnx.ai/onnx/operators/onnx__Slice.html), but our code > expects these inputs to be as many as the `data` tensor's dimensions. > > More precisely, the spec requires that the `starts` and `ends` inputs > are only as long as the `axes` input, but since the `axes` input is > optional, the default type for the `axes` input has to match the type > for the `starts` and `ends` inputs. Moreover, the number of indices in > the `steps` input also has to match those in the `axes` inputs (instad > of matching the dimensions of the `data` input). > > This patch fixes the checks in the TorchOnnxToTorch conversion so that > they match the ONNX spec. --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 8 +++-- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 33 +++++++++++++++++-- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 22 ++++--------- .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 14 +++++++- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 18 ++++++++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 25 ++++++++++++++ 6 files changed, 99 insertions(+), 21 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 9550e982b8c4..e39c42b50422 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -308,12 +308,14 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return rewriter.notifyMatchFailure( binder.op, "kernel list size does not match the number of axes"); } - if (binder.s64IntegerArrayAttr(padding, "pads", {0})) { + SmallVector defaultPadding(2 * (rank - 2), 0); + if (binder.s64IntegerArrayAttr(padding, "pads", defaultPadding)) { return failure(); } - if (padding.size() != 1 && padding.size() != rank - 2) { + if (padding.size() != 2 * (rank - 2)) { return rewriter.notifyMatchFailure( - binder.op, "padding list size does not match the number of axes"); + binder.op, + "padding list size does not match twice the number of axes"); } if (binder.s64IntegerArrayAttr(strides, "strides", {1})) { return failure(); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 5ebba10c9ebd..1760a0a20672 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -861,7 +861,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( patterns.onOp( "Pad", 19, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; - Value data, pads, constantValue, axes; + Value data, pads, axes; std::string mode; // TODO: The `axes` parameter is not supported yet. @@ -871,12 +871,41 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } if (binder.tensorOperandAtIndex(data, 0) || binder.tensorOperandAtIndex(pads, 1) || - binder.tensorOperandAtIndex(constantValue, 2) || binder.tensorResultType(resultType) || binder.customOpNameStringAttr(mode, "mode", "constant")) return failure(); Location loc = binder.getLoc(); + Value constantValue; + if (binder.getNumOperands() >= 3) { + if (binder.tensorOperandAtIndex(constantValue, 2)) { + llvm::errs() << "failed to bind to index 2\n"; + return failure(); + } + } else { + auto dataTensorType = data.getType().cast(); + + auto maybeZeroAttr = [&]() -> std::optional { + if (dataTensorType.getDtype().isa()) { + return rewriter.getI64IntegerAttr(0); + } + if (dataTensorType.getDtype().isa()) { + return rewriter.getFloatAttr(dataTensorType.getDtype(), 0.0f); + } + return std::nullopt; + }(); + + if (!maybeZeroAttr) { + return rewriter.notifyMatchFailure( + binder.op, "expected integer or float data tensor"); + } + + auto shapedType = dataTensorType.toBuiltinTensor(); + auto splat = SplatElementsAttr::get(shapedType, *maybeZeroAttr); + constantValue = rewriter.create( + loc, dataTensorType, splat); + } + // Get pads shape and rank. The pads tensor is expected to be 1-D // tensor. auto padsTensorType = pads.getType().cast(); diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 764cfc247d25..8e46aa9ec7ed 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1531,18 +1531,14 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return failure(); } } else { - // The default axes value is the range from 0 to the number of - // dimensions + // The default axes value is the range from 0 to the size of first + // dimension of `starts` and `ends`. Value none = rewriter.create(loc); - auto defaultAxesType = Torch::ValueTensorType::get( - context, ArrayRef{operandTy.getRank()}, - rewriter.getIntegerType(64, /*signed*/ 1)); Value arangeLength = rewriter.create( loc, rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - operandTy.getRank())); + rewriter.getIntegerAttr(rewriter.getIntegerType(64), startSize)); axes = rewriter.create( - loc, defaultAxesType, arangeLength, none, none, none, none); + loc, startsTorchTy, arangeLength, none, none, none, none); } // Binding `steps` from its arguments or through a default value @@ -1553,22 +1549,18 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } } else { // The default `steps` value is a 1d tensor filled with ones with a - // size of the dimension of the operand + // size equal to the size of `starts` and `ends`. Value none = rewriter.create(loc); - auto defaultStepsType = Torch::ValueTensorType::get( - context, ArrayRef{operandTy.getRank()}, - rewriter.getIntegerType(64, /*signed*/ 1)); Value sizeStepInput = rewriter.create( loc, rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - operandTy.getRank())); + rewriter.getIntegerAttr(rewriter.getIntegerType(64), startSize)); Value sizeStepsInput = rewriter.create( loc, Torch::ListType::get( Torch::IntType::get(binder.op->getContext())), sizeStepInput); steps = rewriter.create( - loc, defaultStepsType, sizeStepsInput, none, none, none, none); + loc, startsTorchTy, sizeStepsInput, none, none, none, none); } if (!(endsTy.getRank() == 1 && startsTy.getRank() == 1 && diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index a71d4e428e18..2ee21c1e3841 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -699,13 +699,25 @@ func.func @test_averagepool_2d_ceil(%arg0: !torch.vtensor<[1,1,4,4],f32>) -> !to // CHECK-LABEL: @test_averagepool_3d_default func.func @test_averagepool_3d_default(%arg0: !torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: torch.aten.avg_pool3d %arg0, %0, %2, %1, %false, %false_2, %none : !torch.vtensor<[1,3,32,32,32],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,31,31,31],f32> + // CHECK: torch.aten.avg_pool3d %arg0, %0, %2, %1, %false, %false{{.*}}, %none : !torch.vtensor<[1,3,32,32,32],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,3,31,31,31],f32> %0 = torch.operator "onnx.AveragePool"(%arg0) {torch.onnx.kernel_shape = [2 : si64, 2 : si64, 2 : si64]} : (!torch.vtensor<[1,3,32,32,32],f32>) -> !torch.vtensor<[1,3,31,31,31],f32> return %0 : !torch.vtensor<[1,3,31,31,31],f32> } // ----- +// CHECK-LABEL: @test_averagepool_with_padding +// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[1,20,64,48],f32> +// CHECK: torch.aten.avg_pool2d %[[ARG]], {{.*}} : !torch.vtensor<[1,20,64,48],f32>, !torch.list, !torch.list, !torch.list, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,20,32,24],f32> + +func.func @test_averagepool_with_padding(%arg0: !torch.vtensor<[1,20,64,48],f32>) -> !torch.vtensor<[1,20,32,24],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 19 : si64} { + + %0 = torch.operator "onnx.AveragePool"(%arg0) {torch.onnx.ceil_mode = 0 : si64, torch.onnx.kernel_shape = [2 : si64, 2 : si64], torch.onnx.pads = [0 : si64, 0 : si64, 0 : si64, 0 : si64], torch.onnx.strides = [2 : si64, 2 : si64]} : (!torch.vtensor<[1,20,64,48],f32>) -> !torch.vtensor<[1,20,32,24],f32> + return %0 : !torch.vtensor<[1,20,32,24],f32> +} + +// ----- + // CHECK-LABEL: @test_conv_with_strides_no_padding func.func @test_conv_with_strides_no_padding(%arg0: !torch.vtensor<[1,1,7,5],f32>, %arg1: !torch.vtensor<[1,1,3,3],f32>) -> !torch.vtensor<[1,1,3,2],f32> attributes {torch.onnx_meta.ir_version = 6 : si64, torch.onnx_meta.opset_version = 11 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[C0:.*]] = torch.constant.int 0 diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 92fbe86caed4..bbef289ff5f2 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -447,6 +447,24 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], // ----- +// CHECK-LABEL: @test_pad_optional_constant +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64> +// CHECK: %[[CONST_STR:.*]] = torch.constant.str "constant" +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[FALSE:.*]] = torch.constant.bool false +// CHECK: %[[SEVEN:.*]] = torch.constant.int 7 +// CHECK: %[[DTYPE:.*]] = torch.aten.to.dtype %0, %[[SEVEN]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f64> +// CHECK: %[[ITEM:.*]] = torch.aten.item %[[DTYPE]] : !torch.vtensor<[],f64> -> !torch.float +// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[CONST_STR]], %[[ITEM]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32> + +func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} { + %0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> + return %0 : !torch.vtensor<[5,4],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_pow func.func @test_pow(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 15 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.pow.Tensor_Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 8a5d5b1efea8..9f2354d13e39 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1205,6 +1205,31 @@ func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f3 // ----- +// CHECK-LABEL: @test_slice_default_axes_and_steps +// CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[20,10,5],f32>, +// CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[1],si64>, +// CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[1],si64> + +// CHECK: %[[ZERO0:.*]] = torch.constant.int 0 +// CHECK: %[[ZERO1:.*]] = torch.constant.int 0 +// CHECK: %[[SCALAR:.*]] = torch.prim.NumToTensor.Scalar %[[ZERO1]] : !torch.int -> !torch.vtensor<[1],si64> +// CHECK: %[[SELECT0:.*]] = torch.aten.index_select %[[ARG1]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> +// CHECK: %[[ITEM0:.*]] = torch.aten.item %[[SELECT0]] : !torch.vtensor<[1],si64> -> !torch.int +// CHECK: %[[SELECT1:.*]] = torch.aten.index_select %[[ARG2]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> +// CHECK: %[[ITEM1:.*]] = torch.aten.item %[[SELECT1]] : !torch.vtensor<[1],si64> -> !torch.int +// CHECK: %[[SELECT2:.*]] = torch.aten.index_select %{{.*}}, %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> +// CHECK: %[[ITEM2:.*]] = torch.aten.item %[[SELECT2]] : !torch.vtensor<[1],si64> -> !torch.int +// CHECK: %[[SELECT3:.*]] = torch.aten.index_select %{{.*}}, %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> +// CHECK: %[[ITEM3:.*]] = torch.aten.item %[[SELECT3]] : !torch.vtensor<[1],si64> -> !torch.int +// CHECK: torch.aten.slice.Tensor %[[ARG0]], %[[ITEM2]], %[[ITEM0]], %[[ITEM1]], %[[ITEM3]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32> + +func.func @test_slice_default_axes_and_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[1],si64>, %arg2: !torch.vtensor<[1],si64>) -> !torch.vtensor<[20,10,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + %0 = torch.operator "onnx.Slice"(%arg0, %arg1, %arg2) : (!torch.vtensor<[20,10,5],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[20,10,1],f32> + return %0 : !torch.vtensor<[20,10,1],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_slice_default_steps func.func @test_slice_default_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[3],si64>, %arg3: !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { //CHECK: %[[NONE:.*]] = torch.constant.none From 44f8f8982687564924379f5fc9f197f767f421bf Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Thu, 8 Feb 2024 09:37:31 -0800 Subject: [PATCH 175/283] [torch-mlir][sparse] add sparsification to linalg reference backend (#2887) This adds a few passes that will ensure linalg with sparse tensors are properly lowered to loops and can run using the ExecutionEngine for testing (a few details on parameter passing from PyTorch still TBD) Test results: $ ./tools/e2e_test.sh --config linalg Summary: Passed: 1144 Expectedly Failed: 8 $ python -m e2e_testing.main --config=torchdynamo -v Summary: Passed: 960 Expectedly Failed: 163 Filed issue: https://github.com/pytorch/pytorch/issues/119407 --- projects/pt1/e2e_testing/xfail_sets.py | 5 +++++ .../linalg_on_tensors_backends/refbackend.py | 9 ++++++++- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 973f75a2637a..0d789a22db0e 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -35,6 +35,11 @@ "ConvolutionBackwardModule2DPadded_basic", "ConvolutionBackwardModule2D_basic", + # Size result mismatch (exposed by downstream canonicalizer + # on incompatabile casts). + # https://github.com/pytorch/pytorch/issues/119407 + "ConvolutionBackwardModule2DStrided_basic", + # RuntimeError: Index tensor must have the same number of dimensions as self tensor # RuntimeError: Failed running call_function aten.nll_loss_backward(... # https://github.com/pytorch/pytorch/issues/89630 diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 266459e00b0c..0b7b28e9df71 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -131,8 +131,13 @@ def invoke(*args): # This is likely because if things are naturally fusable we usually already # emit things in that form from the high level (e.g. single linalg-generic). # Other backends are likely to benefit more. + "func.func(linalg-generalize-named-ops)", "func.func(linalg-fuse-elementwise-ops)", "convert-shape-to-std", + # MLIR Sparsifier mini-pipeline. Note that this is the bare minimum + # to ensure operations on sparse tensors are lowered to loops. + "sparsification-and-bufferization", + "sparse-storage-specifier-to-llvm", # Bufferize. "func.func(scf-bufferize)", "func.func(tm-tensor-bufferize)", @@ -199,7 +204,9 @@ def compile(self, imported_module: Module): run_pipeline_with_repro_report( imported_module, LOWERING_PIPELINE, - "Lowering Linalg-on-Tensors IR to LLVM with RefBackend") + "Lowering Linalg-on-Tensors IR to LLVM with RefBackend", + enable_ir_printing=False, + ) return imported_module def load(self, module) -> RefBackendInvoker: From 9659a436d1374612d7d2c7518a74dfd9ae821bc0 Mon Sep 17 00:00:00 2001 From: Avinash Sharma Date: Thu, 8 Feb 2024 14:53:40 -0800 Subject: [PATCH 176/283] Add lowering support for math::AbsIOp (#2875) There is no lowering support for math::AbsIOp, so if the operand is an integer type, it will fail to lower to math::AbsFOp since the op operand #0 must be floating-point-like. --- .../TorchToLinalg/Uncategorized.cpp | 5 +++- projects/pt1/e2e_testing/xfail_sets.py | 6 ++-- .../test_suite/elementwise.py | 30 ++++++++++++++++--- 3 files changed, 34 insertions(+), 7 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 794b755998fe..479bc1c0d620 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -424,8 +424,11 @@ static Value createLinalgPayloadCalculationForElementwiseOp( b.create(loc, b.getFloatAttr(floatDtype, 0)); return createEqual(b, loc, floatDtype, self, zero); } - if (isa(op)) + if (isa(op)) { + if (payloadArgs[0].getType().isa()) + return b.create(loc, payloadArgs[0]); return b.create(loc, payloadArgs[0]); + } if (isa(op)) { Value abs = b.create(loc, payloadArgs[0]); Value infinity = b.create( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 0d789a22db0e..26f3e843954f 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -579,7 +579,8 @@ "ElementwiseSubScalarFloatModule_basic", "ElementwiseSubScalarIntModule_basic", "ElementwiseWhereScalarModule_basic", - "ElementwiseAbsModule_basic", + "ElementwiseAbsFloatModule_basic", + "ElementwiseAbsIntModule_basic", "EmbeddingModule1DIndices_basic", "EmbeddingModuleI32Static_basic", "EmbeddingModuleI32_basic", @@ -1060,7 +1061,8 @@ "EinsumStaticContractRhsModule_basic", "EinsumStaticFourDimensionModule_basic", "EinsumStaticModule_basic", - "ElementwiseAbsModule_basic", + "ElementwiseAbsFloatModule_basic", + "ElementwiseAbsIntModule_basic", "ElementwiseAddModule_basic", "ElementwiseAddScalarFloatModule_basic", "ElementwiseAddScalarInt64Module_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index f711af6d4639..c1a827ffe108 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -2113,7 +2113,7 @@ def ElementwiseRsqrtIntModule_basic(module, tu: TestUtils): # ============================================================================== -class ElementwiseAbsModule(torch.nn.Module): +class ElementwiseAbsFloatModule(torch.nn.Module): def __init__(self): super().__init__() @@ -2127,9 +2127,31 @@ def forward(self, a): return torch.abs(a) -@register_test_case(module_factory=lambda: ElementwiseAbsModule()) -def ElementwiseAbsModule_basic(module, tu: TestUtils): - module.forward(tu.rand(3, 4, 5, low=-1.0, high=1.0)) +@register_test_case(module_factory=lambda: ElementwiseAbsFloatModule()) +def ElementwiseAbsFloatModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([[[-1.0, 0.0, 1.0]]])) + + +# ============================================================================== + + +class ElementwiseAbsIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.int64, True), + ]) + def forward(self, a): + return torch.abs(a) + + +@register_test_case(module_factory=lambda: ElementwiseAbsIntModule()) +def ElementwiseAbsIntModule_basic(module, tu: TestUtils): + module.forward(torch.tensor([[[-1, 0, 1]]])) # ============================================================================== From 4cc62aeb24e28b3ff60df6ff4a0fd97cc045efc1 Mon Sep 17 00:00:00 2001 From: Franz Haniel <77495327+frafranz@users.noreply.github.com> Date: Fri, 9 Feb 2024 17:00:24 +0100 Subject: [PATCH 177/283] Implement trace (#2790) The lowering decomposes AtenTraceOp into an AtenDiagonalOp followed by AtenSumOp. The progress is tracked in https://github.com/nod-ai/SHARK-Turbine/issues/333. --------- Co-authored-by: Franz Haniel --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 23 ++++++++ .../Transforms/AbstractInterpLibrary.cpp | 26 +++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 49 +++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + .../build_tools/abstract_interp_lib_gen.py | 16 ++++++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/reduction.py | 53 +++++++++++++++++++ 7 files changed, 169 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 726c166849f7..68262ee2368d 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -9105,6 +9105,29 @@ def Torch_AtenEinsumOp : Torch_Op<"aten.einsum", [ }]; } +def Torch_AtenTraceOp : Torch_Op<"aten.trace", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::trace : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTraceOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenTraceOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenBucketizeTensorOp : Torch_Op<"aten.bucketize.Tensor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 4290ce23c44b..320f53f0b7b6 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6978,6 +6978,21 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg1, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.trace\"(%arg0: !torch.list) -> !torch.list {\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: input must have rank 2\"\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.eq.int %0, %int2 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %1 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %2 = torch.prim.ListConstruct : () -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.argmax\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.bool) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.argmax(%arg0, %arg1, %arg2) : (!torch.list, !torch.optional, !torch.bool) -> !torch.list\n" " return %0 : !torch.list\n" @@ -12520,6 +12535,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %5 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%0, %1) : (!torch.list>, !torch.list) -> !torch.int\n" " return %5 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.trace\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._shape_as_tensor\"(%arg0: !torch.tuple) -> !torch.int {\n" " %int4 = torch.constant.int 4\n" " return %int4 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index edf51be11310..bc5276dca6a7 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1056,6 +1056,54 @@ class DecomposeAtenEinsumOp : public OpRewritePattern { }; } // namespace +namespace { +// Calculate the trace of the input tensor as the sum over its diagonal +// elements. This computation is performed as: +// +// Step1: Obtain the diagonal using AtenDiagonalOp +// Step2: Compute the trace using AtenSumOp. +// +// It is verified that the input tensor has rank two. +class DecomposeAtenTraceOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenTraceOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + std::optional inRank = getTensorRank(self); + if (inRank != 2) + return rewriter.notifyMatchFailure( + op, "Expected input tensor to have rank 2."); + + Value none = rewriter.create(loc); + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + BaseTensorType inputType = self.getType().cast(); + + Value output = op.getResult(); + BaseTensorType outputType = output.getType().cast(); + + ArrayRef inputShape = inputType.getSizes(); + int64_t diagonalSize = std::min(inputShape[0], inputShape[1]); + SmallVector diagonalShape{diagonalSize}; + Type elementType = inputType.getOptionalDtype(); + Type diagonalType = inputType.getWithSizesAndDtype( + llvm::ArrayRef(diagonalShape), elementType); + + Value diagonal = rewriter.create( + loc, diagonalType, /*input=*/self, /*offset=*/zero, /*dim1=*/zero, + /*dim2=*/one); + Value sum = rewriter.create(loc, outputType, /*self=*/diagonal, + /*dtype=*/none); + rewriter.replaceOp(op, sum); + return success(); + } +}; +} // namespace + // Calculates the softmax function on the given `input` tensor. Softmax(x) = // exp(x)/sum(exp(x)). // To avoid overflow we use the following decomposition rule: @@ -6727,6 +6775,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index c4259dc958b8..306b2446adb6 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -389,6 +389,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index dadd87a15a08..7fe6e8457fe8 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -518,6 +518,15 @@ def aten〇std〇dim〡shape(self: List[int], dim: Optional[List[int]], unbiased def aten〇std〇correction〡shape(self: List[int], dim: Optional[List[int]] = None, correction: Optional[float] = None, keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, None) +@check_shape_function([ + Invocation(TensorOfShape(2, 3)), # Basic case. + ErrorInvocation(TensorOfShape(2, 3, 4)), # Too many dimensions. + ErrorInvocation(TensorOfShape(2)), # Too few dimensions. +]) +def aten〇trace〡shape(self: List[int]) -> List[int]: + assert len(self) == 2, "input must have rank 2" + return [] + @check_shape_function([ Invocation(TensorOfShape(2, 3, 4)), # Basic case. Invocation(TensorOfShape(2, 3, 4), dim=0), # Test explicit `dim`. @@ -4219,6 +4228,13 @@ def aten〇einsum〡dtype(equation: str, tensors_rank_dtype: List[Tuple[int, int dtypes.append(tensor_dtype) return promote_dtypes(ranks, dtypes) +@check_dtype_function(_check_tensors_with_the_same_dtype(tensor_shapes=[(2, 3)])) +def aten〇trace〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.int64 + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇_shape_as_tensor〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.int64 diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 065331dfb54b..2f91075e96c2 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -592,6 +592,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::argmin : (Tensor, int?, bool) -> (Tensor)") emit("aten::one_hot : (Tensor, int) -> (Tensor)") emit("aten::einsum : (str, Tensor[], int[]?) -> (Tensor)") + emit("aten::trace : (Tensor) -> (Tensor)") emit("aten::bucketize.Tensor : (Tensor, Tensor, bool, bool) -> (Tensor)") emit("aten::clone : (Tensor, int?) -> (Tensor)", has_folder=True) emit("aten::lift_fresh_copy : (Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index ea2ff1609f3a..804476e6a686 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -1302,3 +1302,56 @@ def forward(self, input, target): @register_test_case(module_factory=lambda: CrossEntropyLossNoReductionModule()) def CrossEntropyLossNoReductionModule_basic(module, tu: TestUtils): module.forward(tu.rand(8, 2), tu.randint(8, high=2)) + +# ============================================================================== + +class TraceModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.trace(a) + +@register_test_case(module_factory=lambda: TraceModule()) +def TraceModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 3)) + +@register_test_case(module_factory=lambda: TraceModule()) +def TraceModule_nonsquare(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + +@register_test_case(module_factory=lambda: TraceModule()) +def TraceModule_empty(module, tu: TestUtils): + module.forward(torch.empty(0,0)) + +# ============================================================================== + +class TraceIntModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ]) + def forward(self, a): + return torch.ops.aten.trace(a) + +@register_test_case(module_factory=lambda: TraceIntModule()) +def TraceSignedIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 2, low=-10, high=10)) + +@register_test_case(module_factory=lambda: TraceIntModule()) +def TraceUnsignedIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 2, low=0, high=10)) + +@register_test_case(module_factory=lambda: TraceIntModule()) +def TraceUnsignedIntModule_empty(module, tu: TestUtils): + module.forward(tu.randint(0, 0, low=0, high=10)) + From 7d33ba69ac71981fb39f839d698300bcecf5353a Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 9 Feb 2024 14:02:54 -0800 Subject: [PATCH 178/283] [torch] Folder for torch.aten.select.int for splat cases (#2890) If the input or result is a splat value we can just constant fold the result. This is common for shape computations and can help with shape inference. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + lib/Dialect/Torch/IR/TorchOps.cpp | 35 +++ .../build_tools/torch_ods_gen.py | 2 +- test/Dialect/Torch/canonicalize.mlir | 213 ++++++++++++++++++ 4 files changed, 250 insertions(+), 1 deletion(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 68262ee2368d..e3a7526d9114 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -10124,6 +10124,7 @@ def Torch_AtenSelectIntOp : Torch_Op<"aten.select.int", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; } def Torch_AtenSizeIntOp : Torch_Op<"aten.size.int", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 4cbd49843d00..76d5fc03688a 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1327,6 +1327,41 @@ void AtenSizeOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenSelectIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenSelectIntOp::fold(FoldAdaptor adaptor) { + auto self = dyn_cast_or_null(adaptor.getSelf()); + auto ty = dyn_cast(getType()); + if (!self || !ty || !ty.hasDtype() || !ty.hasSizes()) + return nullptr; + + auto selfTy = cast(self.getType()); + auto bty = ty.toBuiltinTensor().clone(ty.getDtype()); + if (!bty.hasStaticShape()) + return nullptr; + + if (self.isSplat()) + return DenseElementsAttr::get(bty, self.getSplatValue()); + + auto dimAttr = dyn_cast_or_null(adaptor.getDim()); + auto indexAttr = dyn_cast_or_null(adaptor.getIndex()); + if (!dimAttr || !indexAttr || bty.getNumElements() != 1) + return nullptr; + + auto dim = dimAttr.getInt(); + auto index = indexAttr.getInt(); + + for (int i = 0, s = selfTy.getRank(); i < s; ++i) { + if (i != dim && selfTy.getDimSize(i) != 1) + return nullptr; + } + + auto splattr = self.getValues()[index]; + return DenseElementsAttr::get(bty, splattr); +} + //===----------------------------------------------------------------------===// // AtenSizeIntOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 2f91075e96c2..5d0644381612 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -629,7 +629,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::_reshape_alias : (Tensor, int[], int[]) -> (Tensor)") emit("aten::resize : (Tensor, int[], int?) -> (Tensor)") emit("aten::resize_ : (Tensor, int[], int?) -> (Tensor)") - emit("aten::select.int : (Tensor, int, int) -> (Tensor)") + emit("aten::select.int : (Tensor, int, int) -> (Tensor)", has_folder=1) emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True) emit("aten::sum : (Tensor, int?) -> (Tensor)") emit("aten::sum.dim_IntList : (Tensor, int[]?, bool, int?) -> (Tensor)") diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index ac1797b4c398..28b4f6933c5a 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2332,3 +2332,216 @@ func.func @torch.aten.index_select$const_f32_si_neg() -> !torch.vtensor<[1],f32> %0 = torch.aten.index_select %tensor, %dim, %index : !torch.vtensor<[10],f32>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],f32> return %0 : !torch.vtensor<[1],f32> } + +// ----- + +// CHECK-LABEL: @fold_aten_where_true_attr +func.func @fold_aten_where_true_attr() -> !torch.vtensor<[4],si64> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + // CHECK: return %[[RET]] + %bool = torch.vtensor.literal(dense<1> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %rhs = torch.vtensor.literal(dense<11> : tensor) : !torch.vtensor<[],si64> + %where = torch.aten.where.self %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.vtensor<[4],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[4],si64> + return %where : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_where_false_attr +func.func @fold_aten_where_false_attr() -> !torch.vtensor<[4],si64> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + // CHECK: return %[[RET]] + %bool = torch.vtensor.literal(dense<0> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %rhs = torch.vtensor.literal(dense<11> : tensor) : !torch.vtensor<[],si64> + %where = torch.aten.where.self %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.vtensor<[4],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[4],si64> + return %where : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_where_true_value +func.func @fold_aten_where_true_value(%arg0 : !torch.vtensor<[4],si64>, %arg1 : !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> { + // CHECK: return %arg0 + %bool = torch.vtensor.literal(dense<1> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %where = torch.aten.where.self %bool, %arg0, %arg1 : !torch.vtensor<[4],i1>, !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],si64> + return %where : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_where_false_value +func.func @fold_aten_where_false_value(%arg0 : !torch.vtensor<[4],si64>, %arg1 : !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> { + // CHECK: return %arg1 + %bool = torch.vtensor.literal(dense<0> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %where = torch.aten.where.self %bool, %arg0, %arg1 : !torch.vtensor<[4],i1>, !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],si64> + return %where : !torch.vtensor<[4],si64> +} + + +// ----- + +// CHECK-LABEL: @fold_aten_where_true_value_nofold +func.func @fold_aten_where_true_value_nofold(%arg0 : !torch.vtensor<[],si64>, %arg1 : !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],si64> { + // CHECK: torch.aten.where.self + %bool = torch.vtensor.literal(dense<1> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %where = torch.aten.where.self %bool, %arg0, %arg1 : !torch.vtensor<[4],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],si64> + return %where : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_where_true_scalar_int +func.func @fold_aten_where_true_scalar_int() -> !torch.vtensor<[4],si64> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + // CHECK: return %[[RET]] + %bool = torch.vtensor.literal(dense<1> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.constant.int 7 + %rhs = torch.constant.int 11 + %where = torch.aten.where.Scalar %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.int, !torch.int -> !torch.vtensor<[4],si64> + return %where : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_where_false_scalar_int +func.func @fold_aten_where_false_scalar_int() -> !torch.vtensor<[4],ui8> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<11> : tensor<4xui8>) : !torch.vtensor<[4],ui8> + // CHECK: return %[[RET]] + %bool = torch.vtensor.literal(dense<0> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.constant.int 7 + %rhs = torch.constant.int 11 + %where = torch.aten.where.Scalar %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.int, !torch.int -> !torch.vtensor<[4],ui8> + return %where : !torch.vtensor<[4],ui8> +} + +// ----- + +// CHECK-LABEL: @fold_aten_where_false_scalar_fp +func.func @fold_aten_where_false_scalar_fp() -> !torch.vtensor<[4],f32> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<1.100000e+01> : tensor<4xf32>) : !torch.vtensor<[4],f32> + // CHECK: return %[[RET]] + %bool = torch.vtensor.literal(dense<0> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.constant.float 7.0 + %rhs = torch.constant.float 11.0 + %where = torch.aten.where.Scalar %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.float, !torch.float -> !torch.vtensor<[4],f32> + return %where : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_where_true_sother_int +func.func @fold_aten_where_true_sother_int() -> !torch.vtensor<[4],si64> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + // CHECK: %[[RET]] + %bool = torch.vtensor.literal(dense<1> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %rhs = torch.constant.int 11 + %where = torch.aten.where.ScalarOther %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + return %where : !torch.vtensor<[4],si64> +} + + +// ----- + +// CHECK-LABEL: @fold_aten_where_false_sother_int +func.func @fold_aten_where_false_sother_int() -> !torch.vtensor<[4],ui8> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<11> : tensor<4xui8>) : !torch.vtensor<[4],ui8> + // CHECK: return %[[RET]] + %bool = torch.vtensor.literal(dense<0> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.vtensor.literal(dense<7> : tensor) : !torch.vtensor<[],ui8> + %rhs = torch.constant.int 11 + %where = torch.aten.where.ScalarOther %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.vtensor<[],ui8>, !torch.int -> !torch.vtensor<[4],ui8> + return %where : !torch.vtensor<[4],ui8> +} + +// ----- + +// CHECK-LABEL: @fold_aten_where_false_sother_fp +func.func @fold_aten_where_false_sother_fp() -> !torch.vtensor<[4],f32> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<1.100000e+01> : tensor<4xf32>) : !torch.vtensor<[4],f32> + // CHECK: %[[RET]] + %bool = torch.vtensor.literal(dense<0> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %rhs = torch.constant.float 11.0 + %where = torch.aten.where.ScalarOther %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32> + return %where : !torch.vtensor<[4],f32> +} + + +// ----- + +// CHECK-LABEL: @fold_aten_where_true_sself_int +func.func @fold_aten_where_true_sself_int() -> !torch.vtensor<[4],si64> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + // CHECK: %[[RET]] + %bool = torch.vtensor.literal(dense<1> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.constant.int 7 + %rhs = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %where = torch.aten.where.ScalarSelf %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.int, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],si64> + return %where : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_where_false_sself_int +func.func @fold_aten_where_false_sself_int() -> !torch.vtensor<[4],ui8> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<11> : tensor<4xui8>) : !torch.vtensor<[4],ui8> + // CHECK: return %[[RET]] + %bool = torch.vtensor.literal(dense<0> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.constant.int 7 + %rhs = torch.vtensor.literal(dense<11> : tensor) : !torch.vtensor<[],ui8> + %where = torch.aten.where.ScalarSelf %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.int, !torch.vtensor<[],ui8> -> !torch.vtensor<[4],ui8> + return %where : !torch.vtensor<[4],ui8> +} + +// ----- + +// CHECK-LABEL: @fold_aten_where_false_sself_fp +func.func @fold_aten_where_false_sself_fp() -> !torch.vtensor<[4],f32> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<1.100000e+01> : tensor<4xf32>) : !torch.vtensor<[4],f32> + // CHECK: %[[RET]] + %bool = torch.vtensor.literal(dense<0> : tensor<4xi1>) : !torch.vtensor<[4],i1> + %lhs = torch.constant.float 7.0 + %rhs = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %where = torch.aten.where.ScalarSelf %bool, %lhs, %rhs : !torch.vtensor<[4],i1>, !torch.float, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],f32> + return %where : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @aten_select_int_fold_splat +func.func @aten_select_int_fold_splat(%arg0 : !torch.int, %arg1 : !torch.int) -> !torch.vtensor<[1],si64> { + %splat = torch.vtensor.literal(dense<4> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %select = torch.aten.select.int %splat, %arg0, %arg1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<4> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + // CHECK: return %[[RET]] + return %select : !torch.vtensor<[1],si64> +} + +// ----- + +// CHECK-LABEL: @aten_select_int_fold_1D +func.func @aten_select_int_fold_1D() -> !torch.vtensor<[1],si64> { + %index = torch.constant.int 1 + %dim = torch.constant.int 0 + %splat = torch.vtensor.literal(dense<[5,6,7,8]> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %select = torch.aten.select.int %splat, %dim, %index : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<6> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + // CHECK: return %[[RET]] + return %select : !torch.vtensor<[1],si64> +} + +// ----- + +// CHECK-LABEL: @aten_select_int_fold_3D +func.func @aten_select_int_fold_3D() -> !torch.vtensor<[1, 1, 1],si64> { + %index = torch.constant.int 2 + %dim = torch.constant.int 2 + %splat = torch.vtensor.literal(dense<[[[5,6,7,8]]]> : tensor<1x1x4xsi64>) : !torch.vtensor<[1,1,4],si64> + %select = torch.aten.select.int %splat, %dim, %index : !torch.vtensor<[1,1,4],si64>, !torch.int, !torch.int -> !torch.vtensor<[1,1,1],si64> + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<7> : tensor<1x1x1xsi64>) : !torch.vtensor<[1,1,1],si64> + // CHECK: return %[[RET]] + return %select : !torch.vtensor<[1,1,1],si64> +} From d83b576c6e15cf7ebeefc1dbd65fd9061227c278 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 9 Feb 2024 14:07:49 -0800 Subject: [PATCH 179/283] Bump LLVM to llvm/llvm-project@bb180856ec28efe305dc77ca4bb3db12d8932edf (#2895) Includes some minor first for `AffineMap::inferFromExprList` --- externals/llvm-project | 2 +- .../TorchToLinalg/IndirectDataMovement.cpp | 3 ++- lib/Conversion/TorchToLinalg/Linear.cpp | 4 ++-- lib/Conversion/TorchToLinalg/Pooling.cpp | 14 +++++++------- lib/Conversion/TorchToLinalg/Reduction.cpp | 3 ++- lib/Conversion/TorchToLinalg/Utils.cpp | 3 ++- lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp | 3 ++- 7 files changed, 18 insertions(+), 14 deletions(-) diff --git a/externals/llvm-project b/externals/llvm-project index 70eb0e37a867..bb180856ec28 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit 70eb0e37a86747f9266e4c8380baa89746f5e23b +Subproject commit bb180856ec28efe305dc77ca4bb3db12d8932edf diff --git a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp index bfbe45afe167..b8754a306711 100644 --- a/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/IndirectDataMovement.cpp @@ -498,7 +498,8 @@ class ConvertAtenIndexSelectOp : public OpConversionPattern { resultExpr.push_back(rewriter.getAffineDimExpr(i)); } - auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr}); + auto indexingMaps = AffineMap::inferFromExprList({indicesExpr, resultExpr}, + rewriter.getContext()); Value finalRes = rewriter diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 3557b27a2eb2..6c04dd12f55a 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -512,8 +512,8 @@ class ConvertAtenMatmulOp : public OpConversionPattern { resultShape.insert(resultShape.end(), {lhsDim0, rhsDim1}); Value zeroTensor = createZeroInitTensor(rewriter, loc, resultShape, elementType); - auto indexingMaps = - AffineMap::inferFromExprList({lhsExpr, rhsExpr, outExpr}); + auto indexingMaps = AffineMap::inferFromExprList( + {lhsExpr, rhsExpr, outExpr}, rewriter.getContext()); iteratorTypes.insert(iteratorTypes.end(), {utils::IteratorType::parallel, utils::IteratorType::reduction, diff --git a/lib/Conversion/TorchToLinalg/Pooling.cpp b/lib/Conversion/TorchToLinalg/Pooling.cpp index 76100c2c0e71..e795d2ea9fb8 100644 --- a/lib/Conversion/TorchToLinalg/Pooling.cpp +++ b/lib/Conversion/TorchToLinalg/Pooling.cpp @@ -442,8 +442,8 @@ class ConvertAtenMaxPool2dWithIndicesOp // Here we have six dimensions, each corresponding to N, C, Hout, Wout, kH, // and kW, respectively, as described in the algorithm above. - SmallVector indexingMaps = - AffineMap::inferFromExprList({inputExprs, kernelExprs, outputExprs}); + SmallVector indexingMaps = AffineMap::inferFromExprList( + {inputExprs, kernelExprs, outputExprs}, rewriter.getContext()); SmallVector iteratorTypes( 4, utils::IteratorType::parallel); iteratorTypes.push_back(utils::IteratorType::reduction); @@ -724,7 +724,7 @@ class ConvertAtenAdaptiveAvgPool1dOp kSizeTensorExprs.push_back(rewriter.getAffineDimExpr(2)); kIterExprs.push_back(rewriter.getAffineDimExpr(3)); SmallVector indexingMaps = AffineMap::inferFromExprList( - {kIterExprs, outputExprs, kSizeTensorExprs}); + {kIterExprs, outputExprs, kSizeTensorExprs}, rewriter.getContext()); SmallVector iteratorTypes( 3, utils::IteratorType::parallel); iteratorTypes.push_back(utils::IteratorType::reduction); @@ -774,8 +774,8 @@ class ConvertAtenAdaptiveAvgPool1dOp // make a linalg generic to divide each element by the corresponding // Kernel Width. This step is only necessary for avg pooling. - SmallVector indexingMaps1 = - AffineMap::inferFromExprList({kSizeTensorExprs, outputExprs}); + SmallVector indexingMaps1 = AffineMap::inferFromExprList( + {kSizeTensorExprs, outputExprs}, rewriter.getContext()); SmallVector iteratorTypes1( 3, utils::IteratorType::parallel); auto output = rewriter.create( @@ -916,8 +916,8 @@ class ConvertAtenAdaptiveMaxPool2dOp for (unsigned i = rank; i < 2 * rank - 2; i++) { kIterExprs.push_back(rewriter.getAffineDimExpr(i)); } - SmallVector indexingMaps = - AffineMap::inferFromExprList({kIterExprs, outputExprs, auxTensorExprs}); + SmallVector indexingMaps = AffineMap::inferFromExprList( + {kIterExprs, outputExprs, auxTensorExprs}, rewriter.getContext()); SmallVector iteratorTypes( rank, utils::IteratorType::parallel); for (unsigned i = 0; i < rank - 2; i++) { diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index 50fa1f8e610d..a21615ad84c4 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -167,7 +167,8 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { resultExprs.push_back(rewriter.getAffineDimExpr(size.index())); } } - auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs}); + auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs}, + rewriter.getContext()); auto linalgOp = rewriter.create( loc, ArrayRef({filledTensorVal.getType(), filledTensorIdx.getType()}), diff --git a/lib/Conversion/TorchToLinalg/Utils.cpp b/lib/Conversion/TorchToLinalg/Utils.cpp index 20b32cd1fe73..366f5492aa6d 100644 --- a/lib/Conversion/TorchToLinalg/Utils.cpp +++ b/lib/Conversion/TorchToLinalg/Utils.cpp @@ -197,7 +197,8 @@ Value torch_to_linalg::createReductionLinalgGeneric( } } - auto indexingMaps = AffineMap::inferFromExprList({exprs, resultExprs}); + auto indexingMaps = + AffineMap::inferFromExprList({exprs, resultExprs}, b.getContext()); Value accumulator = createInitTensor(b, loc, resultShape, initElem.getType(), initElem); diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index d11a5524af7d..c669e8b6b8cc 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -1064,7 +1064,8 @@ class ConvertAtenMaxPool2dWithIndicesBackwardOp rewriter.getAffineDimExpr(tensorOperandRank)); SmallVector indexingMaps = AffineMap::inferFromExprList( - {originalIndicesDimExprs, updatedIndicesDimExprs}); + {originalIndicesDimExprs, updatedIndicesDimExprs}, + rewriter.getContext()); SmallVector iteratorTypes( tensorOperandRank + 1, utils::IteratorType::parallel); From c0f139be0f8371569c38b45ab2c925deb16292d2 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 9 Feb 2024 15:02:20 -0800 Subject: [PATCH 180/283] [torch] Add `torch.aten.eq.Tensor` comparison folder (#2889) Added a folded for a equals operator. This allows an equivalent comparison folder, primarily for when shape computations occur small size tensor. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 95 +++++++-------- lib/Dialect/Torch/IR/TorchOps.cpp | 102 +++++++++++++++++ .../build_tools/torch_ods_gen.py | 2 +- test/Dialect/Torch/canonicalize.mlir | 108 ++++++++++++++++++ 4 files changed, 259 insertions(+), 48 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index e3a7526d9114..adf5e8396751 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -1669,53 +1669,6 @@ def Torch_AtenLerp_ScalarOp : Torch_Op<"aten.lerp_.Scalar", [ }]; } -def Torch_AtenEqTensorOp : Torch_Op<"aten.eq.Tensor", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenEqTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenEqTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - -def Torch_AtenEq_TensorOp : Torch_Op<"aten.eq_.Tensor", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::eq_.Tensor : (Tensor, Tensor) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$other - ); - let results = (outs - Torch_NonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenEq_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); - } - void AtenEq_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); - } - }]; -} - def Torch_AtenGtTensorOp : Torch_Op<"aten.gt.Tensor", [ AllowsTypeRefinement, HasValueSemantics, @@ -3931,6 +3884,54 @@ def Torch_AtenMul_ScalarOp : Torch_Op<"aten.mul_.Scalar", [ }]; } +def Torch_AtenEqTensorOp : Torch_Op<"aten.eq.Tensor", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEqTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenEqTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + +def Torch_AtenEq_TensorOp : Torch_Op<"aten.eq_.Tensor", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::eq_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$other + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenEq_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenEq_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; +} + def Torch_AtenFloorOp : Torch_Op<"aten.floor", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 76d5fc03688a..d831b70767c4 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1166,6 +1166,108 @@ void AtenMulTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +//===----------------------------------------------------------------------===// +// AtenEqTensorOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenEqTensorOp::fold(FoldAdaptor adaptor) { + constexpr int64_t kMaxFold = 16; + auto ty = dyn_cast(getType()); + if (!ty || !ty.hasDtype() || !ty.hasSizes()) + return nullptr; + + auto bty = ty.toBuiltinTensor().clone(ty.getDtype()); + if (!bty.hasStaticShape()) + return nullptr; + + if (getSelf() == getOther()) + return DenseElementsAttr::get(bty, + IntegerAttr::get(bty.getElementType(), 1)); + + auto self = dyn_cast_or_null(adaptor.getSelf()); + auto other = dyn_cast_or_null(adaptor.getOther()); + if (!self || !other) + return nullptr; + + auto selfTy = dyn_cast(self.getType()); + auto otherTy = dyn_cast(other.getType()); + if (!selfTy || !otherTy || + selfTy.getElementType() != otherTy.getElementType()) + return nullptr; + + // If both values are splats we can just compute the output value as a splat. + if (self.isSplat() && other.isSplat()) { + if (isa(selfTy.getElementType())) { + APFloat lhsFp = self.getSplatValue(); + APFloat rhsFp = other.getSplatValue(); + bool eq = lhsFp.compare(rhsFp) == APFloat::cmpEqual; + return DenseElementsAttr::get(bty, eq); + } + + if (isa(selfTy.getElementType())) { + APInt lhsInt = self.getSplatValue(); + APInt rhsInt = other.getSplatValue(); + bool eq = lhsInt == rhsInt; + return DenseElementsAttr::get(bty, eq); + } + + return nullptr; + } + + if (selfTy != otherTy || bty.getNumElements() > kMaxFold) + return nullptr; + + if (isa(selfTy.getElementType())) { + auto extract = [bty](DenseElementsAttr attr) { + llvm::SmallVector vals; + if (attr.isSplat()) { + vals.resize(bty.getNumElements(), attr.getSplatValue()); + return vals; + } + + for (auto fp : attr.getValues()) { + vals.push_back(fp); + } + return vals; + }; + + llvm::SmallVector lhsFp = extract(self); + llvm::SmallVector rhsFp = extract(other); + llvm::SmallVector vals(bty.getNumElements()); + for (int i = 0, s = bty.getNumElements(); i < s; ++i) { + vals[i] = lhsFp[i].compare(rhsFp[i]) == APFloat::cmpEqual; + } + + return DenseElementsAttr::get(bty, vals); + } + + if (isa(selfTy.getElementType())) { + auto extract = [bty](DenseElementsAttr attr) { + llvm::SmallVector vals; + if (attr.isSplat()) { + vals.resize(bty.getNumElements(), attr.getSplatValue()); + return vals; + } + + for (auto fp : attr.getValues()) { + vals.push_back(fp); + } + return vals; + }; + + llvm::SmallVector lhsInt = extract(self); + llvm::SmallVector rhsInt = extract(other); + llvm::SmallVector vals(bty.getNumElements()); + for (int i = 0, s = bty.getNumElements(); i < s; ++i) { + vals[i] = lhsInt[i] == rhsInt[i]; + } + + return DenseElementsAttr::get(bty, vals); + } + + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenFloorOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 5d0644381612..6f674601393d 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -291,7 +291,6 @@ def emit_with_mutating_variants(key, **kwargs): "aten::logical_not : (Tensor) -> (Tensor)", "aten::lerp.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", "aten::lerp.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)", - "aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::gt.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::ge.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::lt.Tensor : (Tensor, Tensor) -> (Tensor)", @@ -343,6 +342,7 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 28b4f6933c5a..03eeaaeb525b 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2545,3 +2545,111 @@ func.func @aten_select_int_fold_3D() -> !torch.vtensor<[1, 1, 1],si64> { // CHECK: return %[[RET]] return %select : !torch.vtensor<[1,1,1],si64> } + +// ----- + + +// CHECK-LABEL: @aten_eq_tensor_args +func.func @aten_eq_tensor_args(%arg0 : !torch.vtensor<[4],si64>) -> !torch.vtensor<[4],i1> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[RET]] + %0 = torch.aten.eq.Tensor %arg0, %arg0 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_eq_tensor_splats_int_false +func.func @aten_eq_tensor_splats_int_false() -> !torch.vtensor<[4],i1> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[RET]] + %lhs = torch.vtensor.literal(dense<4> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %rhs = torch.vtensor.literal(dense<5> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %0 = torch.aten.eq.Tensor %lhs, %rhs : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_eq_tensor_splats_int_true +func.func @aten_eq_tensor_splats_int_true() -> !torch.vtensor<[4],i1> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[RET]] + %lhs = torch.vtensor.literal(dense<5> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %rhs = torch.vtensor.literal(dense<5> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %0 = torch.aten.eq.Tensor %lhs, %rhs : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_eq_tensor_splats_fp_false +func.func @aten_eq_tensor_splats_fp_false() -> !torch.vtensor<[4],i1> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[RET]] + %lhs = torch.vtensor.literal(dense<4.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %rhs = torch.vtensor.literal(dense<5.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.eq.Tensor %lhs, %rhs : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_eq_tensor_splats_fp_true +func.func @aten_eq_tensor_splats_fp_true() -> !torch.vtensor<[4],i1> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[RET]] + %lhs = torch.vtensor.literal(dense<5.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %rhs = torch.vtensor.literal(dense<5.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.eq.Tensor %lhs, %rhs : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_eq_tensor_splat_dense_fp +func.func @aten_eq_tensor_splat_dense_fp() -> !torch.vtensor<[4],i1> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<[false, true, false, true]> : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[RET]] + %lhs = torch.vtensor.literal(dense<5.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %rhs = torch.vtensor.literal(dense<[4.0, 5.0, 6.0, 5.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.eq.Tensor %lhs, %rhs : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_eq_tensor_dense_fp +func.func @aten_eq_tensor_dense_fp() -> !torch.vtensor<[4],i1> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<[true, false, true, false]> : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[RET]] + %lhs = torch.vtensor.literal(dense<[4.0, 5.5, 6.0, 6.4]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %rhs = torch.vtensor.literal(dense<[4.0, 5.0, 6.0, 5.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.eq.Tensor %lhs, %rhs : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_eq_tensor_splat_dense_int +func.func @aten_eq_tensor_splat_dense_int() -> !torch.vtensor<[4],i1> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<[false, true, false, true]> : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[RET]] + %lhs = torch.vtensor.literal(dense<5> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %rhs = torch.vtensor.literal(dense<[4, 5, 6, 5]> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %0 = torch.aten.eq.Tensor %lhs, %rhs : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_eq_tensor_dense_int +func.func @aten_eq_tensor_dense_int() -> !torch.vtensor<[4],i1> { + // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<[true, true, true, false]> : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[RET]] + %lhs = torch.vtensor.literal(dense<[4, 5, 6, 6]> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %rhs = torch.vtensor.literal(dense<[4, 5, 6, 5]> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %0 = torch.aten.eq.Tensor %lhs, %rhs : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + From b8c48cf283c076c0b55998702ea380d74c1322a8 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Mon, 12 Feb 2024 01:05:00 +0800 Subject: [PATCH 181/283] =?UTF-8?q?Bump=20stablehlo=20to=20openxla/stableh?= =?UTF-8?q?lo@e191eb4c3c3f3144503a8a117d760de5d=E2=80=A6=20(#2891)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …dcc7e89. * to involve `chlo-legalize-to-stablehlo` pass. --- externals/stablehlo | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/externals/stablehlo b/externals/stablehlo index fd52182f76ca..e191eb4c3c3f 160000 --- a/externals/stablehlo +++ b/externals/stablehlo @@ -1 +1 @@ -Subproject commit fd52182f76cadb82f2064fe5fc49a4fb4347a826 +Subproject commit e191eb4c3c3f3144503a8a117d760de5ddcc7e89 From bfb93cb99f259a614b4be7f6a4d04f6a07e4d395 Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Mon, 12 Feb 2024 09:19:39 -0800 Subject: [PATCH 182/283] Fix test_add_uint8 failure to lower to linalg (#2893) By updating convertScalarToDtype invocation pass original source and destination datatypes for the add op. Also fixes a potential problem with the sub op. --------- Co-authored-by: Xida Ren --- lib/Conversion/TorchToLinalg/Uncategorized.cpp | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 479bc1c0d620..e4b683d41cee 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -575,7 +575,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( .getElementType(); Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); - Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype); + Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/dtype); if (dtype.isa()) { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); @@ -613,7 +615,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value other = convertScalarToDtype(b, loc, operands[1], dtype); - Value alpha = convertScalarToDtype(b, loc, operands[2], dtype); + Value alpha = convertScalarToDtype( + b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(), + /*dstOriginalDtype=*/dtype); if (dtype.isa()) { Value mult = b.create(loc, other, alpha); return b.create(loc, self, mult); @@ -1118,7 +1122,9 @@ static Value createLinalgPayloadCalculationForElementwiseOp( .getElementType(); Value self = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value other = convertScalarToDtype(b, loc, operands[1], dtype); - Value alpha = convertScalarToDtype(b, loc, operands[2], dtype); + Value alpha = convertScalarToDtype( + b, loc, operands[2], dtype, /*srcOriginalDtype=*/operands[2].getType(), + /*dstOriginalDtype=*/dtype); if (dtype.isa()) { Value mult = b.create(loc, self, alpha); return b.create(loc, other, mult); From be8375d35037ca4ca496d0de4052745f5472277b Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Mon, 12 Feb 2024 10:04:54 -0800 Subject: [PATCH 183/283] [torch-mlir][sparse] implement first sparse_jit end-to-end path (#2894) This PR introduces a sparse_jit wrapper that can run simple models with sparse tensor inputs end-to-end. The implementation shows all required components on modifying sparse tensor types with a 1:N relation on the call sites. Two tests shows that the JIT runs end-to-end while computing the correct results. More details to follow (generalizing to COO and different ranks, as well as support for *output* sparse tensors), but the general concepts are all here now. **_Update: Thanks to Rob, bump to proper LLVM/MLIR hash is done!_** _**NOTE that all parameter passing changes are nicely done "downstream" in MLIR, so very little changes are required in torch-mlir code proper**_ --------- Co-authored-by: Franz Haniel <77495327+frafranz@users.noreply.github.com> Co-authored-by: Franz Haniel --- .../linalg_on_tensors_backends/refbackend.py | 1 + python/torch_mlir/extras/fx_importer.py | 55 +++--- test/python/fx_importer/sparse_test.py | 157 +++++++++++++++--- 3 files changed, 168 insertions(+), 45 deletions(-) diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 0b7b28e9df71..c0b5eabf5f04 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -136,6 +136,7 @@ def invoke(*args): "convert-shape-to-std", # MLIR Sparsifier mini-pipeline. Note that this is the bare minimum # to ensure operations on sparse tensors are lowered to loops. + "sparse-assembler", "sparsification-and-bufferization", "sparse-storage-specifier-to-llvm", # Bufferize. diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 5328e8730cc3..2e9ba233367f 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -208,7 +208,9 @@ } -def sparsity_encoding(shape: torch.Size, sparse_layout: torch.layout) -> str: +def sparsity_encoding( + shape: torch.Size, sparsity: tuple[torch.layout, int, int] +) -> str: """Returns sparse tensor encoding for the given sparse layout as string. The method currently just supports 2-dim sparse formats. This should be @@ -216,20 +218,24 @@ def sparsity_encoding(shape: torch.Size, sparse_layout: torch.layout) -> str: and suffix dense subtensor dimensions. Since MLIR supports a superset of what is currently implememented in torch.sparse, this should not a be problem. """ + assert sparsity is not None + sparse_layout, posw, crdw = sparsity # TODO: any rank if len(shape) != 2: raise RuntimeError(f"Unsupported sparse rank {len(shape)}") if sparse_layout is torch.sparse_coo: - return "#sparse_tensor.encoding<{map=(i,j)->(i:compressed(nonunique),j:singleton)}>" - if sparse_layout is torch.sparse_csr: - return "#sparse_tensor.encoding<{map=(i,j)->(i:dense,j:compressed)}>" - if sparse_layout is torch.sparse_csc: - return "#sparse_tensor.encoding<{map=(i,j)->(j:dense,i:compressed)}>" - # TODO: block format (derive block size!) + smap = f"(i,j)->(i:compressed(nonunique),j:singleton)" + elif sparse_layout is torch.sparse_csr: + smap = f"(i,j)->(i:dense,j:compressed)" + elif sparse_layout is torch.sparse_csc: + smap = f"(i,j)->(j:dense,i:compressed)" + else: + # TODO: block format (derive block size!) + raise RuntimeError(f"Unsupported sparse layout {sparse_layout}") - raise RuntimeError(f"Unsupported sparse layout {sparse_layout}") + return f"#sparse_tensor.encoding<{{map={smap},posWidth={posw},crdWidth={crdw}}}>" def is_symbolic(obj: Any) -> bool: @@ -479,14 +485,19 @@ def format_asm_shape(self, shape: torch.Size) -> str: """Return IrType for !torch.vtensor with the given shape and dtype""" def get_vtensor_type( - self, shape: torch.Size, dtype: torch.dtype, sparse_layout: torch.layout = None + self, + shape: torch.Size, + dtype: torch.dtype, + *, + sparsity: Optional[tuple[torch.layout, int, int]] = None, # keyword-only ): shape_asm = self.format_asm_shape(shape) mlir_dtype = str(self.dtype_to_type(dtype)) - if sparse_layout is not None: - sparsity = sparsity_encoding(shape, sparse_layout) + if sparsity is not None: + encoding = sparsity_encoding(shape, sparsity) + assert encoding is not None return IrType.parse( - f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)},{sparsity}>", + f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)},{encoding}>", context=self._c, ) return IrType.parse( @@ -497,7 +508,7 @@ def node_val_to_type(self, node: torch_fx.Node) -> IrType: try: tensor_meta = node.meta.get("tensor_meta") val = node.meta.get("val") - sparse_layout = node.meta.get("sparsity", None) + sparsity = node.meta.get("sparsity", None) if tensor_meta is not None: assert isinstance(tensor_meta, TensorMetadata) # Quantized tensor meta data is not preserved in our lowering, @@ -507,12 +518,14 @@ def node_val_to_type(self, node: torch_fx.Node) -> IrType: f"Quantized tensor meta data is not supported." ) else: - return self.tensor_metadata_to_type(tensor_meta, sparse_layout) + return self.tensor_metadata_to_type(tensor_meta, sparsity=sparsity) elif val is not None: # some nodes with symbolic inputs pass a 'val' attribute rather than # tensor_meta if isinstance(val, TorchFakeTensor): - return self.get_vtensor_type(val.size(), val.dtype, sparse_layout) + return self.get_vtensor_type( + val.size(), val.dtype, sparsity=sparsity + ) t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val)) if t is not None: @@ -528,16 +541,19 @@ def node_val_to_type(self, node: torch_fx.Node) -> IrType: ) def tensor_metadata_to_type( - self, tm: TensorMetadata, sparse_layout: torch.layout = None + self, + tm: TensorMetadata, + *, + sparsity: Optional[tuple[torch.layout, int, int]] = None, # keyword-only ) -> IrType: tm_shape = tuple( item.node if is_symbolic(item) else item for item in list(tm.shape) ) - key = (tm_shape, tm.dtype, sparse_layout) + key = (tm_shape, tm.dtype, sparsity) t = self._tensor_metadata_cache.get(key) if t is None: - t = self.get_vtensor_type(tm.shape, tm.dtype, sparse_layout) + t = self.get_vtensor_type(tm.shape, tm.dtype, sparsity=sparsity) self._tensor_metadata_cache[key] = t return t @@ -1128,7 +1144,8 @@ def lookup(self, t: type) -> Any: # Opaque value to indicate something is empty. Used in cases where 'None' # may have a different meaning. -class EmptyType: ... +class EmptyType: + ... Empty = EmptyType() diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 1490c160c3f1..f62dea11e222 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -5,7 +5,7 @@ # RUN: %PYTHON %s | FileCheck %s -from typing import Any, Callable, Dict, Optional, Tuple +from typing import Any, Callable, Optional import torch import torch.export @@ -14,6 +14,10 @@ from torch_mlir.extras.fx_importer import FxImporter from torch_mlir import ir from torch_mlir.dialects import torch as torch_d +from torch_mlir.compiler_utils import run_pipeline_with_repro_report +from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import ( + RefBackendLinalgOnTensorsBackend, +) # All sparse layouts currently supported in torch.sparse. @@ -22,13 +26,50 @@ torch.sparse_csr, torch.sparse_csc, torch.sparse_bsr, - torch.sparse_bsc + torch.sparse_bsc, ] -def sparse_export(f: Callable, - args: Tuple[Any, ...], - kwargs: Optional[Dict[str, Any]] = None) -> torch.export.ExportedProgram: +def sparse_overhead_width(d: torch.dtype) -> int: + """Returns bit-width for admissible overhead type.""" + if d is torch.int64: + return 64 + if d is torch.int32: + return 32 + if d is torch.int16: + return 16 + if d is torch.int8: + return 8 + raise RuntimeError(f"Unsupported overhead type {d}") + + +def sparse_metadata(a: torch.Tensor) -> tuple[torch.layout, int, int]: + """Returns a meta data tuple for the given sparse tensor.""" + if a.layout is torch.sparse_coo: + return ( + a.layout, + sparse_overhead_width(a.indices().dtype), + sparse_overhead_width(a.indices().dtype), + ) + elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: + return ( + a.layout, + sparse_overhead_width(a.crow_indices().dtype), + sparse_overhead_width(a.col_indices().dtype), + ) + elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc: + return ( + a.layout, + sparse_overhead_width(a.ccol_indices().dtype), + sparse_overhead_width(a.row_indices().dtype), + ) + else: + raise RuntimeError(f"Unsupported sparse layout for {a}") + + +def sparse_export( + f: Callable, args: tuple[Any, ...], kwargs: Optional[dict[str, Any]] = None +) -> torch.export.ExportedProgram: """ This is a ***temporary*** wrapper around `torch.export.export` that eventually should be removed and simply replaced by the @@ -47,17 +88,16 @@ def sparse_export(f: Callable, resovled. """ # Convert all arguments to dense. - dargs = tuple( a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args ) - mask = [ a.layout in SPARSE_LAYOUTS for a in args ] + dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args) + mask = [a.layout in SPARSE_LAYOUTS for a in args] # Build the regular FX traced graph with only dense arguments # (the current version would crash otherwise, see issue above). prog = torch.export.export(f, dargs, kwargs, constraints=None) # Annotate sparse arguments in the graph. alen = len(args) for i, node in enumerate(prog.graph.nodes): - if node.op == "placeholder" and i < alen and mask[i]: - node.meta['sparsity'] = args[i].layout - # TODO: annotate inputs to change calling conventions! + if node.op == "placeholder" and i < alen and mask[i]: + node.meta["sparsity"] = sparse_metadata(args[i]) return prog @@ -68,7 +108,46 @@ def export_and_import(f, *args, **kwargs): fx_importer = FxImporter(context=context) prog = sparse_export(f, args, kwargs) fx_importer.import_frozen_exported_program(prog) - return fx_importer.module_op + return fx_importer.module + + +def sparse_jit(f, *args, **kwargs): + """This method compiles and runs the given callable using linalg backend.""" + # Import module and lower into Linalg IR. + module = export_and_import(f, *args, *kwargs) + run_pipeline_with_repro_report( + module, + "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", + "Lowering TorchFX IR -> Linalg IR", + enable_ir_printing=False, + ) + # Compile with reference Linalg backend. + backend = RefBackendLinalgOnTensorsBackend() + compiled = backend.compile(module) + invoker = backend.load(compiled) + # Prepare input parameters. Sparse input tensors are split into + # their composite tensors. All PyTorch tensors are converted + # to their backing numpy arrays. + # + # TODO: sparse output tensors + # + xargs = [] + for a in args: + if a.layout is torch.sparse_coo: + xargs.append(a.values().numpy()) + xargs.append(a.indices().numpy()) + elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: + xargs.append(a.values().numpy()) + xargs.append(a.crow_indices().numpy()) + xargs.append(a.col_indices().numpy()) + elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc: + xargs.append(a.values().numpy()) + xargs.append(a.ccol_indices().numpy()) + xargs.append(a.row_indices().numpy()) + else: + xargs.append(a.numpy()) + # Invoke. + return invoker.main(*xargs) def run(f): @@ -80,51 +159,77 @@ def run(f): @run # CHECK-LABEL: test_sparse_sum -# CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed) }> +# CHECK: #[[$CSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : dense, d1 : compressed), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( # CHECK-SAME: %[[A:.*]]: !torch.vtensor<[64,64],f32,#[[$CSR]]>) -> !torch.vtensor<[],f32> { # CHECK: %[[N:.*]] = torch.constant.none # CHECK: %[[R:.*]] = torch.aten.sum %[[A]], %[[N]] : !torch.vtensor<[64,64],f32,#[[$CSR]]>, !torch.none -> !torch.vtensor<[],f32> # CHECK: return %[[R]] : !torch.vtensor<[],f32> # CHECK: } +# +# CHECK: torch.sparse = tensor(4096.) +# CHECK: torch.mlir = 4096.0 +# def test_sparse_sum(): - class SumNet(torch.nn.Module): - def __init__(self): super(SumNet, self).__init__() def forward(self, x): return x.sum() - - dense_input = torch.ones(64, 64) + dense_input = torch.ones(64, 64) sparse_input = dense_input.to_sparse_csr() m = export_and_import(SumNet(), sparse_input) print(m) + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. + net = SumNet() + res1 = net(sparse_input) + res2 = sparse_jit(net, sparse_input) + print("torch.sparse =", res1) + print("torch.mlir =", res2) + @run # CHECK-LABEL: test_sparse_SpMM -# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton) }> +# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( -# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[64,64],f32,#[[$COO]]>, -# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[64,64],f32>) -> !torch.vtensor<[64,64],f32> { -# CHECK: %[[R:.*]] = torch.aten.mm %[[A]], %[[B]] : !torch.vtensor<[64,64],f32,#[[$COO]]>, !torch.vtensor<[64,64],f32> -> !torch.vtensor<[64,64],f32> -# CHECK: return %[[R]] : !torch.vtensor<[64,64],f32> +# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>, +# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[8,8],f32>) -> !torch.vtensor<[8,8],f32> { +# CHECK: %[[R:.*]] = torch.aten.mm %[[A]], %[[B]] : !torch.vtensor<[8,8],f32,#[[$COO]]>, !torch.vtensor<[8,8],f32> -> !torch.vtensor<[8,8],f32> +# CHECK: return %[[R]] : !torch.vtensor<[8,8],f32> # CHECK: } +# +# CHECK: torch.sparse +# CHECK: tensor({{\[}}[8., 8., 8., 8., 8., 8., 8., 8.], +# CHECK-COUNT-6: [8., 8., 8., 8., 8., 8., 8., 8.], +# CHECK: [8., 8., 8., 8., 8., 8., 8., 8.]{{\]}}) +# CHECK: torch.mlir +# CHECK: {{\[}}[8. 8. 8. 8. 8. 8. 8. 8.] +# CHECK-COUNT-6: [8. 8. 8. 8. 8. 8. 8. 8.] +# CHECK: [8. 8. 8. 8. 8. 8. 8. 8.]{{\]}} +# def test_sparse_SpMM(): - class MatMulNet(torch.nn.Module): - def __init__(self): super(MatMulNet, self).__init__() def forward(self, x, y): - return torch.matmul(x, y) - + return torch.matmul(x, y) - dense_input = torch.ones(64, 64) + dense_input = torch.ones(8, 8) sparse_input = dense_input.to_sparse_coo() m = export_and_import(MatMulNet(), sparse_input, dense_input) print(m) + + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. + # TODO: run with COO, right now only CSR works + sparse_input = dense_input.to_sparse_csr() + net = MatMulNet() + res1 = net(sparse_input, dense_input) + res2 = sparse_jit(net, sparse_input, dense_input) + print("torch.sparse") + print(res1) + print("torch.mlir") + print(res2) From 370d6ac9a2f723de7c7609d4128e58ac6d363b00 Mon Sep 17 00:00:00 2001 From: Ashay Rane Date: Mon, 12 Feb 2024 17:31:41 -0600 Subject: [PATCH 184/283] build: find Protobuf using config mode search (#2900) This patch makes the Protobuf package mandatory in addition to forcing a config mode search. The (default) module mode search looks for the CMake-provided FindProtobuf.cmake file, but this file does not list Abseil as a dependency, causing linker issues like the one below: ``` ld: Undefined symbols: absl::lts_20230802::log_internal::LogMessageFatal::LogMessageFatal(char const*, int, std::__1::basic_string_view>), referenced from: google::protobuf::RepeatedPtrField, std::__1::allocator>>::TypeHandler::Type const& google::protobuf::internal::RepeatedPtrFieldBase::Get, std::__1::allocator>>::TypeHandler>(int) const (.cold.1) in OnnxImporter.cpp.o ``` By forcing a config mode search, CMake looks for the file that is installed as part of the protobuf package and which does contain the Abseil dependency. This workaround is also mentioned in a GitHub issue for Protobuf: https://github.com/protocolbuffers/protobuf/issues/12292#issuecomment-1529680040. --- projects/onnx_c_importer/CMakeLists.txt | 12 +----------- 1 file changed, 1 insertion(+), 11 deletions(-) diff --git a/projects/onnx_c_importer/CMakeLists.txt b/projects/onnx_c_importer/CMakeLists.txt index b685c732f5dc..681ca14feafc 100644 --- a/projects/onnx_c_importer/CMakeLists.txt +++ b/projects/onnx_c_importer/CMakeLists.txt @@ -2,17 +2,7 @@ message(STATUS "Enabling onnx_c_importer...") include(FetchContent) -find_package(Protobuf) -if(NOT Protobuf_FOUND) - message(FATAL_ERROR - "In order to build C ONNX support, the Protobuf package must be installed " - "on the system. Without this ONNX will attempt to build it in the project " - "and the dependent ABSEIL build system is incompatible. " - "On Ubuntu, install with: " - "apt install libprotobuf-dev protobuf-compiler\n\n" - "(or this entire component can be disabled with " - "-DTORCH_MLIR_ENABLE_ONNX_C_IMPORTER=OFF)") -endif() +find_package(Protobuf REQUIRED CONFIG) option(ONNX_DISABLE_EXCEPTIONS "For compatibility with LLVM build" ON) From b6f4ca512ea93eaa34aad7b16a2bf6ff8d01350b Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Mon, 12 Feb 2024 16:10:57 -0800 Subject: [PATCH 185/283] [torch-mlir][sparse] sparsity metadata refinement (#2901) Various improvements on sparsity metadata: (1) define single data structure for all sparsity related metadata (2) handle batched dense dimensions, as well as dense subtensor dimensions (3) refine sparsity propagation for deeper networks --- python/torch_mlir/extras/fx_importer.py | 73 ++++++++++++++++-------- test/python/fx_importer/sparse_test.py | 74 ++++++++++++++++++++++--- 2 files changed, 115 insertions(+), 32 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 2e9ba233367f..6749c6078e49 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -14,6 +14,7 @@ import logging import operator import re +from dataclasses import dataclass from types import BuiltinMethodType, BuiltinFunctionType from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union import weakref @@ -208,34 +209,58 @@ } -def sparsity_encoding( - shape: torch.Size, sparsity: tuple[torch.layout, int, int] -) -> str: - """Returns sparse tensor encoding for the given sparse layout as string. +@dataclass(frozen=True) +class SparsityMeta: + """Class for keeping track of sparsity meta data.""" - The method currently just supports 2-dim sparse formats. This should be - generalized to the torch.sparse encodings for prefix dense batch dimensions - and suffix dense subtensor dimensions. Since MLIR supports a superset of what - is currently implememented in torch.sparse, this should not a be problem. - """ + layout: torch.layout + batch_dim: int + sparse_dim: int + dense_dim: int + pos_width: int + crd_width: int + + +def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str: + """Returns sparse tensor encoding for the given sparse layout as string.""" assert sparsity is not None - sparse_layout, posw, crdw = sparsity - - # TODO: any rank - if len(shape) != 2: - raise RuntimeError(f"Unsupported sparse rank {len(shape)}") - - if sparse_layout is torch.sparse_coo: - smap = f"(i,j)->(i:compressed(nonunique),j:singleton)" - elif sparse_layout is torch.sparse_csr: - smap = f"(i,j)->(i:dense,j:compressed)" - elif sparse_layout is torch.sparse_csc: - smap = f"(i,j)->(j:dense,i:compressed)" + + # Sparse tensors have the form + # [ , , ] + # which map directly to MLIR types. + batch_dim, sparse_dim, dense_dim = ( + sparsity.batch_dim, + sparsity.sparse_dim, + sparsity.dense_dim, + ) + dim = batch_dim + sparse_dim + dense_dim + assert dim == len(shape) + + dims = ",".join(f"d{d}" for d in range(0, dim)) + + if sparsity.layout is torch.sparse_coo: + assert sparse_dim == 2 # TODO: deeper sparse dims + lvls = f"d{batch_dim}:compressed(nonunique),d{batch_dim+1}:singleton" + elif sparsity.layout is torch.sparse_csr: + assert sparse_dim == 2 + lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed" + elif sparsity.layout is torch.sparse_csc: + assert sparse_dim == 2 + lvls = f"d{batch_dim+1}:dense,d{batch_dim}:compressed" else: # TODO: block format (derive block size!) raise RuntimeError(f"Unsupported sparse layout {sparse_layout}") - return f"#sparse_tensor.encoding<{{map={smap},posWidth={posw},crdWidth={crdw}}}>" + if batch_dim > 0: + batch = ",".join(f"d{d}:dense" for d in range(0, batch_dim)) + lvls = f"{batch},{lvls}" + + if dense_dim > 0: + dense = ",".join(f"d{d}:dense" for d in range(batch_dim + sparse_dim, dim)) + lvls = f"{lvls},{dense}" + + posw, crdw = sparsity.pos_width, sparsity.crd_width + return f"#sparse_tensor.encoding<{{map=({dims})->({lvls}),posWidth={posw},crdWidth={crdw}}}>" def is_symbolic(obj: Any) -> bool: @@ -489,7 +514,7 @@ def get_vtensor_type( shape: torch.Size, dtype: torch.dtype, *, - sparsity: Optional[tuple[torch.layout, int, int]] = None, # keyword-only + sparsity: Optional[SparsityMeta] = None, # keyword-only ): shape_asm = self.format_asm_shape(shape) mlir_dtype = str(self.dtype_to_type(dtype)) @@ -544,7 +569,7 @@ def tensor_metadata_to_type( self, tm: TensorMetadata, *, - sparsity: Optional[tuple[torch.layout, int, int]] = None, # keyword-only + sparsity: Optional[SparsityMeta] = None, # keyword-only ) -> IrType: tm_shape = tuple( item.node if is_symbolic(item) else item for item in list(tm.shape) diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index f62dea11e222..161b29148dcd 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -12,6 +12,7 @@ import torch.nn as nn from torch_mlir.extras.fx_importer import FxImporter +from torch_mlir.extras.fx_importer import SparsityMeta from torch_mlir import ir from torch_mlir.dialects import torch as torch_d from torch_mlir.compiler_utils import run_pipeline_with_repro_report @@ -43,23 +44,35 @@ def sparse_overhead_width(d: torch.dtype) -> int: raise RuntimeError(f"Unsupported overhead type {d}") -def sparse_metadata(a: torch.Tensor) -> tuple[torch.layout, int, int]: +def sparse_metadata(a: torch.Tensor) -> SparsityMeta: """Returns a meta data tuple for the given sparse tensor.""" + sparse_dim = a.sparse_dim() + dense_dim = a.dense_dim() + batch_dim = a.ndim - dense_dim - sparse_dim if a.layout is torch.sparse_coo: - return ( + return SparsityMeta( a.layout, + batch_dim, + sparse_dim, + dense_dim, sparse_overhead_width(a.indices().dtype), sparse_overhead_width(a.indices().dtype), ) elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: - return ( + return SparsityMeta( a.layout, + batch_dim, + sparse_dim, + dense_dim, sparse_overhead_width(a.crow_indices().dtype), sparse_overhead_width(a.col_indices().dtype), ) elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc: - return ( + return SparsityMeta( a.layout, + batch_dim, + sparse_dim, + dense_dim, sparse_overhead_width(a.ccol_indices().dtype), sparse_overhead_width(a.row_indices().dtype), ) @@ -93,11 +106,21 @@ def sparse_export( # Build the regular FX traced graph with only dense arguments # (the current version would crash otherwise, see issue above). prog = torch.export.export(f, dargs, kwargs, constraints=None) - # Annotate sparse arguments in the graph. - alen = len(args) + # Annotate sparse arguments in the graph. Note that we currently + # only account for sparsity defined by the user inputs to the model. + # TODO: support sparsity in model parameters (weights, biases) + # TODO: propagate sparsity into the layers + specs = prog.graph_signature.input_specs + alen = len(specs) + k = 0 for i, node in enumerate(prog.graph.nodes): - if node.op == "placeholder" and i < alen and mask[i]: - node.meta["sparsity"] = sparse_metadata(args[i]) + if i >= alen: + break + spec = specs[i] + if spec.kind is torch.export.graph_signature.InputKind.USER_INPUT: + if mask[k]: + node.meta["sparsity"] = sparse_metadata(args[k]) + k = k + 1 return prog @@ -233,3 +256,38 @@ def forward(self, x, y): print(res1) print("torch.mlir") print(res2) + + +@run +# CHECK-LABEL: test_sparse_eltwise +# CHECK: #[[$BCSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : dense, d2 : compressed), posWidth = 64, crdWidth = 64 }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$BCSR]]>) -> !torch.vtensor<[8,4,2],f32> { +# CHECK: %[[R:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[8,4,2],f32,#[[$BCSR]]> -> !torch.vtensor<[8,4,2],f32> +# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32> +# CHECK: } +# CHECK: #[[$CSRD:.*]] = #sparse_tensor.encoding<{ map = (d0, d1, d2) -> (d0 : dense, d1 : compressed, d2 : dense), posWidth = 64, crdWidth = 64 }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*]]: !torch.vtensor<[8,4,2],f32,#[[$CSRD]]>) -> !torch.vtensor<[8,4,2],f32> { +# CHECK: %[[R:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> -> !torch.vtensor<[8,4,2],f32> +# CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32> +# CHECK: } +def test_sparse_eltwise(): + class EltNet(torch.nn.Module): + def __init__(self): + super(EltNet, self).__init__() + + def forward(self, x): + return -x + + dense_input = torch.ones(8, 4, 2) + + # This yields a **batched** CSR. + sparse_input = dense_input.to_sparse_csr(dense_dim=0) + m = export_and_import(EltNet(), sparse_input) + print(m) + + # This yields a plain CSR with dense **sub**tensor + sparse_input = dense_input.to_sparse_csr(dense_dim=1) + m = export_and_import(EltNet(), sparse_input) + print(m) From 9b967f6b5ab49b344af9a2c56659784502f2c488 Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Mon, 12 Feb 2024 23:08:21 -0800 Subject: [PATCH 186/283] [MLIR][ONNX] Add OnnxToTorch support for Mean, IsInf, IsNaN, PRelu op (#2801) This commit adds the OnnxToTorch support for Mean, IsInf, IsNaN, and PRelu ops. All high priority ops were taken so went with these. The non trivial ones are Mean and IsInf which might require extra review --------- Co-authored-by: MaheshRavishankar --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 103 ++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 54 +++++++++ 2 files changed, 157 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 1760a0a20672..dd9fa211b551 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1006,4 +1006,107 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, tensor, /*memory_format=*/noneVal); return success(); }); + patterns.onOp( + "Mean", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + if (binder.op->getNumOperands() == 1) { + Torch::ValueTensorType resultType; + Value x; + if (binder.tensorOperand(x) || binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOp(binder.op, x); + return success(); + } + Torch::ValueTensorType resultType; + SmallVector valList; + int64_t numOperands = binder.op->getNumOperands(); + Value numOperandsConstant = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), numOperands)); + if (binder.tensorOperands(valList, numOperands) || + binder.tensorResultType(resultType)) + return failure(); + Value constOne = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 1)); + // Short circuit to binary add + Value curr = rewriter.create( + binder.getLoc(), resultType, valList[0], valList[1], constOne); + if (numOperands == 2) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, curr, numOperandsConstant); + return success(); + } + // When binder.op->getNumOperands() > 2 + auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation( + binder.op->getContext()); + for (int i = 2; i < numOperands; i++) { + if (i == numOperands - 1) { + curr = rewriter.create( + binder.getLoc(), resultType, curr, valList[i], constOne); + } else { + curr = rewriter.create( + binder.getLoc(), baseType, curr, valList[i], constOne); + } + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, curr, numOperandsConstant); + return success(); + }); + patterns.onOp( + "IsInf", 10, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value tensor; + int64_t neg; + int64_t pos; + if (binder.tensorOperand(tensor) || + binder.s64IntegerAttr(neg, "detect_negative", 1) || + binder.s64IntegerAttr(pos, "detect_positive", 1) || + binder.tensorResultType(resultType)) { + return failure(); + } + if (neg == 0) { + // replace all negative infs with 0 + tensor = rewriter.create( + binder.getLoc(), + dyn_cast(tensor.getType()), tensor); + } + if (pos == 0) { + // first use neg op to flip positive inf to negative inf. Then relu to + // replace all positive infs with 0. + Value flip = rewriter.create( + binder.getLoc(), + dyn_cast(tensor.getType()), tensor); + tensor = rewriter.create( + binder.getLoc(), dyn_cast(flip.getType()), + flip); + } + rewriter.replaceOpWithNewOp(binder.op, resultType, + tensor); + return success(); + }); + patterns.onOp("IsNaN", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value tensor; + if (binder.tensorOperand(tensor) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, tensor); + return success(); + }); + patterns.onOp("PRelu", 1, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value tensor; + Value slope; + if (binder.tensorOperands(tensor, slope) || + binder.tensorResultType(resultType)) { + return failure(); + } + rewriter.replaceOpWithNewOp( + binder.op, resultType, tensor, slope); + return success(); + }); } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index bbef289ff5f2..0a154db29323 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -650,3 +650,57 @@ func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4], %0 = torch.operator "onnx.Identity"(%arg0) : (!torch.vtensor<[3,4], f32>) -> !torch.vtensor<[3,4], f32> return %0 : !torch.vtensor<[3,4], f32> } + +// CHECK-LABEL: func.func @test_mean_one_input + func.func @test_mean_one_input(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %0 = torch.operator "onnx.Mean"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> + } + +// CHECK-LABEL: func.func @test_mean_two_inputs + func.func @test_mean_two_inputs(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: torch.aten.div.Scalar %0, %int2 : !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Mean"(%arg0, %arg1) : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> + } + +// CHECK-LABEL: func.func @test_isinf_negative + func.func @test_isinf_negative(%arg0: !torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.neg %arg0 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],f32> + // CHECK: torch.aten.relu %0 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],f32> + // CHECK: torch.aten.isinf %1 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],i1> + %0 = torch.operator "onnx.IsInf"(%arg0) {torch.onnx.detect_positive = 0 : si64} : (!torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1> + return %0 : !torch.vtensor<[6],i1> + } + +// CHECK-LABEL: func.func @test_isinf_positive + func.func @test_isinf_positive(%arg0: !torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.relu %arg0 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],f32> + // CHECK: torch.aten.isinf %0 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],i1> + %0 = torch.operator "onnx.IsInf"(%arg0) {torch.onnx.detect_negative = 0 : si64} : (!torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1> + return %0 : !torch.vtensor<[6],i1> + } + +// CHECK-LABEL: func.func @test_isnan + func.func @test_isnan(%arg0: !torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.isnan %arg0 : !torch.vtensor<[6],f32> -> !torch.vtensor<[6],i1> + %0 = torch.operator "onnx.IsNaN"(%arg0) : (!torch.vtensor<[6],f32>) -> !torch.vtensor<[6],i1> + return %0 : !torch.vtensor<[6],i1> + } + +// CHECK-LABEL: func.func @test_prelu_example + func.func @test_prelu_example(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.prelu %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.PRelu"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> + } + +// CHECK-LABEL: func.func @test_prelu_broadcast + func.func @test_prelu_broadcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.prelu %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.PRelu"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> + } From 3e836d8dad551b6e5302de1b84840b90ee039c83 Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Tue, 13 Feb 2024 12:38:32 -0800 Subject: [PATCH 187/283] [fx_importer] Convert non-persistent buffers lifted as tensor constants (#2902) The investigation is largely recorded in https://github.com/llvm/torch-mlir/pull/2881, but this change allows us to capture non-persistent buffers that were lifted as tensor constants (after https://github.com/pytorch/pytorch/pull/118969 landed in upstream PyTorch), and propagate them to `Torch` dialect as "frozen" `torch.vtensor.literal`. I believe this patch should work with both nightly and stable PyTorch, but will let CI confirm the same. Thanks @stellaraccident for the valuable pointers and guidance. --------- Co-authored-by: Vivek Khandelwal --- python/torch_mlir/extras/fx_importer.py | 27 ++++++++++++++++++------- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 4 files changed, 23 insertions(+), 10 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 6749c6078e49..b70487ad5ad9 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -373,13 +373,26 @@ def import_frozen_exported_program(self, prog: torch.export.ExportedProgram): sig = prog.graph_signature state_dict = prog.state_dict arg_replacements: dict[str, Any] = {} - # Lift buffers. - for input_name, state_name in sig.inputs_to_buffers.items(): - try: - state_value = state_dict[state_name] - except KeyError as e: - raise AssertionError("Could not find state mapping for buffer") from e - arg_replacements[input_name] = state_value + + # If there is no "constants" attribute, consult the "state_dict". Otherwise, only look + # at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969 + if hasattr(prog, "constants"): + constants = prog.constants + # Lift tensor constants. + for input_name, state_name in sig.inputs_to_lifted_tensor_constants.items(): + try: + state_value = constants[state_name] + except KeyError as e: + raise AssertionError("Could not find state mapping for tensor constants") from e + arg_replacements[input_name] = state_value + else: + # Lift buffers. + for input_name, state_name in sig.inputs_to_buffers.items(): + try: + state_value = state_dict[state_name] + except KeyError as e: + raise AssertionError("Could not find state mapping for buffer") from e + arg_replacements[input_name] = state_value # Lift parameters. for input_name, state_name in sig.inputs_to_parameters.items(): diff --git a/pytorch-hash.txt b/pytorch-hash.txt index 16be42d6c147..d78b0d0694d4 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -72fcb9ad662bb941a266e3d747835382634c2be6 +3cbc8e89fd09b0ffb4914187b438f15c121e2302 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 1de47ff9a195..540e78dccd49 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.3.0.dev20240122 +torch==2.3.0.dev20240207 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index fad713123493..4f775c549c6c 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.18.0.dev20240122 +torchvision==0.18.0.dev20240207 From 24c2fc0b5f90d870cbfd967c81460bcc5686d24d Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Tue, 13 Feb 2024 13:42:56 -0800 Subject: [PATCH 188/283] [torch-mlir][sparse] add JIT test to expose pending issues (#2906) This test exposes issues that need fixing (1) propagate sparsity into the FX graph (over elt-wise) (2) batched dimensions need a new "dense(batch)" format --- test/python/fx_importer/sparse_test.py | 34 +++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 161b29148dcd..d0b94ac83656 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -272,6 +272,22 @@ def forward(self, x, y): # CHECK: %[[R:.*]] = torch.aten.neg %arg0 : !torch.vtensor<[8,4,2],f32,#[[$CSRD]]> -> !torch.vtensor<[8,4,2],f32> # CHECK: return %[[R]] : !torch.vtensor<[8,4,2],f32> # CHECK: } +# +# CHECK: torch.sparse +# CHECK: tensor(crow_indices=tensor([ 0, 4, 8, 12, 16, 20, 24, 28, 32]), +# CHECK: col_indices=tensor([0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, +# CHECK: 2, 3, 0, 1, 2, 3, 0, 1, 2, 3, 0, 1, 2, 3]), +# CHECK: values=tensor({{\[}}[ -1., -2.], +# CHECK: [ -3., -4.], +# ... +# CHECK: [-63., -64.]{{\]}}), size=(8, 4, 2), nnz=32, +# CHECK: layout=torch.sparse_csr) +# CHECK: torch.mlir +# CHECK: {{\[\[}}[ -1. -2.] +# CHECK: [ -3. -4.] +# ... +# CHECK: [-61. -62.] +# CHECK: [-63. -64.]{{\]\]}} def test_sparse_eltwise(): class EltNet(torch.nn.Module): def __init__(self): @@ -280,7 +296,9 @@ def __init__(self): def forward(self, x): return -x - dense_input = torch.ones(8, 4, 2) + dense_input = torch.reshape( + torch.arange(1, 65, dtype=torch.float32), shape=(8, 4, 2) + ) # This yields a **batched** CSR. sparse_input = dense_input.to_sparse_csr(dense_dim=0) @@ -291,3 +309,17 @@ def forward(self, x): sparse_input = dense_input.to_sparse_csr(dense_dim=1) m = export_and_import(EltNet(), sparse_input) print(m) + + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. + # + # TODO: note several issues that need to be fixed + # (1) since we do not propagate sparsity into elt-wise, MLIR returns dense result + # (2) for dense_dim=0, this will need a dense(batched) property + sparse_input = dense_input.to_sparse_csr(dense_dim=1) + net = EltNet() + res1 = net(sparse_input) + res2 = sparse_jit(net, sparse_input) + print("torch.sparse") + print(res1) + print("torch.mlir") + print(res2) From d6e1d836ca8d49da25ad5e2f10d6816bfbb6ba2f Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Tue, 13 Feb 2024 14:32:02 -0800 Subject: [PATCH 189/283] Drop torch attributes at the end of backend conversion. (#2876) Fixes https://github.com/llvm/torch-mlir/issues/2866 Some backends / downstream projects expect that a "fully converted" program has no remaining ops or attributes from the original dialect(s). --- .../Transforms/BackendTypeConversionPasses.cpp | 18 ++++++++++++++++++ test/Conversion/TorchToLinalg/basic.mlir | 2 +- .../finalizing-backend-type-conversion.mlir | 14 ++++++++++++++ 3 files changed, 33 insertions(+), 1 deletion(-) diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp index 5f3a2609be8c..5dd3d778f8f4 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp @@ -115,6 +115,21 @@ static void setupFinalization(ConversionTarget &target, setupFinalization(target, patterns, typeConverter); } +static void stripTorchAttrs(func::FuncOp func) { + bool modified = false; + SmallVector newAttrs; + for (auto attr : func->getDialectAttrs()) { + if (attr.getName().getValue().starts_with("torch.")) + modified = true; + else + newAttrs.push_back(attr); + } + if (modified) + func->setDialectAttrs(newAttrs); + + // Note: this could also strip "arg" and "result" attrs if they were used. +} + namespace { struct FinalizingBackendTypeConversionPass : public FinalizingBackendTypeConversionBase< @@ -151,6 +166,9 @@ struct FinalizingBackendTypeConversionPass if (failed(applyFullConversion(func, target, std::move(patterns)))) signalPassFailure(); + + // Drop attributes that are no longer used after conversion out of Torch. + stripTorchAttrs(func); } }; } // namespace diff --git a/test/Conversion/TorchToLinalg/basic.mlir b/test/Conversion/TorchToLinalg/basic.mlir index cfb252cd104a..f063f234e4e5 100644 --- a/test/Conversion/TorchToLinalg/basic.mlir +++ b/test/Conversion/TorchToLinalg/basic.mlir @@ -30,7 +30,7 @@ func.func @torch.aten.mm$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.v // ----- // CHECK-LABEL: func.func @torch.aten.matmul.2d -func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +func.func @torch.aten.matmul.2d(%arg0: !torch.vtensor<[8,16],f32>, %arg1: !torch.vtensor<[16,8],f32>) -> !torch.vtensor<[8,8],f32> { // CHECK-DAG: %[[LHS:.+]] = torch_c.to_builtin_tensor %arg0 : !torch.vtensor<[8,16],f32> -> tensor<8x16xf32> // CHECK-DAG: %[[RHS:.+]] = torch_c.to_builtin_tensor %arg1 : !torch.vtensor<[16,8],f32> -> tensor<16x8xf32> // CHECK-DAG: %[[ZERO:.+]] = arith.constant 0.000000e+00 : f32 diff --git a/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir b/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir index a16da0932640..46f80c06b4ce 100644 --- a/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir +++ b/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir @@ -54,6 +54,20 @@ func.func @eliminate_materializations$torch.Generator(%arg0: i64) -> i64 { // ----- +// CHECK-LABEL: func.func @eliminate_attributes() +// CHECK-NOT: attributes +// CHECK-NOT: torch.onnx_meta +func.func @eliminate_attributes() attributes { + torch.onnx_meta.ir_version = 8 : si64, + torch.onnx_meta.opset_version = 17 : si64, + torch.onnx_meta.producer_name = "pytorch", + torch.onnx_meta.producer_version = "2.1.0" +} { + return +} + +// ----- + func.func @unable_to_convert_lone_buffer_cast() -> tensor { // expected-error @+1 {{failed to legalize operation 'test.source'}} %0 = "test.source"() : () -> !torch.vtensor<[],f32> From e9cdd6cbc558880da2b6faa70b6c844bd7b5f494 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 13 Feb 2024 21:18:01 -0800 Subject: [PATCH 190/283] [torch] Fix tm_tensor.attention for end-to-end (#2907) Some operations include a backend matcher for specialized operations. We map these back to generics so they appropriately match to the high performance versions. This is done for the attention operation. --- .../Dialect/TMTensor/IR/TMTensorOps.td | 3 - .../TorchToTMTensor/TorchToTMTensor.cpp | 77 +++++++++++++--- lib/Dialect/TMTensor/IR/TMTensorOps.cpp | 43 +++++---- .../Torch/Transforms/ReduceOpVariants.cpp | 91 ++++++++++++++++++- projects/pt1/e2e_testing/xfail_sets.py | 3 +- .../linalg_on_tensors_backends/refbackend.py | 1 - .../torch_mlir_e2e_test/test_suite/basic.py | 24 ++--- test/Dialect/Torch/reduce-op-variants.mlir | 34 ++++++- 8 files changed, 226 insertions(+), 50 deletions(-) diff --git a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td index ac2c114ded74..50dc0c1a0403 100644 --- a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td +++ b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td @@ -313,9 +313,6 @@ def TMTensor_AttentionOp : TMTensor_Op<"attention", int64_t getOutputRank() { return getOutputType().getRank(); } - int64_t getIterationDomainRank() { - return 2; - }; // Method to implement for specifying output range for // DestinationStyleOpInterface std::pair getDpsInitsPositionRange() { diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index c669e8b6b8cc..4aa82420c38e 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -1600,27 +1600,82 @@ class ConvertAtenScaledDotProductAttentionOp "only default scale supported"); } + auto opTy = cast(op.getType()).toBuiltinTensor(); + auto query = adaptor.getQuery(); + auto value = adaptor.getValue(); + auto key = adaptor.getKey(); + auto queryTy = cast(query.getType()); + auto valueTy = cast(value.getType()); + auto keyTy = cast(key.getType()); + + if (queryTy.getRank() != valueTy.getRank() || + queryTy.getRank() != keyTy.getRank()) + return rewriter.notifyMatchFailure(op, "operand ranks do not match"); + + if (queryTy.getRank() < 3) + return rewriter.notifyMatchFailure(op, "missing batch dimension"); + + llvm::SmallVector reassociation(3); + for (int i = 0, s = valueTy.getRank() - 2; i < s; ++i) + reassociation.front().push_back(i); + reassociation[1].push_back(valueTy.getRank() - 2); + reassociation[2].push_back(valueTy.getRank() - 1); + + auto loc = op.getLoc(); + auto collapseBatch = [&rewriter, &reassociation, + loc](Value value) -> Value { + auto valueTy = cast(value.getType()); + if (valueTy.getRank() == 3) + return value; + + llvm::SmallVector newShape(3, 1); + newShape[1] = valueTy.getDimSize(valueTy.getRank() - 2); + newShape[2] = valueTy.getDimSize(valueTy.getRank() - 1); + + for (int i = 0, s = valueTy.getRank() - 2; i < s; ++i) { + if (valueTy.isDynamicDim(i)) { + newShape[0] = ShapedType::kDynamic; + break; + } + newShape[0] = newShape[0] * valueTy.getDimSize(i); + } + + auto collapseTy = valueTy.clone(newShape); + return rewriter.create(loc, collapseTy, value, + reassociation); + }; + + query = collapseBatch(query); + key = collapseBatch(key); + value = collapseBatch(value); + SmallVector outSizes( - adaptor.getQuery().getType().cast().getShape()); + query.getType().cast().getShape()); SmallVector valueSizes( - adaptor.getValue().getType().cast().getShape()); + value.getType().cast().getShape()); outSizes[outSizes.size() - 1] = valueSizes[valueSizes.size() - 1]; SmallVector outSizesDynamic( - getTensorSizes(rewriter, op.getLoc(), adaptor.getQuery())); - outSizesDynamic[outSizesDynamic.size() - 1] = getTensorSizes( - rewriter, op.getLoc(), adaptor.getValue())[valueSizes.size() - 1]; + getTensorSizes(rewriter, op.getLoc(), query)); + outSizesDynamic[outSizesDynamic.size() - 1] = + getTensorSizes(rewriter, op.getLoc(), value)[valueSizes.size() - 1]; Type outType = RankedTensorType::get(outSizes, elementType); Value output = createZeroInitTensor(rewriter, op.getLoc(), outSizesDynamic, elementType); // Overwrite with tm_tensor::attention - auto attention = rewriter.create( - op.getLoc(), outType, - SmallVector{adaptor.getQuery(), adaptor.getKey(), - adaptor.getValue()}, - SmallVector{output}); + Value attention = + rewriter + .create(loc, outType, + SmallVector{query, key, value}, + SmallVector{output}) + .getResult()[0]; + + if (opTy != outType) { + attention = rewriter.create(loc, opTy, attention, + reassociation); + } - rewriter.replaceOp(op, attention.getResult()); + rewriter.replaceOp(op, attention); return success(); } diff --git a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index ec399fe9633e..0b827893cae3 100644 --- a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -94,31 +94,22 @@ LogicalResult AttentionOp::verify() { ShapedType keyType = getKeyType(); ArrayRef queryShape = queryType.getShape(); ArrayRef keyShape = keyType.getShape(); - if (keyShape[0] != queryShape[0]) - return op->emitOpError("query and key batch mismatch"); - if (keyShape[2] != queryShape[2]) + for (int i = 0, s = queryShape.size() - 2; i < s; ++i) { + if (keyShape[i] != queryShape[i]) + return op->emitOpError("query and key batch mismatch"); + } + if (keyShape.back() != queryShape.back()) return op->emitOpError("query and key head dimension mismatch"); return success(); } SmallVector AttentionOp::getIterationDomain(OpBuilder &builder) { - int64_t iterationDomainRank = getIterationDomainRank(); - SmallVector loopBounds(iterationDomainRank); - Location loc = getLoc(); - Value zero = builder.create(loc, 0); - Value one = builder.create(loc, 1); - Value source = getQuery(); - for (auto dim : llvm::seq(0, iterationDomainRank)) { - loopBounds[dim].offset = zero; - loopBounds[dim].size = getDimValue(builder, loc, source, dim); - loopBounds[dim].stride = one; - } + SmallVector loopBounds; return loopBounds; } SmallVector AttentionOp::getLoopIteratorTypes() { - SmallVector iteratorTypes(getIterationDomainRank(), - utils::IteratorType::parallel); + SmallVector iteratorTypes; return iteratorTypes; } @@ -189,6 +180,8 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, Value zeroF = b.create(loc, elementType, b.getFloatAttr(elementType, 0.0)); + // TODO: This needs to be fixed, it assumes everything is dynamic however if + // any shapes are static the `memref.alloc` generated is illegal. SmallVector queryDynSizes, keyDynSizes, valueDynSizes, outputDynSizes; for (auto i = 0; i < queryRank; i++) queryDynSizes.push_back(b.create(loc, query, i)); @@ -204,9 +197,18 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, auto weightSizes = SmallVector(queryType.getShape()); weightSizes[weightRank - 1] = keySizes[keyRank - 2]; auto weightType = MemRefType::get(weightSizes, queryType.getElementType()); + + // Setup the weight dynamic sizes: SmallVector weightDynSizes(queryDynSizes); weightDynSizes[weightRank - 1] = keyDynSizes[keyRank - 2]; - Value weight = b.create(loc, weightType, weightDynSizes); + + SmallVector weightFilteredDynSizes; + for (int i = 0; i < weightRank; ++i) + if (weightSizes[i] == ShapedType::kDynamic) + weightFilteredDynSizes.push_back(weightDynSizes[i]); + + Value weight = + b.create(loc, weightType, weightFilteredDynSizes); matmul(b, loc, query, queryDynSizes, key, keyDynSizes, weight, weightDynSizes, /*transposed=*/true); @@ -259,12 +261,17 @@ LogicalResult AttentionOp::generateScalarImplementation(OpBuilder &b, x = b.create(loc, x); b.create(loc, x, weight, localIVs); }); + + llvm::SmallVector expWeightDynDims(weightFilteredDynSizes); + if (weightSizes.back() == ShapedType::kDynamic) + expWeightDynDims.resize(expWeightDynDims.size() - 1); + Value expWeightSum = b.create( loc, MemRefType::get( SmallVector(weightSizes.begin(), weightSizes.end() - 1), elementType), - SmallVector{weightDynSizes.begin(), weightDynSizes.end() - 1}); + expWeightDynDims); b.create( loc, SmallVector(weightRank - 1, zero), SmallVector{weightDynSizes.begin(), weightDynSizes.end() - 1}, diff --git a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp index 200f25c82c43..f8161de1fa0b 100644 --- a/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp +++ b/lib/Dialect/Torch/Transforms/ReduceOpVariants.cpp @@ -189,6 +189,78 @@ class ConvertHasValueSemanticsOpsToValueTensors : public RewritePattern { }; } // namespace +namespace { + +class TorchMatchSpecializedBackendOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + using HandlerFn = LogicalResult (*)(OperatorOp op, + ConversionPatternRewriter &rewriter); + + LogicalResult + matchAndRewrite(Torch::OperatorOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + if (namedHandlers.contains(op.getNameAttr())) { + return namedHandlers.lookup(op.getNameAttr()).front()(op, rewriter); + } + + return failure(); + } + + static void + populateSpecializedConversions(TorchMatchSpecializedBackendOp &matcher); + + static std::unique_ptr + getPopulatedMatcher(MLIRContext *context) { + auto matcher = std::make_unique(context); + populateSpecializedConversions(*matcher); + return matcher; + }; + + void populate(StringRef name, HandlerFn fn) { + namedHandlers[StringAttr::get(getContext(), name)].push_back(fn); + } + + void populateLegalizedNames(llvm::DenseSet &set) { + for (auto handle : namedHandlers) { + set.insert(handle.first); + } + } + +private: + DenseMap> namedHandlers; +}; + +void TorchMatchSpecializedBackendOp::populateSpecializedConversions( + TorchMatchSpecializedBackendOp &matcher) { + matcher.populate( + "torch.aten._scaled_dot_product_flash_attention_for_cpu", + [](Torch::OperatorOp op, + ConversionPatternRewriter &rewriter) -> LogicalResult { + auto uses = op.getResult(1).getUses(); + if (uses.end() == uses.begin()) { + auto oldOperands = op->getOperands(); + llvm::SmallVector newOperands{ + oldOperands[0], oldOperands[1], oldOperands[2], oldOperands[5], + oldOperands[3], oldOperands[4], oldOperands[6]}; + + auto newOp = rewriter.create( + op.getLoc(), op->getResultTypes()[0], newOperands, + op->getAttrs()); + rewriter.replaceAllUsesWith(op.getResult(0), newOp.getResult()); + rewriter.eraseOp(op); + return success(); + } + return failure(); + }); +} + +bool isSpecializedOperation(Torch::OperatorOp op) { return true; } +} // namespace + // Reduce Ops without value semantics but the corresponding without trailing // underscore variant doesn't exist. namespace { @@ -353,12 +425,24 @@ struct ReduceOpVariantsPass patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp); patterns.add(context); + // Create specialized matcher: + auto specialized = + TorchMatchSpecializedBackendOp::getPopulatedMatcher(context); + DenseSet specializedNames; + specialized->populateLegalizedNames(specializedNames); + patterns.insert(std::move(specialized)); + ConversionTarget target(*context); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable]( - Operation *op) { + target.markUnknownOpDynamicallyLegal([&extraLibraryModuleSymTable, + &specializedNames](Operation *op) { + if (isa(op)) { + if (specializedNames.contains(cast(op).getNameAttr())) { + return false; + } + } if (op->hasTrait() || (isa(op) && operatorOpHasValueSemantics(cast(op), @@ -377,6 +461,9 @@ struct ReduceOpVariantsPass if (op->hasTrait()) { return false; } + + if (isa(op) && isSpecializedOperation(cast(op))) + return false; return true; }); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 26f3e843954f..3048ac04a248 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -303,8 +303,7 @@ # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", - # Exception: Unsupported: node.meta['val'] is not a FakeTensor or list of FakeTensor's: _scaled_dot_product_flash_attention; - "ScaledDotProductAttentionSameModule_basic", + # AssertionError: Unregistered operation: torch.aten._scaled_dot_product_flash_attention_for_cpu "ScaledDotProductAttentionDifferentModule_basic", # AssertionError: Unregistered operation: torch.aten._embedding_bag_forward_only diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index c0b5eabf5f04..9c33d8fd504d 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -202,7 +202,6 @@ def compile(self, imported_module: Module): An opaque, backend specific compiled artifact object that can be passed to `load`. """ - run_pipeline_with_repro_report( imported_module, LOWERING_PIPELINE, "Lowering Linalg-on-Tensors IR to LLVM with RefBackend", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index c73d706f25cf..7e707893911a 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -4517,18 +4517,18 @@ def __init__(self): @export @annotate_args([ None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True) + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True), + ([-1, -1, -1], torch.float32, True) ]) def forward(self, query, key, value): return torch.ops.aten.scaled_dot_product_attention(query, key, value) @register_test_case(module_factory=lambda: ScaledDotProductAttentionSameModule()) def ScaledDotProductAttentionSameModule_basic(module, tu: TestUtils): - query = torch.randn(1, 1, 5, 5, dtype=torch.float32) - key = torch.randn(1, 1, 5, 5, dtype=torch.float32) - value = torch.randn(1, 1, 5, 5, dtype=torch.float32) + query = torch.randn(1, 5, 5, dtype=torch.float32) + key = torch.randn(1, 5, 5, dtype=torch.float32) + value = torch.randn(1, 5, 5, dtype=torch.float32) module.forward(query, key, value) class ScaledDotProductAttentionDifferentModule(torch.nn.Module): @@ -4539,18 +4539,18 @@ def __init__(self): @export @annotate_args([ None, - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True), - ([-1, -1, -1, -1], torch.float32, True) + ([2, 3, 8, 4], torch.float32, True), + ([2, 3, 16, 4], torch.float32, True), + ([2, 3, 16, 4], torch.float32, True) ]) def forward(self, query, key, value): return torch.ops.aten.scaled_dot_product_attention(query, key, value) @register_test_case(module_factory=lambda: ScaledDotProductAttentionDifferentModule()) def ScaledDotProductAttentionDifferentModule_basic(module, tu: TestUtils): - query = torch.randn(3, 2, 8, 4, dtype=torch.float32) - key = torch.randn(3, 2, 16, 4, dtype=torch.float32) - value = torch.randn(3, 2, 16, 4, dtype=torch.float32) + query = torch.randn(2, 3, 8, 4, dtype=torch.float32) + key = torch.randn(2, 3, 16, 4, dtype=torch.float32) + value = torch.randn(2, 3, 16, 4, dtype=torch.float32) module.forward(query, key, value) # ============================================================================== diff --git a/test/Dialect/Torch/reduce-op-variants.mlir b/test/Dialect/Torch/reduce-op-variants.mlir index 1122a7b3f844..94bec8aa2160 100644 --- a/test/Dialect/Torch/reduce-op-variants.mlir +++ b/test/Dialect/Torch/reduce-op-variants.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt -torch-reduce-op-variants %s | FileCheck %s +// RUN: torch-mlir-opt -torch-reduce-op-variants --split-input-file %s | FileCheck %s // CHECK-LABEL: func.func @convert_to_value_semantic_tensors( // CHECK-SAME: %[[ARG:.*]]: !torch.tensor<[],f32>) -> !torch.tensor<[],f32> { @@ -11,6 +11,8 @@ func.func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !t return %0 : !torch.tensor<[],f32> } +// ----- + // CHECK-LABEL: func.func @convert_to_value_semantic_tensors_list( // CHECK-SAME: %[[VT0:.*]]: !torch.vtensor, %[[VT1:.*]]: !torch.vtensor, // CHECK-SAME: %[[VT2:.*]]: !torch.vtensor) -> !torch.tensor { @@ -40,6 +42,8 @@ func.func @convert_to_value_semantic_tensors_list(%vt0: !torch.vtensor, %vt1: !t return %ret : !torch.tensor } +// ----- + // CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional( // CHECK-SAME: %[[INPUT:.*]]: !torch.tensor, %[[FLOAT_TENSOR:.*]]: !torch.tensor<[4],f32>, // CHECK-SAME: %[[TRAINING:.*]]: !torch.bool, %[[CUDNN_ENABLE:.*]]: !torch.bool, @@ -83,6 +87,8 @@ func.func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor, return %ret: !torch.tensor } +// ----- + // CHECK-LABEL: func.func @reduce_trailing_underscore_inplace_variant( // CHECK-SAME: %[[ARG0:.*]]: !torch.tensor<[2,2],f32>, // CHECK-SAME: %[[ARG1:.*]]: !torch.tensor<[2,2],f32>) -> (!torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>) { @@ -106,6 +112,7 @@ func.func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2] %0 = torch.aten.add_.Tensor %arg0, %arg1, %c1 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>, !torch.int -> !torch.tensor<[2,2],f32> return %0, %arg0 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32> } +// ----- // CHECK-LABEL: func.func @torch.tensor.literal() -> !torch.tensor { // CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<7xf32>) : !torch.vtensor<[7],f32> @@ -117,6 +124,8 @@ func.func @torch.tensor.literal() -> !torch.tensor { return %0 : !torch.tensor } +// ----- + // CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional_list( // CHECK-SAME: %[[SELF:.*]]: !torch.tensor<[5],f32>, // CHECK-SAME: %[[INDICES:.*]]: !torch.tensor<[2,3],si64>) -> !torch.tensor { @@ -134,6 +143,8 @@ func.func @convert_to_value_semantic_tensors_optional_list(%self: !torch.tensor< return %ret : !torch.tensor } +// ----- + // CHECK-LABEL: func.func @convert_to_value_semantic_tensors_optional_list_nones_and_tensors( // CHECK-SAME: %[[SELF:.*]]: !torch.tensor<[5],f32>, // CHECK-SAME: %[[INDICES_0:.*]]: !torch.tensor<[2,3],si64>, @@ -155,6 +166,8 @@ func.func @convert_to_value_semantic_tensors_optional_list_nones_and_tensors(%se return %ret : !torch.tensor } +// ----- + // CHECK-LABEL: func.func @torch.aten.bernoulli_.float( // CHECK-SAME: %[[T:.*]]: !torch.tensor) -> !torch.tensor { // CHECK: %[[GENERATOR:.*]] = torch.constant.none @@ -171,3 +184,22 @@ func.func @torch.aten.bernoulli_.float(%t: !torch.tensor) -> !torch.tensor { %ret = torch.aten.bernoulli_.float %t, %p, %generator : !torch.tensor, !torch.float, !torch.none -> !torch.tensor return %ret : !torch.tensor } + +// ----- + +// CHECK-LABEL: func.func @scaled_dot_product_flash_attention_for_cpu +// CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[1,1,5,5],f32>, %[[ARG1:.+]]: !torch.vtensor<[1,1,5,5],f32>, %[[ARG2:.+]]: !torch.vtensor<[1,1,5,5],f32> +// CHECK: %[[ZERO:.+]] = torch.constant.float 0.000000e+00 +// CHECK: %[[FALSE:.+]] = torch.constant.bool false +// CHECK: %[[NONE0:.+]] = torch.constant.none +// CHECK: %[[NONE1:.+]] = torch.constant.none +// CHECK: %[[ATTEN:.+]] = torch.aten.scaled_dot_product_attention %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[NONE0]], %[[ZERO]], %[[FALSE]], %[[NONE1]] +// CHECK: return %[[ATTEN]] +func.func @scaled_dot_product_flash_attention_for_cpu(%arg0: !torch.vtensor<[1,1,5,5],f32>, %arg1: !torch.vtensor<[1,1,5,5],f32>, %arg2: !torch.vtensor<[1,1,5,5],f32>) -> !torch.vtensor<[1,1,5,5],f32> { + %float0.000000e00 = torch.constant.float 0.000000e+00 + %false = torch.constant.bool false + %none = torch.constant.none + %none_0 = torch.constant.none + %0:2 = torch.operator "torch.aten._scaled_dot_product_flash_attention_for_cpu"(%arg0, %arg1, %arg2, %float0.000000e00, %false, %none, %none_0) : (!torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,5,5],f32>, !torch.float, !torch.bool, !torch.none, !torch.none) -> (!torch.vtensor<[1,1,5,5],f32>, !torch.vtensor<[1,1,5],f32>) + return %0#0 : !torch.vtensor<[1,1,5,5],f32> +} From d6d1a173dcebd7c0d62863673c4aee08a0c3853b Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Wed, 14 Feb 2024 11:58:09 +0530 Subject: [PATCH 191/283] [MLIR][Torch] Add OnnxToTorch and TorchToLinalg support for trig ops (#2903) This commit adds the OnnxToTorch lowering for cosh, acosh, asin, asinh, and atanh op. This commit also adds the TorchToLinalg lowering for acosh, asin, asinh, and atanh op. Signed-Off By: Vivek Khandelwal --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 360 +++++++++++++----- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 59 ++- .../TorchToLinalg/Uncategorized.cpp | 97 +++-- .../Transforms/AbstractInterpLibrary.cpp | 70 +++- .../build_tools/abstract_interp_lib_gen.py | 56 ++- .../build_tools/torch_ods_gen.py | 8 +- .../test_suite/elementwise.py | 176 +++++++++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 81 ++++ 8 files changed, 746 insertions(+), 161 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index adf5e8396751..0becb668636e 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -19,96 +19,6 @@ //===----------------------------------------------------------------------===// -def Torch_AtenTanhOp : Torch_Op<"aten.tanh", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::tanh : (Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenTanhOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenTanhOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - -def Torch_AtenTanh_Op : Torch_Op<"aten.tanh_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::tanh_ : (Tensor) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self - ); - let results = (outs - Torch_NonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenTanh_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenTanh_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - -def Torch_AtenCoshOp : Torch_Op<"aten.cosh", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::cosh : (Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenCoshOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenCoshOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - -def Torch_AtenCosh_Op : Torch_Op<"aten.cosh_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::cosh_ : (Tensor) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self - ); - let results = (outs - Torch_NonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenCosh_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenCosh_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - def Torch_AtenHardtanhOp : Torch_Op<"aten.hardtanh", [ AllowsTypeRefinement, HasValueSemantics, @@ -886,6 +796,96 @@ def Torch_AtenSin_Op : Torch_Op<"aten.sin_", [ }]; } +def Torch_AtenAsinOp : Torch_Op<"aten.asin", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::asin : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAsinOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAsinOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenAsin_Op : Torch_Op<"aten.asin_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::asin_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAsin_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAsin_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenAsinhOp : Torch_Op<"aten.asinh", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::asinh : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAsinhOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAsinhOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenAsinh_Op : Torch_Op<"aten.asinh_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::asinh_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAsinh_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAsinh_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenExpOp : Torch_Op<"aten.exp", [ AllowsTypeRefinement, HasValueSemantics, @@ -1021,6 +1021,51 @@ def Torch_AtenCos_Op : Torch_Op<"aten.cos_", [ }]; } +def Torch_AtenCoshOp : Torch_Op<"aten.cosh", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::cosh : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCoshOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenCoshOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenCosh_Op : Torch_Op<"aten.cosh_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::cosh_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCosh_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenCosh_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenAcosOp : Torch_Op<"aten.acos", [ AllowsTypeRefinement, HasValueSemantics, @@ -1066,6 +1111,51 @@ def Torch_AtenAcos_Op : Torch_Op<"aten.acos_", [ }]; } +def Torch_AtenAcoshOp : Torch_Op<"aten.acosh", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::acosh : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAcoshOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAcoshOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenAcosh_Op : Torch_Op<"aten.acosh_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::acosh_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAcosh_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAcosh_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenTanOp : Torch_Op<"aten.tan", [ AllowsTypeRefinement, HasValueSemantics, @@ -1111,6 +1201,51 @@ def Torch_AtenTan_Op : Torch_Op<"aten.tan_", [ }]; } +def Torch_AtenTanhOp : Torch_Op<"aten.tanh", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::tanh : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTanhOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenTanhOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenTanh_Op : Torch_Op<"aten.tanh_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::tanh_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenTanh_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenTanh_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenAtanOp : Torch_Op<"aten.atan", [ AllowsTypeRefinement, HasValueSemantics, @@ -1156,6 +1291,51 @@ def Torch_AtenAtan_Op : Torch_Op<"aten.atan_", [ }]; } +def Torch_AtenAtanhOp : Torch_Op<"aten.atanh", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::atanh : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAtanhOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAtanhOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenAtanh_Op : Torch_Op<"aten.atanh_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::atanh_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenAtanh_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenAtanh_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenAtan2Op : Torch_Op<"aten.atan2", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index e39c42b50422..e8c36d8cad54 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -103,7 +103,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); - // TODO: Acosh unimplemented in torch-mlir // Add became forward compatible with Torch in version 7. patterns.onOp("Add", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -203,9 +202,28 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand, constAxis, constKeepDims); return success(); }); - // TODO: Asin unimplemented in torch-mlir - // TODO: Asinh unimplemented in torch-mlir - // TODO: Atanh unimplemented in torch-mlir + patterns.onOp("Asin", 7, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); + patterns.onOp("Asinh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp("Atan", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -217,6 +235,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); + patterns.onOp("Atanh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp("Acos", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -228,6 +257,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); + patterns.onOp("Acosh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp("BatchNormalization", 15, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; @@ -1041,6 +1081,17 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); + patterns.onOp("Cosh", 9, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + rewriter.replaceOpWithNewOp( + binder.op, resultType, operand); + return success(); + }); patterns.onOp( "CumSum", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Location loc = binder.getLoc(); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index e4b683d41cee..25c2a93d2797 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -216,22 +216,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return b.create(loc, payloadArgs[0]); if (isa(op)) return b.create(loc, payloadArgs[0]); - if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); - } - if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); - } - if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); - } - if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); - } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); @@ -276,18 +260,50 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } - if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } if (isa(op)) { return createCalculationForMathOpWithDtypeConversion( b, converter, payloadArgs[0], op); } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } + if (isa(op)) { + return createCalculationForMathOpWithDtypeConversion( + b, converter, payloadArgs[0], op); + } if (auto clone = dyn_cast(op)) { int64_t memoryFormat; if (!clone.getMemoryFormat().getType().isa() && @@ -1505,7 +1521,8 @@ class ConvertElementwiseOp : public ConversionPattern { AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, - AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenRealOp, AtenImagOp, + AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenAtanhOp, AtenAcoshOp, + AtenAsinOp, AtenAsinhOp, AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); @@ -2350,27 +2367,27 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( ConversionTarget &target) { MLIRContext *context = patterns.getContext(); target.addIllegalOp< - AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenReluOp, AtenGeluOp, - AtenGeluBackwardOp, AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, - AtenDivTensorModeOp, AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, - AtenMinimumOp, AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, - AtenClampTensorOp, AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, - AtenFloorOp, AtenCeilOp, AtenPreluOp, AtenPowScalarOp, - AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, - AtenLog1pOp, AtenRsqrtOp, AtenAbsOp, AtenReciprocalOp, - AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, - AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp, - AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp, - AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, - AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, - AtenLtTensorOp, AtenLeTensorOp, AtenThresholdOp, AtenThresholdBackwardOp, - AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, - AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, - AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, - AtenTrilOp, AtenRemainderScalarOp, AtenRemainderTensorOp, - AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, - AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, - AtenQuantizePerTensorOp>(); + AtenTanOp, AtenTanhOp, AtenSinhOp, AtenCoshOp, AtenAtanhOp, AtenAcoshOp, + AtenAsinOp, AtenAsinhOp, AtenReluOp, AtenGeluOp, AtenGeluBackwardOp, + AtenAddTensorOp, AtenMulTensorOp, AtenDivTensorOp, AtenDivTensorModeOp, + AtenSubTensorOp, AtenLerpTensorOp, AtenSigmoidOp, AtenMinimumOp, + AtenAtan2Op, AtenMaximumOp, AtenToDtypeOp, AtenClampOp, AtenClampTensorOp, + AtenRsubScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, + AtenCeilOp, AtenPreluOp, AtenPowScalarOp, AtenPowTensorScalarOp, + AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, + AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, + AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, + AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp, + AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, + AtenLeScalarOp, AtenWhereSelfOp, AtenGtTensorOp, AtenGeTensorOp, + AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, + AtenThresholdOp, AtenThresholdBackwardOp, AtenHardtanhBackwardOp, + AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, + AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp, + AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp, + AtenRemainderScalarOp, AtenRemainderTensorOp, AtenBitwiseNotOp, + AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRealOp, AtenImagOp, + AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 320f53f0b7b6..29c94304288b 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6306,11 +6306,19 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %18 = torch.aten.append.t %7, %17 : !torch.list, !torch.int -> !torch.list\n" " return %7 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.tan\"(%arg0: !torch.list) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.sin\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.atan\"(%arg0: !torch.list) -> !torch.list {\n" +" func.func @\"__torch_mlir_shape_fn.aten.asin\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.asinh\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.cos\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" @@ -6318,10 +6326,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.acos\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.acosh\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.tan\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.tanh\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.atan\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" +" func.func @\"__torch_mlir_shape_fn.aten.atanh\"(%arg0: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.erf\"(%arg0: !torch.list) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -6358,18 +6386,6 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" -" func.func @\"__torch_mlir_shape_fn.aten.sin\"(%arg0: !torch.list) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" -" func.func @\"__torch_mlir_shape_fn.aten.cos\"(%arg0: !torch.list) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" -" func.func @\"__torch_mlir_shape_fn.aten.acos\"(%arg0: !torch.list) -> !torch.list {\n" -" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" -" return %0 : !torch.list\n" -" }\n" " func.func @\"__torch_mlir_shape_fn.aten.cosine_similarity\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.float) -> !torch.list {\n" " %none = torch.constant.none\n" " %int1 = torch.constant.int 1\n" @@ -9371,6 +9387,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = torch.prim.ListConstruct %int5, %int15, %int6, %int7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.acosh\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.tanh\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" @@ -9391,6 +9412,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.asin\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.asinh\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.cos\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1 = call @__torch__._get_dtype_of_floating_point_op(%0#1) : (!torch.int) -> !torch.int\n" @@ -12473,6 +12504,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %2 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.atanh\"(%arg0: !torch.tuple) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.prim.If %1 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %0#1 : !torch.int\n" +" }\n" +" return %2 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.linear\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.optional>) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 7fe6e8457fe8..c014808af97a 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -89,18 +89,39 @@ def aten〇diagonal〡shape(self: List[int], offset: int = 0, dim1: int = 0, dim return diagonal -def aten〇tan〡shape(self: List[int]) -> List[int]: +def aten〇sin〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) -def aten〇atan〡shape(self: List[int]) -> List[int]: +def aten〇asin〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇asinh〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇cos〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) def aten〇cosh〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇acos〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇acosh〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇tan〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇tanh〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇atan〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + +def aten〇atanh〡shape(self: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇erf〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) @@ -128,15 +149,6 @@ def aten〇exp〡shape(self: List[int]) -> List[int]: def aten〇expm1〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) -def aten〇sin〡shape(self: List[int]) -> List[int]: - return upstream_shape_functions.unary(self) - -def aten〇cos〡shape(self: List[int]) -> List[int]: - return upstream_shape_functions.unary(self) - -def aten〇acos〡shape(self: List[int]) -> List[int]: - return upstream_shape_functions.unary(self) - def aten〇cosine_similarity〡shape(x1: List[int], x2: List[int], dim: int = 1, eps: float = 1e-08) -> List[int]: broadcast = upstream_shape_functions.broadcast(x1, x2) return broadcast[:dim] + broadcast[dim + 1:] @@ -1856,6 +1868,11 @@ def aten〇cosh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇acosh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇tanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: @@ -1878,6 +1895,16 @@ def aten〇sin〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return _get_dtype_of_floating_point_op(self_dtype) +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇asin〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇asinh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + return _get_dtype_of_floating_point_op(self_dtype) + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇cos〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype @@ -4191,6 +4218,13 @@ def aten〇atan〡dtype(self_rank_dtype: Tuple[int, int]) -> int: return torch.float32 return self_dtype +@check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) +def aten〇atanh〡dtype(self_rank_dtype: Tuple[int, int]) -> int: + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.float32 + return self_dtype + @check_dtype_function(_check_two_tensor_op()) def aten〇linear〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Tuple[int, int], bias_rank_dtype: Optional[Tuple[int, int]] = None) -> int: input_rank, input_dtype = input_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 6f674601393d..65e9f44c1126 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -255,8 +255,6 @@ def emit_with_mutating_variants(key, **kwargs): # Elementwise tensor compute ops for key in [ - "aten::tanh : (Tensor) -> (Tensor)", - "aten::cosh : (Tensor) -> (Tensor)", "aten::hardtanh : (Tensor, Scalar, Scalar) -> (Tensor)", "aten::elu : (Tensor, Scalar, Scalar, Scalar) -> (Tensor)", "aten::relu : (Tensor) -> (Tensor)", @@ -274,12 +272,18 @@ def emit_with_mutating_variants(key, **kwargs): "aten::erfinv : (Tensor) -> (Tensor)", "aten::silu : (Tensor) -> (Tensor)", "aten::sin : (Tensor) -> (Tensor)", + "aten::asin : (Tensor) -> (Tensor)", + "aten::asinh : (Tensor) -> (Tensor)", "aten::exp : (Tensor) -> (Tensor)", "aten::expm1 : (Tensor) -> (Tensor)", "aten::cos : (Tensor) -> (Tensor)", + "aten::cosh : (Tensor) -> (Tensor)", "aten::acos : (Tensor) -> (Tensor)", + "aten::acosh : (Tensor) -> (Tensor)", "aten::tan : (Tensor) -> (Tensor)", + "aten::tanh : (Tensor) -> (Tensor)", "aten::atan : (Tensor) -> (Tensor)", + "aten::atanh : (Tensor) -> (Tensor)", "aten::atan2 : (Tensor, Tensor) -> (Tensor)", "aten::neg : (Tensor) -> (Tensor)", "aten::ceil : (Tensor) -> (Tensor)", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index c1a827ffe108..2f74ceb84416 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -107,6 +107,182 @@ def ElementwiseCoshIntModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseAcoshModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.acosh(a) + + +@register_test_case(module_factory=lambda: ElementwiseAcoshModule()) +def ElementwiseAcoshModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseAcoshIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.acosh(a) + + +@register_test_case(module_factory=lambda: ElementwiseAcoshIntModule()) +def ElementwiseAcoshIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + +# ============================================================================== + + +class ElementwiseAsinModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.asin(a) + + +@register_test_case(module_factory=lambda: ElementwiseAsinModule()) +def ElementwiseAsinModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseAsinIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.asin(a) + + +@register_test_case(module_factory=lambda: ElementwiseAsinIntModule()) +def ElementwiseAsinIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + +# ============================================================================== + + +class ElementwiseAsinhModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.asinh(a) + + +@register_test_case(module_factory=lambda: ElementwiseAsinhModule()) +def ElementwiseAsinhModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseAsinhIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.asinh(a) + + +@register_test_case(module_factory=lambda: ElementwiseAsinhIntModule()) +def ElementwiseAsinhIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + +# ============================================================================== + + +class ElementwiseAtanhModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.atanh(a) + + +@register_test_case(module_factory=lambda: ElementwiseAtanhModule()) +def ElementwiseAtanhModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4)) + + +# ============================================================================== + + +class ElementwiseAtanhIntModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int32, True), + ]) + def forward(self, a): + return torch.atanh(a) + + +@register_test_case(module_factory=lambda: ElementwiseAtanhIntModule()) +def ElementwiseAtanhIntModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=1, high=10).to(torch.int32)) + + +# ============================================================================== + + class ElementwiseBinaryModule(torch.nn.Module): def __init__(self): diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 2ee21c1e3841..3e4a476dbfbb 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -127,6 +127,15 @@ func.func @test_atan(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // ----- +// CHECK-LABEL: @test_atanh +func.func @test_atanh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.atanh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Atanh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + // CHECK-LABEL: @test_acos func.func @test_acos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.acos %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> @@ -558,6 +567,78 @@ func.func @test_cos(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5 // ----- +// CHECK-LABEL: @test_cosh_example +func.func @test_cosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.cosh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Cosh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: @test_cosh +func.func @test_cosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.cosh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Cosh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_acosh_example +func.func @test_acosh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.acosh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Acosh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: @test_acosh +func.func @test_acosh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.acosh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Acosh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_asin_example +func.func @test_asin_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.asin %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Asin"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: @test_asin +func.func @test_asin(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.asin %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Asin"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + +// CHECK-LABEL: @test_asinh_example +func.func @test_asinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.asinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> + %0 = torch.operator "onnx.Asinh"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> + return %0 : !torch.vtensor<[3],f32> +} + +// ----- + +// CHECK-LABEL: @test_asinh +func.func @test_asinh(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.asinh %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> + %0 = torch.operator "onnx.Asinh"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> + return %0 : !torch.vtensor<[3,4,5],f32> +} + +// ----- + // CHECK-LABEL: @test_dequantizelinear_si8 func.func @test_dequantizelinear_si8(%arg0: !torch.vtensor<[6],si8>, %arg1: !torch.vtensor<[],f32>, %arg2: !torch.vtensor<[],si8>) -> !torch.vtensor<[6],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 19 : si64} { %0 = torch.operator "onnx.DequantizeLinear"(%arg0, %arg1, %arg2) : (!torch.vtensor<[6],si8>, !torch.vtensor<[],f32>, !torch.vtensor<[],si8>) -> !torch.vtensor<[6],f32> From e7a09440d380827e90b94ef33bd82f32fda8874a Mon Sep 17 00:00:00 2001 From: Ean Garvey <87458719+monorimet@users.noreply.github.com> Date: Wed, 14 Feb 2024 12:31:37 -0600 Subject: [PATCH 192/283] Bump torch to pytorch/pytorch@b51e024 (#2909) This version of pytorch includes a patch to enable dynamo support on Windows, so I would like to sync on this torch version across torch-mlir/shark-turbine for a seamless Windows import flow. --- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index d78b0d0694d4..c23d10cf50fb 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -3cbc8e89fd09b0ffb4914187b438f15c121e2302 +b51e0246b7f119770c47183b230c553f15ab4fbb diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 540e78dccd49..25546907e856 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.3.0.dev20240207 +torch==2.3.0.dev20240214 From 77b7550997e91f19f27af085f8ae531e696e6406 Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Wed, 14 Feb 2024 16:24:25 -0600 Subject: [PATCH 193/283] Add support for bfloat16 in fximporter (#2896) this introduces an additional soft dependency on the python ml_dtypes python packages in order to support bfloat16 Addresses #2843 --- python/torch_mlir/extras/fx_importer.py | 16 ++++++++++++++-- test-requirements.txt | 2 +- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index b70487ad5ad9..5677ee4f75ba 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -43,6 +43,13 @@ Graph, GraphModule, ) +try: + import ml_dtypes +except ModuleNotFoundError: + # The third-party ml_dtypes package provides some optional + # low precision data-types. If used in this file, it is + # conditional. + ml_dtypes = None from torch.fx.node import ( Argument as NodeArgument, @@ -137,7 +144,6 @@ torch.int16: np.int16, torch.int32: np.int32, torch.int64: np.int64, - # torch.bf16: None, there's no equivalent np datatype so this isn't supported right now torch.float16: np.float16, torch.float32: np.float32, torch.float64: np.float64, @@ -146,6 +152,8 @@ torch.complex64: np.complex64, torch.complex128: np.complex128, } +if ml_dtypes is not None: + TORCH_DTYPE_TO_NPY_TYPE[torch.bfloat16] = ml_dtypes.bfloat16 TORCH_DTYPE_TO_INT = { torch.uint8: 0, @@ -1090,6 +1098,10 @@ def _make_vtensor_literal_op( ) -> Operation: mapping = py_attr_tracker.track(tensor) if mapping.is_empty: + # check support for bfloat16 + assert ( + not (tensor.dtype == torch.bfloat16 and ml_dtypes is None) + ), f"torch.bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" # Resolve the attribute. npy_dtype = TORCH_DTYPE_TO_NPY_TYPE.get(tensor.dtype) assert ( @@ -1115,7 +1127,7 @@ def _make_vtensor_literal_op( type=element_type, array=np_tensor, shape=np_tensor.shape ) else: - bytes_view = memoryview(np_tensor) + bytes_view = np_tensor.view(npy_dtype) tensor_type = create_mlir_tensor_type(tensor) shape_desc = "_".join([str(d) for d in tensor.shape]) blob_name = f"torch_tensor_{shape_desc}_{str(tensor.dtype)}" diff --git a/test-requirements.txt b/test-requirements.txt index 315e021308e8..c8e8e2bc6e5a 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -1,4 +1,4 @@ pillow dill multiprocess -onnx==1.15.0 \ No newline at end of file +onnx==1.15.0 From f3b38e5d1214afa6046ad21543ab3bb10d2d3b98 Mon Sep 17 00:00:00 2001 From: Ze Zhang Date: Wed, 14 Feb 2024 18:18:11 -0800 Subject: [PATCH 194/283] DecomposeComplexOps: update parseEquation to skip space char for AtenEinsumOp op (#2910) Just a minor update to skip the space char if included in the equation string --------- Co-authored-by: Ze Zhang --- lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index bc5276dca6a7..abd716c56afa 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -181,7 +181,7 @@ static bool parseEquation(const std::string &equation, inputToken.clear(); currentVariable = kIsResult; index++; - } else { + } else if (equation[index] != ' ') { return false; } index++; From 8e2e5eeae991c825496e22470e3d3fb766d54a66 Mon Sep 17 00:00:00 2001 From: saienduri <77521230+saienduri@users.noreply.github.com> Date: Wed, 14 Feb 2024 21:00:52 -0800 Subject: [PATCH 195/283] add support for decomposition (#2879) This commit adds decomposition support into the core aten operators before importing the module from torch. Also, this commit deals with the lifted tensor constants in torch.export.export(). We don't want to add unnecessary placeholder nodes in the graph (extra args in the block module), and should treat them like the constants that they are. The unnecessary clone is also removed for max efficiency. --- python/CMakeLists.txt | 1 + python/torch_mlir/extras/fx_decomp_util.py | 50 ++++++++++++++++++++++ python/torch_mlir/fx.py | 3 ++ 3 files changed, 54 insertions(+) create mode 100644 python/torch_mlir/extras/fx_decomp_util.py diff --git a/python/CMakeLists.txt b/python/CMakeLists.txt index 6300df01e4ec..e52135599864 100644 --- a/python/CMakeLists.txt +++ b/python/CMakeLists.txt @@ -44,6 +44,7 @@ declare_mlir_python_sources(TorchMLIRPythonSources.PublicAPI ADD_TO_PARENT TorchMLIRPythonSources SOURCES fx.py + extras/fx_decomp_util.py ) declare_mlir_python_sources(TorchMLIRPythonSources.Tools diff --git a/python/torch_mlir/extras/fx_decomp_util.py b/python/torch_mlir/extras/fx_decomp_util.py new file mode 100644 index 000000000000..47a79f95597e --- /dev/null +++ b/python/torch_mlir/extras/fx_decomp_util.py @@ -0,0 +1,50 @@ +import torch +from torch._decomp import get_decompositions + +# default decompositions pulled from SHARK / torch._decomp +DEFAULT_DECOMPOSITIONS = [ + torch.ops.aten.embedding_dense_backward, + torch.ops.aten.native_layer_norm_backward, + torch.ops.aten.slice_backward, + torch.ops.aten.select_backward, + torch.ops.aten.norm.ScalarOpt_dim, + torch.ops.aten.native_group_norm, + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.split.Tensor, + torch.ops.aten.split_with_sizes, + torch.ops.aten.native_layer_norm, + torch.ops.aten.masked_fill.Tensor, + torch.ops.aten.masked_fill.Scalar, + torch.ops.aten.t, + torch.ops.aten.addmm, + # decompositions that aid us in handling nn.BatchNorm2d + torch.ops.aten._native_batch_norm_legit_functional, + torch.ops.aten._native_batch_norm_legit_no_training, + torch.ops.aten._native_batch_norm_legit, + torch.ops.aten._native_batch_norm_legit.no_stats, + torch.ops.aten.squeeze.dims, + # decompositions for miscellaneous ops that are not handled in torch-mlir but have available decompositions + torch.ops.aten.soft_margin_loss, + torch.ops.aten.im2col, + torch.ops.aten._euclidean_dist, + torch.ops.aten.index_copy, + torch.ops.aten.index_copy_, + torch.ops.aten.grid_sampler_2d, + torch.ops.aten.log_sigmoid_forward, + torch.ops.aten.unsafe_split.Tensor, + torch.ops.aten.binary_cross_entropy, + torch.ops.aten.dot, + torch.ops.aten._adaptive_avg_pool2d, + torch.ops.aten._prelu_kernel, + torch.ops.aten.full, + torch.ops.aten._log_softmax, + torch.ops.aten.nll_loss_forward, + torch.ops.aten.nll_loss_backward, + torch.ops.aten._to_copy, + torch.ops.aten._log_softmax_backward_data, + torch.ops.aten.lift_fresh_copy.default, + torch.ops.aten._unsafe_index.Tensor, +] + +def get_decomposition_table(): + return get_decompositions(DEFAULT_DECOMPOSITIONS) diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 78b46cc3ea29..3abb70261db8 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -7,6 +7,7 @@ from torch_mlir.extras.fx_importer import FxImporter from torch_mlir import ir from torch_mlir.dialects import torch as torch_d +from torch_mlir.extras.fx_decomp_util import get_decomposition_table def export_and_import( f, @@ -21,5 +22,7 @@ def export_and_import( if fx_importer is None: fx_importer = FxImporter(context=context) prog = torch.export.export(f, args, kwargs, constraints=constraints) + decomp_table = get_decomposition_table() + prog = prog.run_decompositions(decomp_table) fx_importer.import_frozen_exported_program(prog) return fx_importer.module_op From f3e8199a6d2871312619608b839ccd8037b12264 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Fri, 16 Feb 2024 01:08:48 +0800 Subject: [PATCH 196/283] [Stablehlo] add refbackend (#2712) --- build_tools/ci/test_posix.sh | 4 + externals/stablehlo | 2 +- lib/CMakeLists.txt | 4 + lib/Conversion/TorchToStablehlo/Reduction.cpp | 56 +- lib/InitAll.cpp | 10 + projects/pt1/e2e_testing/main.py | 8 +- projects/pt1/e2e_testing/xfail_sets.py | 816 ++++++++---------- .../stablehlo_backends/linalg_on_tensors.py | 57 ++ 8 files changed, 482 insertions(+), 475 deletions(-) create mode 100644 projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py diff --git a/build_tools/ci/test_posix.sh b/build_tools/ci/test_posix.sh index 8cc68d77bd79..73818051d06b 100755 --- a/build_tools/ci/test_posix.sh +++ b/build_tools/ci/test_posix.sh @@ -20,6 +20,10 @@ echo "::group::Run TOSA e2e integration tests" python -m e2e_testing.main --config=tosa -v echo "::endgroup::" +echo "::group::Run Stablehlo e2e integration tests" +python -m e2e_testing.main --config=stablehlo -v +echo "::endgroup::" + case $torch_version in nightly) # Failing with: NotImplementedError: diff --git a/externals/stablehlo b/externals/stablehlo index e191eb4c3c3f..4ac26f8786d4 160000 --- a/externals/stablehlo +++ b/externals/stablehlo @@ -1 +1 @@ -Subproject commit e191eb4c3c3f3144503a8a117d760de5ddcc7e89 +Subproject commit 4ac26f8786d491c5d8376e6e563d1b72af09de75 diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 0db753e4746a..e4ba46138f34 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -31,6 +31,10 @@ set(LinkedLibs TorchMLIRTorchOnnxToTorch ) +if(TORCH_MLIR_ENABLE_STABLEHLO) +list(APPEND LinkedLibs StablehloPasses StablehloLinalgTransforms) +endif() + if(TORCH_MLIR_ENABLE_REFBACKEND) add_subdirectory(RefBackend) list(APPEND LinkedLibs TorchMLIRRefBackend) diff --git a/lib/Conversion/TorchToStablehlo/Reduction.cpp b/lib/Conversion/TorchToStablehlo/Reduction.cpp index e413fe532654..0b27d0748855 100644 --- a/lib/Conversion/TorchToStablehlo/Reduction.cpp +++ b/lib/Conversion/TorchToStablehlo/Reduction.cpp @@ -24,6 +24,9 @@ #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionOps.h" +#include +#include + using namespace mlir; using namespace mlir::torch; using namespace mlir::torch::Torch; @@ -116,6 +119,12 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, initIndex = hlo::getConstTensor(rewriter, op, {0}, {}).value(); } + std::vector outputShape(inputShape.begin(), inputShape.end()); + outputShape.erase(outputShape.begin() + dim); + auto outputTy = RankedTensorType::get(outputShape, inputElemTy); + auto outputIndexTy = + RankedTensorType::get(outputShape, rewriter.getIntegerType(64)); + auto inputShapeTensor = rewriter.create( op->getLoc(), inputShapeVec); auto indexTensor = rewriter.create( @@ -125,7 +134,8 @@ getMaxInDim(ConversionPatternRewriter &rewriter, Operation *op, Value &input, inputShapeTensor, static_cast(dim)); auto stablehloReduceOp = rewriter.create( - op->getLoc(), ValueRange{input, indexTensor}, + op->getLoc(), TypeRange{outputTy, outputIndexTy}, + ValueRange{input, indexTensor}, ValueRange{ initValue, initIndex, @@ -412,7 +422,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims)); + op.getLoc(), RankedTensorType::get({}, outTy.getElementType()), input, + initValue, rewriter.getDenseI64ArrayAttr(dims)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); @@ -473,7 +484,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return failure(); llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims)); + op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue, + rewriter.getDenseI64ArrayAttr(dims)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); @@ -535,7 +547,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( return failure(); llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims)); + op.getLoc(), RankedTensorType::get({}, inputElemTy), input, initValue, + rewriter.getDenseI64ArrayAttr(dims)); Block &block = stablehloReduceOp.getBody().emplaceBlock(); auto blockArgumentTy = RankedTensorType::get({}, inputTy.getElementType()); @@ -614,6 +627,14 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } } + std::unordered_set dimsSet(dims.begin(), dims.end()); + SmallVector reduceResultShape; + for (int64_t i = 0; i < inputTy.getRank(); i++) { + if (dimsSet.find(i) == dimsSet.end()) { + reduceResultShape.push_back(inputTy.getDimSize(i)); + } + } + bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { return rewriter.notifyMatchFailure(op, "non-bool keepdim unsupported"); @@ -625,7 +646,9 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( llvm::sort(dims.begin(), dims.end()); auto stablehloReduceOp = rewriter.create( - op.getLoc(), input, initValue, rewriter.getDenseI64ArrayAttr(dims)); + op.getLoc(), + RankedTensorType::get(reduceResultShape, outTy.getElementType()), input, + initValue, rewriter.getDenseI64ArrayAttr(dims)); Region ®ion = stablehloReduceOp.getBody(); Block &block = region.emplaceBlock(); @@ -714,6 +737,14 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( // stable with unordered dims. std::sort(dims.begin(), dims.end()); + std::unordered_set dimsSet(dims.begin(), dims.end()); + SmallVector reduceResultShape; + for (int64_t i = 0; i < inputRank; i++) { + if (dimsSet.find(i) == dimsSet.end()) { + reduceResultShape.push_back(inputType.getDimSize(i)); + } + } + bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { return rewriter.notifyMatchFailure( @@ -728,8 +759,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( } auto reduceOp = rewriter.create( - op->getLoc(), squareOp.getResult(), initValue, - rewriter.getDenseI64ArrayAttr(dims)); + op->getLoc(), RankedTensorType::get(reduceResultShape, inputElemType), + squareOp.getResult(), initValue, rewriter.getDenseI64ArrayAttr(dims)); Region ®ion = reduceOp.getBody(); Block &block = region.emplaceBlock(); @@ -832,6 +863,14 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( std::sort(dims.begin(), dims.end()); } + std::unordered_set dimsSet(dims.begin(), dims.end()); + SmallVector reduceResultShape; + for (int64_t i = 0; i < inputType.getRank(); i++) { + if (dimsSet.find(i) == dimsSet.end()) { + reduceResultShape.push_back(inputType.getDimSize(i)); + } + } + bool keepDim = false; if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepDim))) { return rewriter.notifyMatchFailure( @@ -848,7 +887,8 @@ LogicalResult ConvertAtenReductionOp::matchAndRewrite( ord, nullptr); auto reduceOp = rewriter.create( - op->getLoc(), powValue, initValue, rewriter.getDenseI64ArrayAttr(dims)); + op->getLoc(), RankedTensorType::get(reduceResultShape, outElemType), + powValue, initValue, rewriter.getDenseI64ArrayAttr(dims)); Region ®ion = reduceOp.getBody(); Block &block = region.emplaceBlock(); diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index ace6c1a40e74..eebfc940870c 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -29,6 +29,11 @@ #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "torch-mlir/RefBackend/Passes.h" +#ifdef TORCH_MLIR_ENABLE_STABLEHLO +#include "stablehlo/conversions/linalg/transforms/Passes.h" +#include "stablehlo/transforms/Passes.h" +#endif + void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) { registry.insert(); registry.insert(); @@ -52,6 +57,11 @@ void mlir::torch::registerAllPasses() { mlir::torch::onnx_c::registerTorchOnnxToTorchPasses(); mlir::torch::TMTensor::registerPasses(); +#ifdef TORCH_MLIR_ENABLE_STABLEHLO + mlir::stablehlo::registerChloLegalizeToStablehloPass(); + mlir::stablehlo::registerStablehloLegalizeToLinalgPass(); +#endif + #ifdef TORCH_MLIR_ENABLE_REFBACKEND mlir::torch::RefBackend::registerRefBackendPasses(); #endif diff --git a/projects/pt1/e2e_testing/main.py b/projects/pt1/e2e_testing/main.py index 885f344778f5..b9cd04c1e80b 100644 --- a/projects/pt1/e2e_testing/main.py +++ b/projects/pt1/e2e_testing/main.py @@ -25,6 +25,7 @@ from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend +from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend from .xfail_sets import ( LINALG_XFAIL_SET, @@ -43,7 +44,7 @@ register_all_tests() def _get_argparse(): - config_choices = ["native_torch", "torchscript", "linalg", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo"] + config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo"] parser = argparse.ArgumentParser(description="Run torchscript e2e tests.") parser.add_argument("-c", "--config", choices=config_choices, @@ -52,6 +53,7 @@ def _get_argparse(): Meaning of options: "linalg": run through torch-mlir"s default Linalg-on-Tensors backend. "tosa": run through torch-mlir"s default TOSA backend. +"stablehlo": run through torch-mlir"s default Stablehlo backend. "native_torch": run the torch.nn.Module as-is without compiling (useful for verifying model is deterministic; ALL tests should pass in this configuration). "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). "lazy_tensor_core": run the model through the Lazy Tensor Core frontend and execute the traced graph. @@ -90,6 +92,10 @@ def main(): config = LinalgOnTensorsBackendTestConfig(RefBackendLinalgOnTensorsBackend()) xfail_set = LINALG_XFAIL_SET crashing_set = set() + elif args.config == "stablehlo": + config = StablehloBackendTestConfig(LinalgOnTensorsStablehloBackend()) + xfail_set = all_test_unique_names - STABLEHLO_PASS_SET + crashing_set = STABLEHLO_CRASHING_SET elif args.config == "tosa": config = TosaBackendTestConfig(LinalgOnTensorsTosaBackend()) xfail_set = all_test_unique_names - TOSA_PASS_SET diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 3048ac04a248..36a1d5662810 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -378,84 +378,16 @@ } STABLEHLO_PASS_SET = { - "TileBigDimsSizeModule_basic", - "TileSmallDimsSizeModule_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", "AddIntModule_basic", - "AtenIntBoolOpModule_basic", - "AtenIntTensorByteDtypeModule_basic", - "AtenIntTensorCharDtypeModule_basic", - "BoolFloatFalseModule_basic", - "BoolFloatTrueModule_basic", - "BoolIntFalseModule_basic", - "BoolIntTrueModule_basic", - "CeilFloatModule_basic", - "DivFloatModule_basic", - "DivIntModule_basic", - "EqIntModule_basic", - "GeFloatIntModule_basic", - "GeFloatModule_basic", - "GeIntModule_basic", - "GtFloatIntModule_basic", - "GtIntModule_basic", - "MulIntModule_basic", - "NeFloatIntModule_basic", - "NeIntModule_basic", - "SqrtIntModule_basic", - "SubFloatModule_basic", - "SubIntModule_basic", - "TensorToBoolZeroRank_basic", - "TensorToIntZeroRank_basic", - "TensorToFloatZeroRank_basic", - "IndexTensorStaticContiguousWithNoneModule_basic", - "IndexTensorStaticNonContiguousWithNoneModule_basic", "AliasModule_basic", - "TensorIntModule_basic", "AllBoolFalseModule_basic", "AllBoolTrueModule_basic", "AnyBoolFalseModule_basic", "AnyBoolTrueModule_basic", - "AtenIntBoolOpConstFalseModule_basic", - "AtenIntBoolOpConstTrueModule_basic", - "AtenFloatScalarModule_basic", - "ScalarImplicitFloatModule_basic", - "ScalarImplicitIntModule_basic", - "AtenSubFloatModule_basic", - "BoolFloatConstantModule_basic", - "BoolIntConstantModule_basic", - "ContainsIntList_False", - "ContainsIntList_True", - "IntFloatModule_basic", - "IsFloatingPointFloat_True", - "IsFloatingPointInt_False", - "LenStrModule_basic", - "MeanDimAllReduceKeepdimModule_basic", - "MeanDimAllReduceModule_basic", - "MeanDimDtypeModule_basic", - "MeanDimKeepdimModule_basic", - "MeanDimModule_basic", - "MeanDimNegativeModule_basic", - "NumelZeroRankModule_basic", - "PowIntFloatModule_basic", - "PrimMaxIntModule_basic", - "PrimMinIntModule_basic", - "PrimMinIntDynamicModule_basic", - "SortIntListReverse_basic", - "SortIntList_basic", - "SqrtIntConstantModule_basic", - "StdBiasedModule_basic", - "StdDimBiasedModule_basic", - "TestMultipleTensorAndPrimitiveTypesReturn_basic", - "VarBiasedModule_basic", - "VarDimBiasedModule_basic", - "VarMeanBiasedModule_basic", - "VarMeanDimBiasedModule_basic", - "ConstantBoolParameterModule_basic", - "MaskedFillScalarIntValueStaticModule_basic", - "MaskedFillScalarFloatValueStaticModule_basic", - "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", - "AddSizeIntModule_basic", - "AddSizeIntNegDimModule_basic", "ArangeDtypeFloatModule_basic", "ArangeDtypeIntModule_basic", "ArangeFalsePinMemoryModule_basic", @@ -467,139 +399,161 @@ "ArangeStartIntModule_basic", "ArangeStartNegativeStepFloatModule_basic", "ArangeStartNegativeStepIntModule_basic", + "ArangeStartOutDtypeModule_basic", + "ArangeStartOutModule_basic", + "ArangeStartOutViewModule_basic", "ArangeStartStepFloatModule_basic", "ArangeStartStepIntModule_basic", "ArangeZeroElementOutputModule_basic", - "BatchMlpLayerModule_basic", - "BatchNorm1DModule_basic", - "BatchNorm1DWith2DInputModule_basic", - "BatchNorm2DModule_basic", - "BatchNorm3DModule_basic", - "BatchNorm1DStaticShapeModule_basic", - "ResNet18StaticModule_basic", - "BmmFloatModule_basic", - "BmmIntModule_basic", - "BroadcastToModule_basic", + "ArgmaxModule_with_dim", + "AtenComplex64Module_basic", + "AtenEyeMModuleCPUDevice_basic", + "AtenEyeMModuleDefaultDtype_basic", + "AtenEyeMModuleFalsePinMemory_basic", + "AtenEyeMModuleFloat2D_basic", + "AtenEyeMModuleInt2D_basic", + "AtenEyeModuleCPUDevice_basic", + "AtenEyeModuleDefaultDtype_basic", + "AtenEyeModuleFalsePinMemory_basic", + "AtenEyeModuleFloat2D_basic", + "AtenEyeModuleInt2D_basic", + "AtenFloatScalarModule_basic", + "AtenIntBoolOpConstFalseModule_basic", + "AtenIntBoolOpConstTrueModule_basic", + "AtenIntBoolOpModule_basic", + "AtenIntTensorByteDtypeModule_basic", + "AtenIntTensorCharDtypeModule_basic", + "AtenItemFpOpModule_basic", + "AtenItemIntOpModule_basic", + "AtenMmFloatTypes_basic", + "AtenMmIntTypes_basic", + "AtenRoundIntModule_basic", + "AtenSubFloatModule_basic", + "AtenToDeviceModule_basic", + "AvgPool1dStaticModule_basic", + "AvgPool2dStaticModule_basic", + "BaddbmmBroadcast1DInputModule_basic", + "BaddbmmBroadcast2DInputModule_basic", + "BaddbmmStaticModule_basic", + "BoolFloatConstantModule_basic", + "BoolFloatFalseModule_basic", + "BoolFloatTrueModule_basic", + "BoolIntConstantModule_basic", + "BoolIntFalseModule_basic", + "BoolIntTrueModule_basic", + "BoolTensorReturnFalseModule_basic", + "BoolTensorReturnMixedModule_basic", + "BoolTensorReturnTrueModule_basic", + "BroadcastListConstructWithMinusOneModule_basic", "BroadcastToSameRankStaticModule_basic", "BroadcastZeroRankInputStaticModule_basic", - "BroadcastListConstructWithMinusOneModule_basic", "BucketizeTensorStaticFloatModule_basic", "BucketizeTensorStaticModule_basic", + "CeilFloatModule_basic", + "ChunkListUnpackUneven_Module_basic", + "ChunkListUnpack_Module_basic", + "CloneModule_basic", + "ConstantBoolParameterModule_basic", + "ContainsIntList_False", + "ContainsIntList_True", + "ContiguousModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_basic", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", + "Conv2dWithPaddingDilationStrideStaticModule_grouped", + "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", + "Convolution2DStaticModule_basic", + "ConvolutionBackwardModule2DStatic_basic", + "ConvolutionBackwardModule2DStrided_basic", + "ConvolutionModule2DTransposeStridedStatic_basic", + "CosineSimilarityStaticBroadcastModule_basic", + "CosineSimilarityStaticModule_basic", + "CumsumInputDtypeInt32Module_basic", "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", - "CosineSimilarityStaticModule_basic", - "CosineSimilarityStaticBroadcastModule_basic", "DetachModule_basic", - "ElementwiseIsnanModule_basic", + "DivFloatModule_basic", + "DivIntModule_basic", + "DropoutEvalFloatModule_basic", + "DropoutEvalIntModule_basic", + "DropoutTrainStaticShapeModule_basic", + "EinsumStaticContractRhsModule_basic", + "EinsumStaticFourDimensionModule_basic", + "EinsumStaticModule_basic", + "ElementwiseAbsFloatModule_basic", + "ElementwiseAbsIntModule_basic", + "ElementwiseAddScalar_NumToTensorFloat_Module_basic", + "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", + "ElementwiseAtenIsinfOpModule_basic", + "ElementwiseAtenIsneginfOpModule_basic", + "ElementwiseAtenIsposinfOpModule_basic", "ElementwiseAtenLogicalAndOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseAtenLogicalNotOpModule_basic", "ElementwiseAtenLogicalNotOpPromoteModule_basic", "ElementwiseAtenLogicalOrOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseAtenLogicalXorOpPromoteBroadcastStaticShapeModule_basic", "ElementwiseAtenWhereSelfModule_basic", - "ElementwiseWhereScalarOtherStaticModule_basic", - "ElementwiseWhereScalarSelfStaticModule_basic", - "ElementwiseNanToNumModule_Basic", + "ElementwiseBinaryStaticShapeModule_basic", "ElementwiseBitwiseAndStaticShapeModule_basic", - "ElementwiseBitwiseNotInt64Module_basic", "ElementwiseBitwiseNotInt32Module_basic", - "ElementwiseOrTensorStaticShapeModule_basic", + "ElementwiseBitwiseNotInt64Module_basic", "ElementwiseBitwiseOrStaticShapeModule_basic", "ElementwiseBitwiseXorStaticShapeModule_basic", - "ElementwiseClampModule_basic", - "ElementwiseClampMinModule_basic", + "ElementwiseCeilModule_basic", "ElementwiseClampMaxModule_basic", - "ElementwiseSignModule_basic", - "ElementwisePowModule_basic", - "ElementwisePowTensorStaticModule_basic", - "ElementwisePowTensorBroadcastStaticModule_basic", + "ElementwiseClampMinModule_basic", + "ElementwiseClampModule_basic", + "ElementwiseClampTensorInt8Module_basic", + "ElementwiseCloneChannelsLastMemoryFormatModule_basic", + "ElementwiseCloneContiguousModule_basic", + "ElementwiseCloneModule_basic", + "ElementwiseCosModule_basic", + "ElementwiseErfModule_basic", "ElementwiseExpModule_basic", - "ElementwiseFlattenBroadcastModule_basic", - "ElementwiseLeakyReluModule_basic", - "ElementwiseEluModule_basic", - "ElementwiseEluNonDefaultModule_basic", - "ElementwiseSeluModule_basic", + "ElementwiseFloorIntModule_basic", + "ElementwiseFloorModule_basic", + "ElementwiseGeluModule_basic", + "ElementwiseLeakyReluStaticModule_basic", "ElementwiseLogModule_basic", + "ElementwiseNanToNumModule_Basic", + "ElementwiseNeFloatTensorStaticModule_basic", + "ElementwiseNeIntTensorStaticModule_basic", "ElementwiseNegModule_basic", + "ElementwiseOrTensorStaticShapeModule_basic", + "ElementwisePowTensorBroadcastStaticModule_basic", + "ElementwisePowTensorStaticModule_basic", + "ElementwiseReciprocalModule_basic", + "ElementwiseReluModule_basic", "ElementwiseRsqrtModule_basic", "ElementwiseSigmoidModule_basic", - "ElementwiseSqrtModule_basic", "ElementwiseSinModule_basic", - "ElementwiseCosModule_basic", - "ElementwiseCeilModule_basic", - "ElementwiseFloorModule_basic", - "ElementwiseUnaryModule_basic", - "ElementwiseUnsqueezeBroadcastModule_basic", - "ElementwiseUnsqueezeNegDimsModule_basic", + "ElementwiseSqrtModule_basic", "ElementwiseToDtypeF32ToI64Module_basic", - "ElementwiseAddModule_basic", - "ElementwiseAddScalarFloatModule_basic", - "ElementwiseAddScalarInt64Module_basic", - "ElementwiseAddScalarIntModule_basic", - "ElementwiseAddScalar_NumToTensorFloat_Module_basic", - "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", - "ElementwiseDivScalarModule_basic", - "ElementwiseAtenDivIntScalarModule_basic", - "ElementwiseEqDiffWidthScalarModule_basic", - "ElementwiseEqFloatScalarModule_basic", - "ElementwiseEqIntScalarModule_basic", - "ElementwiseNeFloatScalarModule_basic", - "ElementwiseNeFloatTensorStaticModule_basic", - "ElementwiseNeIntTensorStaticModule_basic", - "ElementwiseEqBoolScalarModule_basic", - "ElementwiseErfModule_basic", - "ElementwiseGeluModule_basic", - "ElementwiseGtFloatScalarModule_basic", - "ElementwiseGtIntScalarModule_basic", - "ElementwiseGtMixed2ScalarModule_basic", - "ElementwiseGeFloatIntScalarModule_basic", - "ElementwiseGeFloatScalarModule_basic", - "ElementwiseGeIntScalarModule_basic", - "ElementwiseGeMixedIntScalarModule_basic", - "ElementwiseLeakyReluStaticModule_basic", - "ElementwiseLeFloatIntScalarModule_basic", - "ElementwiseLeFloatScalarModule_basic", - "ElementwiseLeIntScalarModule_basic", - "ElementwiseLeMixedIntScalarModule_basic", - "ElementwiseLtDiffWidthScalarModule_basic", - "ElementwiseLtFloatScalarModule_basic", - "ElementwiseLtIntScalarModule_basic", - "ElementwiseMulScalarModule_basic", - "ElementwiseMulScalarModule_float", - "ElementwiseMulScalarModule_int", - "ElementwiseNeIntScalarModule_basic", - "ElementwiseReciprocalModule_basic", - "ElementwiseRelu6Module_basic", - "ElementwiseReluModule_basic", - "ElementwiseRemainderScalarModule_Bool_basic", - "ElementwiseRemainderScalarModule_Float_basic", - "ElementwiseRemainderScalarModule_Int_Float_basic", - "ElementwiseRemainderScalarModule_Int_basic", - "ElementwiseSubScalarFloatModule_basic", - "ElementwiseSubScalarIntModule_basic", - "ElementwiseWhereScalarModule_basic", - "ElementwiseAbsFloatModule_basic", - "ElementwiseAbsIntModule_basic", - "EmbeddingModule1DIndices_basic", - "EmbeddingModuleI32Static_basic", - "EmbeddingModuleI32_basic", - "EmbeddingModuleI64_basic", - "EmbeddingModuleF16_basic", + "ElementwiseToDtypeI64ToI8Module_basic", + "ElementwiseToDtypeIdentityModule_basic", + "ElementwiseUnaryModule_basic", + "ElementwiseWhereScalarOtherStaticModule_basic", + "ElementwiseWhereScalarSelfStaticModule_basic", "EmptyLikeMemoryFormatModule_basic", "EmptyLikeModule_defaultDtype", "EmptyLikeModule_falsePinMemory", "EmptyLikeModule_float", "EmptyLikeModule_int", + "EmptyModule_contiguous", + "EmptyModule_defaultDtype", + "EmptyModule_falsePinMemory", + "EmptyModule_float", + "EmptyModule_int", + "EmptyStridedModule_basic", + "EqIntModule_basic", "ExpandAsIntModule_basic", - "ExpandModule_basic", - "EinsumStaticModule_basic", - "EinsumStaticFourDimensionModule_basic", - "EinsumStaticContractRhsModule_basic", + "Fill_TensorFloat64WithFloat32Static_basic", "Fill_TensorFloat64WithFloat32_basic", "Fill_TensorFloat64WithFloat64_basic", - "Fill_TensorFloat64WithInt64_basic", - "Fill_TensorFloat64WithFloat32Static_basic", "Fill_TensorFloat64WithInt64Static_basic", + "Fill_TensorFloat64WithInt64_basic", + "FlattenRank0Module_basic", + "FlattenStaticModule_basic", "FlipModuleStaticShape_basic", "FlipNegativeIndexModule_basic", "FullLikeModuleDefaultDtype_basic", @@ -616,188 +570,67 @@ "FullModuleFloat3D_basic", "FullModuleInt2D_basic", "FullModuleInt3D_basic", - "NewFullModuleDefaultDtype_basic", - "NewFullModuleFalsePinMemory_basic", - "NewFullModuleFloat2D_basic", - "NewFullModuleFloat3DStatic_basic", - "NewFullModuleFloat3D_basic", - "NewFullModuleInt2DStatic_basic", - "NewFullModuleInt2D_basic", - "NewFullModuleInt3D_basic", - "GroupNormModule_basic", "GatherStaticModule_basic", - "GatherModule_basic", - "Gather2DInputModdule_basic", - "GatherRandomIndexModule_basic", - "GatherNegativeDimModule_basic", + "GeFloatIntModule_basic", + "GeFloatModule_basic", + "GeIntModule_basic", "GeluBackwardModule_basic", - "HardswishModule_basic", - "HardswishRandomModule_basic", - "HardTanhIntModule_basic", - "HardTanhModule_basic", - "HardsigmoidModule_basic", - "HardsigmoidRandomModule_basic", - "IndexSelectDynamicIndexSizeModule_basic", - "IndexSelectSingleIdxModule_basic", - "IndexSelectTwoIdxModule_basic", - "IndexSelectWholeDimensionModule_basic", - "IndexSelectWholeTensorModule_basic", - "IndexSelectNegativeDimModule_basic", - "IndexTensorStaticModule_basic", + "GluStaticModule_basic", + "GroupNormModule_basic", + "GroupNormNoWeightAndBiasModule_basic", + "GtFloatIntModule_basic", + "GtIntModule_basic", "IndexTensorMultiIndexStaticModule_basic", + "IndexTensorStaticContiguousWithNoneModule_basic", + "IndexTensorStaticModule_basic", + "IndexTensorStaticNonContiguousWithNoneModule_basic", + "IntFloatModule_basic", + "IsFloatingPointFloat_True", + "IsFloatingPointInt_False", "LayerNormLastDimModule_basic", "LayerNormModule_basic", "LayerNormNormalizeOverAllDimsModule_basic", "LeakyReluBackwardStaticModule_basic", - "LinalgVectorNormModule_basic", - "LinalgVectorNormKeepDimModule_basic", - "MatmulBroadcastBatchDim_basic", - "MatmulSingleDynamicBatchDim_basic", - "Matmul_3d", - "Matmul_4d", - "MeanDimEmptyDimModule_basic", - "MeanDtypeModule_basic", - "MeanDynamicSizesModule_basic", - "MeanLargeInputModule_basic", - "MeanModule_basic", - "Mlp1LayerModule_basic", - "Mlp2LayerModule_basic", - "MmTanhModule_basic", - "Mv_basic", - "NativeLayerNormModule4D_basic", - "NativeLayerNormModule_basic", - "OneHotModule_basic", - "PrimsConvertElementTypeModule_basic", - "ReduceFrobeniusNormKeepDimModule_basic", - "ReduceSumDimIntListElementTypeBoolModule_basic", - "ReduceSumElementTypeBoolModule_basic", - "ReduceSumDimIntListEmptyDimModule_basic", - "ReduceSumDimIntListDtypeFloatModule_basic", - "ReduceSumDimIntListDtypeIntModule_basic", - "ReduceSumDimIntListKeepDimFloatModule_basic", - "ReduceSumDimIntListKeepDimIntModule_basic", - "ReduceSumDtypeFloatModule_basic", - "ReduceSumDtypeIntModule_basic", - "ReduceL1NormModule_basic", - "ReduceL1NormWithDTypeModule_basic", - "ReduceL2NormModule_basic", - "ReduceL3NormAllDimsModule_basic", - "ReduceL3NormKeepDimModule_basic", - "ReduceLN3NormModule_basic", - "NormScalarOptDimKeepDimModule_basic", - "NormScalarOptDimModule_basic", - "NormalizeModule_basic", - "ScalarConstantTupleModule_basic", - "SelectIntModule_basic", - "SelectIntNegativeDimAndIndexStaticModule_basic", - "SliceSingleIdxModule_basic", - "SqueezeDimModule_dynamic", - "SqueezeDimModule_negDim", - "ToCopyBoolDTypeStaticModule_basic", - "ToCopyModule_basic", - "ToCopyWithDTypeFalsePinMemoryModule_basic", - "ToCopyWithDTypeModule_basic", - "ReduceFrobeniusNormModule_basic", - "FlattenStaticModule_basic", - "FlattenRank0Module_basic", - "TensorsConcatNegativeDimModule_basic", - "TensorsConcatPromoteDTypeModule_basic", - "TensorsConcatStaticModule_basic", - "TensorsConcatNegativeDimStaticModule_basic", - "TensorsStackModule_basic", - "TensorsStackNegativeDimModule_basic", - "TensorsStackPromoteDTypeModule_basic", + "LenStrModule_basic", "LiftFreshCopyModule_basic", - "Mlp2LayerModuleNoBias_basic", - "NumelModule_basic", - "SiluModule_basic", - "SquareModule_basic", - "SqueezeModule_allUnitDim", - "SqueezeDimModule_unitDim", - "ViewCollapseOnesMiddleModule_basic", - "ViewDoubleMergeStaticModule_basic", - "ViewExpandDynamicDimModule_basic", - "ViewFlattenAndExpandModule_basic", - "ViewFiveTestStaticModule_basic", - "ViewOffsetTestStaticModule_basic", - "ViewTwoFiveThreeStaticModule_basic", - "ViewTwoToThreeStaticModule_basic", - "ViewExpandOnesMiddleOppModule_basic", - "ViewOffsetBackwardTestStaticModule_basic", - "NumToTensorFloatModule_basic", - "AtenToDeviceModule_basic", - "AvgPool1dStaticModule_basic", - "AvgPool2dStaticModule_basic", - "Conv2dWithPaddingDilationStrideStaticModule_basic", - "Conv2dWithPaddingDilationStrideStaticModule_depthwise", - "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", - "Conv2dWithPaddingDilationStrideStaticModule_grouped", - "Conv2dWithPaddingDilationStrideStaticModule_grouped_multiplier", - "Convolution2DStaticModule_basic", - "ConvolutionModule2DTransposeStridedStatic_basic", - "ElementwiseCloneContiguousModule_basic", - "ElementwiseCloneChannelsLastMemoryFormatModule_basic", - "ElementwiseCloneModule_basic", - "ElementwiseBinaryStaticShapeModule_basic", - "ReturnThreeTensorFloat32_basic", - "BoolTensorReturnFalseModule_basic", - "BoolTensorReturnTrueModule_basic", - "BoolTensorReturnMixedModule_basic", - "SqueezeModule_static", - "TModuleRank1_basic", - "TModuleRank0_basic", - "ElementwiseToDtypeIdentityModule_basic", - "View1DFoldModule_basic", - "UnsafeView1DFoldModule_basic", - "UnflattenStaticModule_basic", - "UnflattenIntStaticModule_basic", - "UnflattenIntNegativeOneDimStaticModule_basic", - "UnflattenIntNegativeOneSizeStaticModule_basic", - "RsubFloatModule_basic", - "RsubFloatModule_noalpha_basic", - "RsubIntModule_basic", - "RsubIntModule_noalpha_basic", - "RsubInt0d_NumToTensor_Module_basic", - "ScalarTensorDefaultDtypeModule_basic", - "ScalarTensorFloat32Module_basic", - "ScalarTensorInt32Module_basic", - "ScalarTensorInt64Module_basic", - "SelectScattertModule_basic", - "SelectScattertStaticModule_basic", - "SliceStaticModule_basic", - "SliceModule_basic", - "SliceNegIdxModule_basic", - "SliceOutOfLowerBoundStartIndexModule_basic", - "SliceOutOfUpperBoundIndexModule_basic", - "SliceOutOfUpperBoundIndexStaticModule_basic", - "SliceStartEqEndModule_basic", - "SliceSizeTwoStepModule_basic", - "SliceWholeTensorModule_basic", - "SliceScatterModule_basic", - "SliceScatterNegativeDimModule_basic", - "SliceScatterNegativeEndModule_basic", - "SliceScatterStaticModule_basic", - "SliceScatterStepVariationModule_basic", - "SliceScatterZeroDimModule_basic", - "SqueezeDimModule_static", - "SqueezeDimModule_identity", - "SqueezeModule_broadcast", - "ReturnTwoTensorF32I64_basic", + "MaskedFillScalarFloatValueStaticModule_basic", + "MaskedFillScalarIntValueStaticModule_basic", "Matmul4dStatic_basic", - "Matmul_dot", "Matmul_2d", + "Matmul_dot", "Matmul_matvec", "Matmul_vecmat", + "MaxPool2dStaticModule_basic", "MaxPool2dWithIndicesStaticModule_basic", + "MeanDimAllReduceKeepdimModule_basic", + "MeanDimAllReduceModule_basic", + "MeanDimEmptyDimModule_basic", + "MeanDtypeModule_basic", + "MeanDynamicSizesModule_basic", + "MeanModule_basic", + "Mlp2LayerModuleNoBias_basic", "MmDagModule_basic", "MmModule_basic", "MmModule_chained", - "MaxPool2dStaticModule_basic", - "EmptyModule_contiguous", - "EmptyModule_defaultDtype", - "EmptyModule_falsePinMemory", - "EmptyModule_int", - "EmptyModule_float", + "MmTanhModule_basic", + "MoveDimIntModule_basic", + "MoveDimIntNegativeIndexModule_basic", + "MulFloatModule_basic", + "MulIntModule_basic", + "Mv_basic", + "NarrowHorizontalTest2_basic", + "NarrowHorizontalTest_basic", + "NarrowTensorHorizontalModule_basic", + "NarrowTensorVerticalModule_basic", + "NarrowVerticalTest2_basic", + "NarrowVerticalTest_basic", + "NativeDropoutEvalFloatModule_basic", + "NativeDropoutTrainStaticShapeModule_basic", + "NativeGroupNormModule_basic", + "NativeLayerNormModule4D_basic", + "NativeLayerNormModule_basic", + "NeFloatIntModule_basic", + "NeIntModule_basic", "NewEmptyModuleDefaultDtype_basic", "NewEmptyModuleFalsePinMemory_basic", "NewEmptyModuleFloat2D_basic", @@ -808,117 +641,169 @@ "NewEmptyModuleNonDefaultFloatDtype_basic", "NewEmptyModuleNonDefaultIntDtype_basic", "NewEmptyStridedModuleDefaultDtype_basic", - "EmptyStridedModule_basic", - "EmptyStridedSizeIntStrideModule_basic", - "PermuteModule_basic", - "PermuteNegativeIndexModule_basic", - "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", - "ZeroFloat32Module_basic", - "ZeroInt32Module_basic", - "ZeroInt64Module_basic", - "ZerosLikeModule_defaultDtype", - "ZerosLikeModule_falsePinMemory", - "ZerosLikeModule_float", - "ZerosLikeModule_int", - "ZerosModuleDefaultDtype_basic", - "ZerosModuleInt2D_basic", - "ZerosModuleInt3D_basic", - "ZerosModuleFloat2D_basic", - "ZerosModuleFloat3D_basic", - "ZerosModuleFalsePinMemory_basic", - "OnesModuleDefaultDtype_basic", - "OnesModuleInt_basic", - "OnesModuleFloat_basic", - "OnesModuleFalsePinMemory_basic", - "OnesLikeModule_defaultDtype", - "OnesLikeModule_falsePinMemory", - "OnesLikeModule_float", - "OnesLikeModule_int", - "NewZerosModuleDefaultDtype_basic", - "NewZerosModuleInt2D_basic", - "NewZerosModuleInt3D_basic", - "NewZerosModuleFloat2D_basic", - "NewZerosModuleFloat3D_basic", - "NewZerosModuleFalsePinMemory_basic", + "NewFullModuleDefaultDtype_basic", + "NewFullModuleFalsePinMemory_basic", + "NewFullModuleFloat3DStatic_basic", + "NewFullModuleFloat3D_basic", + "NewFullModuleInt2D_basic", + "NewFullModuleInt3D_basic", "NewOnesModuleDefaultDtype_basic", - "NewOnesModuleInt2D_basic", - "NewOnesModuleInt3D_basic", + "NewOnesModuleFalsePinMemory_basic", "NewOnesModuleFloat2D_basic", "NewOnesModuleFloat3D_basic", - "NewOnesModuleFalsePinMemory_basic", + "NewOnesModuleInt2D_basic", + "NewOnesModuleInt3D_basic", + "NewZerosModuleDefaultDtype_basic", + "NewZerosModuleFalsePinMemory_basic", + "NewZerosModuleFloat2D_basic", + "NewZerosModuleFloat3D_basic", + "NewZerosModuleInt2D_basic", + "NewZerosModuleInt3D_basic", "NewZerosStaticModuleLayoutStrided_basic", - "DropoutEvalIntModule_basic", - "DropoutEvalFloatModule_basic", - "DropoutTrainStaticShapeModule_basic", - "NativeDropoutEvalFloatModule_basic", - "NativeDropoutTrainStaticShapeModule_basic", - "ContiguousModule_basic", - "DropoutModule_basic", - "ViewCollapseModule_basic", - "ViewCollapseInferredDimModule_basic", - "ViewDynamicExpandCollapseModule_basic", - "ViewDynamicExpandModule_basic", - "ViewExpandModule_basic", - "ViewExpandOnesModule_basic", - "ViewExpandOnesBeforeAndAfterModule_basic", - "ViewExpandOnesMiddleModule_basic", - "ViewExpandCollapseModule_basic", - "ViewExpandCollapseWithOnesModule_basic", - "ViewExpandInferredDimModule_basic", - "ViewNegativeStaticModule_basic", - "ViewNoChangeStaticModule_basic", - "ViewNoChange1dModule_basic", - "ViewNoChange2dModule_basic", - "ViewNoChange3dModule_basic", - "UnsafeViewExpandModule_basic", + "NormalizeModule_basic", + "NumToTensorFloatModule_basic", + "NumToTensorIntModule_basic", + "NumelModule_basic", + "NumelZeroRankModule_basic", + "NumpyTRank0Module_basic", + "NumpyTRank1Module_basic", + "NumpyTRank2Module_basic", + "NumpyTRankNDynamicModule_basic", + "NumpyTRankNStaticModule_basic", + "OnesLikeModule_defaultDtype", + "OnesLikeModule_falsePinMemory", + "OnesLikeModule_float", + "OnesLikeModule_int", + "OnesModuleCPUDevice_basic", + "OnesModuleDefaultDtype_basic", + "OnesModuleFalsePinMemory_basic", + "OnesModuleFloat_basic", + "OnesModuleInt_basic", + "Permute0RankModule_basic", + "PermuteModule_basic", + "PermuteNegativeIndexModule_basic", + "PowIntFloatModule_basic", + "PrimMaxIntModule_basic", + "PrimMinIntDynamicModule_basic", + "PrimMinIntModule_basic", + "PrimsConvertElementTypeModule_basic", + "PrimsSqueezeEmptyDimensionsModule_basic", + "PrimsSqueezeModule_basic", + "PrimsViewOfModule_basic", + "PrimsViewOfZeroRankModule_basic", + "RandIntDtypeModule_basic", + "RandIntLowDtypeModule_basic", + "RandIntLowModule_basic", + "RandIntModule_basic", + "RandIntPinMemoryModule_basic", + "RandModule_basic", + "ReduceAmaxMultiDim_basic", + "ReduceAmaxOutOfOrderDim_basic", + "ReduceAmaxSingleDim_basic", + "ReduceFrobeniusNormModule_basic", "ReduceMaxAllDims_basic", + "ReduceMaxAlongDimNegative_basic", + "ReduceMaxAlongDimSignedInt_basic", + "ReduceMaxAlongDim_basic", "ReduceMaxFloatModule_basic", "ReduceMaxSignedIntModule_basic", "ReduceMaxUnsignedIntModule_basic", - "ReduceMinAllDims_basic", "ReduceMinFloatModule_basic", "ReduceMinSignedIntModule_basic", "ReduceMinUnsignedIntModule_basic", + "ReduceSumDimIntListDtypeFloatModule_basic", + "ReduceSumDimIntListDtypeIntModule_basic", + "ReduceSumDimIntListElementTypeBoolModule_basic", + "ReduceSumDimIntListEmptyDimModule_basic", "ReduceSumDimIntListFloatModule_basic", "ReduceSumDimIntListIntModule_basic", + "ReduceSumDimIntListKeepDimNegativeDimStaticModule_basic", + "ReduceSumDtypeFloatModule_basic", + "ReduceSumDtypeIntModule_basic", + "ReduceSumElementTypeBoolModule_basic", "ReduceSumFloatModule_basic", "ReduceSumSignedIntModule_basic", "ReduceSumUnsignedIntModule_basic", "RepeatModule_basic", "ReshapeAliasCollapseModule_basic", "ReshapeAliasExpandModule_basic", - "ReshapeExpandModule_basic", "ReshapeAsModule_basic", - "TestMultipleTensorReturn_basic", - "AdaptiveAvgPool1dUnitOutputSizeStaticModule_basic", - "AdaptiveAvgPool2dUnitOutputSizeStaticModule_basic", - "BaddbmmStaticModule_basic", - "BaddbmmBroadcast1DInputModule_basic", - "BaddbmmBroadcast2DInputModule_basic", - "NarrowHorizontalTest2_basic", - "NarrowHorizontalTest_basic", - "NarrowVerticalTest2_basic", - "NarrowVerticalTest_basic", - "NarrowTensorHorizontalModule_basic", - "NarrowTensorVerticalModule_basic", - "NumToTensorIntModule_basic", - "NumpyTRank0Module_basic", - "NumpyTRank1Module_basic", - "NumpyTRank2Module_basic", - "NumpyTRankNStaticModule_basic", - "NumpyTRankNDynamicModule_basic", + "ReshapeExpandModule_basic", + "ReturnThreeTensorFloat32_basic", + "ReturnTwoTensorF32I64_basic", + "RollModule_basic", + "RsubInt0d_NumToTensor_Module_basic", + "ScalarConstantTupleModule_basic", + "ScalarImplicitFloatModule_basic", + "ScalarImplicitIntModule_basic", + "ScalarTensorDefaultDtypeModule_basic", + "ScalarTensorFloat32Module_basic", + "ScalarTensorInt32Module_basic", + "ScalarTensorInt64Module_basic", + "SelectIntNegativeDimAndIndexStaticModule_basic", + "SelectScattertStaticModule_basic", + "SliceModule_basic", + "SliceNegIdxModule_basic", + "SliceOutOfLowerBoundStartIndexModule_basic", + "SliceOutOfUpperBoundIndexModule_basic", + "SliceOutOfUpperBoundIndexStaticModule_basic", + "SliceScatterModule_basic", + "SliceScatterNegativeDimModule_basic", + "SliceScatterNegativeEndModule_basic", + "SliceScatterStaticModule_basic", + "SliceScatterStepVariationModule_basic", + "SliceScatterZeroDimModule_basic", + "SliceSizeTwoStepModule_basic", + "SliceStartEqEndModule_basic", + "SliceStaticModule_basic", + "SliceWholeTensorModule_basic", + "SortIntListReverse_basic", + "SortIntList_basic", + "SplitTensorGetItem_Module_basic", + "SplitTensorLastSmallerModule_basic", + "SplitTensorListUnpackModule_basic", + "SplitTensorNegativeDimModule_basic", + "SplitWithSizesListUnpackModule_basic", + "SqrtIntConstantModule_basic", + "SqrtIntModule_basic", + "SqueezeDimModule_identity", + "SqueezeDimModule_static", + "SqueezeDimModule_unitDim", + "SqueezeModule_allUnitDim", + "SqueezeModule_static", + "SubFloatModule_basic", + "SubIntModule_basic", + "TModuleRank0_basic", + "TModuleRank1_basic", "TModuleRank2_basic", + "TensorIntModule_basic", "TensorLiteralModule_basic", - "TensorsConcatModule_basic", "TensorOpaqueLiteralModule_basic", - "TransposeIntModule_basic", - "TransposeIntNegDimsModule_basic", - "ToDtypeBoolLayoutNoneModule_basic", + "TensorToBoolZeroRank_basic", + "TensorToFloatZeroRank_basic", + "TensorToIntZeroRank_basic", + "TensorsConcatModule_basic", + "TensorsConcatNegativeDimModule_basic", + "TensorsConcatNegativeDimStaticModule_basic", + "TensorsConcatPromoteDTypeModule_basic", + "TensorsConcatStaticModule_basic", + "TestF16Return_basic", + "TestMultipleTensorAndPrimitiveTypesReturn_basic", + "TestMultipleTensorReturn_basic", + "TileBigDimsSizeModule_basic", + "TileSmallDimsSizeModule_basic", + "ToCopyBoolDTypeStaticModule_basic", "ToDtypeBoolLayoutNoneStaticModule_basic", + "ToDtypeLayoutCPUModule_basic", "ToDtypeLayoutNoneModule_basic", "ToDtypeLayoutStridedModule_basic", - "TypeAsSameModule_basic", + "TransposeIntModule_basic", + "TransposeIntNegDimsModule_basic", + "TriuBroadcastModule_basic", + "TriuModule_basic", + "TupleModule_basic", "TypeAsDifferentModule_basic", + "TypeAsSameModule_basic", "TypeConversionF32ToF64Module_basic", "TypeConversionF64ToF32Module_basic", "TypeConversionI1ToF32Module_basic", @@ -927,57 +812,58 @@ "TypeConversionI1ToI64Module_basic", "TypeConversionI32ToI64Module_basic", "TypeConversionI64ToI32Module_basic", - "TypePromotionAlphaWiderModule_basic", - "TypePromotionSameCategoryZeroRankWider_basic", - "TypePromotionZeroRankHigherCategoryModule_basic", - "OnesModuleCPUDevice_basic", - "Permute0RankModule_basic", - "UnsafeViewCollapseModule_basic", - "UnsafeViewDynamicExpandModule_basic", - "AtenRoundIntModule_basic", - "TestF16Return_basic", - "_LogSoftmaxModuleStable_basic", - "PrimsSqueezeModule_basic", - "PrimsSqueezeEmptyDimensionsModule_basic", - "MoveDimIntModule_basic", - "MoveDimIntNegativeIndexModule_basic", - "ConvolutionBackwardModule2DStatic_basic", - "ConvolutionBackwardModule2DStrided_basic", - "PrimsViewOfModule_basic", - "PrimsViewOfZeroRankModule_basic", - "AtenComplex64Module_basic", - "SplitTensorGetItem_Module_basic", - "SplitTensorListUnpackModule_basic", - "SplitTensorNegativeDimModule_basic", - "SplitTensorLastSmallerModule_basic", - "SplitWithSizesListUnpackModule_basic", - "UnbindIntListUnpack_Module_basic", "UnbindIntGetItem_Module_basic", - "ChunkListUnpack_Module_basic", - "ChunkListUnpackUneven_Module_basic", - "RandIntDtypeModule_basic", - "RandIntLowDtypeModule_basic", - "RandIntLowModule_basic", - "RandIntModule_basic", - "RandIntPinMemoryModule_basic", - "RandModule_basic", - "UniformStaticShapeModule_basic", + "UnbindIntListUnpack_Module_basic", + "UnflattenIntNegativeOneDimStaticModule_basic", + "UnflattenIntNegativeOneSizeStaticModule_basic", + "UnflattenIntStaticModule_basic", + "UnflattenStaticModule_basic", "UniformNoCorrelationModule_basic", - "TupleModule_basic", - "AtenEmbeddingBagStaticModule_basic", + "UniformStaticShapeModule_basic", + "UnsafeView1DFoldModule_basic", + "UnsafeViewCollapseModule_basic", + "UnsafeViewDynamicExpandModule_basic", + "UnsafeViewExpandModule_basic", + "View1DFoldModule_basic", + "ViewCollapseInferredDimModule_basic", + "ViewCollapseModule_basic", + "ViewCollapseOnesMiddleModule_basic", + "ViewDynamicExpandCollapseModule_basic", + "ViewDynamicExpandModule_basic", + "ViewExpandCollapseModule_basic", + "ViewExpandCollapseWithOnesModule_basic", + "ViewExpandDynamicDimModule_basic", + "ViewExpandInferredDimModule_basic", + "ViewExpandModule_basic", + "ViewExpandOnesBeforeAndAfterModule_basic", + "ViewExpandOnesMiddleModule_basic", + "ViewExpandOnesModule_basic", + "ViewNegativeStaticModule_basic", + "ViewNoChange1dModule_basic", + "ViewNoChange2dModule_basic", + "ViewNoChange3dModule_basic", + "ViewNoChangeStaticModule_basic", + "ViewOffsetBackwardTestStaticModule_basic", + "ViewOffsetTestStaticModule_basic", + "ViewTwoFiveThreeStaticModule_basic", + "ViewTwoToThreeStaticModule_basic", + "ZeroFloat32Module_basic", + "ZeroInt32Module_basic", + "ZeroInt64Module_basic", + "ZerosLikeModule_defaultDtype", + "ZerosLikeModule_falsePinMemory", + "ZerosLikeModule_float", + "ZerosLikeModule_int", + "ZerosModuleDefaultDtype_basic", + "ZerosModuleFalsePinMemory_basic", + "ZerosModuleFloat2D_basic", + "ZerosModuleFloat3D_basic", + "ZerosModuleInt2D_basic", + "ZerosModuleInt3D_basic", } -STABLEHLO_CRASHING_SET = { - # These e2e tests crash because currently mlir-hlo's shape-component-analysis - # only support exact one index in tensor::ExtractOp when it's related with - # some tensors' shape. REF: - # https://github.com/tensorflow/mlir-hlo/blob/master/mhlo/analysis/shape_component_analysis.cc#L586 - # FIXME if upstream mlir-hlo fix this. - "ViewCollapseDynamicWithAtenSizeIntModule_basic", - "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", - - "Aten_EmbeddingBagExample_basic", - "AtenEmbeddingBagSumExample_basic" +STABLEHLO_CRASHING_SET = { + "AtenEmbeddingBagSumExample_basic", } # Write the TOSA set as a "passing" set as it is very early in development diff --git a/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py new file mode 100644 index 000000000000..9143ae5eaf46 --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py @@ -0,0 +1,57 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +from torch_mlir.ir import * +from torch_mlir.passmanager import * +from torch_mlir.compiler_utils import run_pipeline_with_repro_report + +from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend + +from .abc import StablehloBackend + +__all__ = [ + "LinalgOnTensorsStablehloBackend", +] + +# The pipeline of func.func passes that lower the STABLEHLO backend contract to the +# Linalg-on-Tensors backend contract accepted by RefBackend. +STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join([ + "func.func(chlo-legalize-to-stablehlo)", + "canonicalize", + "stablehlo-legalize-to-linalg" +]) + + +class LinalgOnTensorsStablehloBackend(StablehloBackend): + """Main entry-point for the linalg-on-tensors based TOSA backend. + + This currently uses the linalg-on-tensors RefBackend for actual execution. + """ + + def __init__(self): + super().__init__() + self.refbackend = RefBackendLinalgOnTensorsBackend() + + def compile(self, imported_module: Module): + """Compiles an imported module that satisfied the TOSA backend contract. + + Args: + imported_module: The MLIR module consisting of funcs in the TOSA + dialect. + Returns: + An opaque, backend specific compiled artifact object that can be + passed to `load`. + """ + + run_pipeline_with_repro_report( + imported_module, + f"builtin.module({STABLEHLO_TO_LINALG_FUNC_PIPELINE})", + "Lowering STABLEHLO backend contract to Linalg-on-Tensors backend contract") + + return self.refbackend.compile(imported_module) + + def load(self, module): + """Loads a compiled artifact into the runtime.""" + return self.refbackend.load(module) From 5733c84443ed2ee3b0cbb75a9e425aa076c457a9 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Fri, 16 Feb 2024 01:38:13 +0800 Subject: [PATCH 197/283] [bazel] fix bazel with stablehlo refbackend and fix some typo (#2911) --- .../stablehlo_backends/linalg_on_tensors.py | 7 +++---- utils/bazel/torch-mlir-overlay/BUILD.bazel | 7 ++++++- 2 files changed, 9 insertions(+), 5 deletions(-) diff --git a/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py index 9143ae5eaf46..7dee2041c724 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py +++ b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py @@ -25,7 +25,7 @@ class LinalgOnTensorsStablehloBackend(StablehloBackend): - """Main entry-point for the linalg-on-tensors based TOSA backend. + """Main entry-point for the linalg-on-tensors based Stablehlo backend. This currently uses the linalg-on-tensors RefBackend for actual execution. """ @@ -35,11 +35,10 @@ def __init__(self): self.refbackend = RefBackendLinalgOnTensorsBackend() def compile(self, imported_module: Module): - """Compiles an imported module that satisfied the TOSA backend contract. + """Compiles an imported module that satisfied the Stablehlo backend contract. Args: - imported_module: The MLIR module consisting of funcs in the TOSA - dialect. + imported_module: The MLIR module consisting of funcs in the Stablehlo dialect. Returns: An opaque, backend specific compiled artifact object that can be passed to `load`. diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index c37023c5e31f..11fcb5714a2c 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -867,7 +867,10 @@ cc_library( hdrs = [ "include/torch-mlir/InitAll.h", ], - copts = ["-DTORCH_MLIR_ENABLE_REFBACKEND"], + copts = [ + "-DTORCH_MLIR_ENABLE_REFBACKEND", + "-DTORCH_MLIR_ENABLE_STABLEHLO", + ], strip_include_prefix = "include", deps = [ ":TorchMLIRConversionPasses", @@ -882,6 +885,8 @@ cc_library( "@llvm-project//mlir:Dialect", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:IR", + "@stablehlo//:stablehlo_passes", + "@stablehlo//:linalg_passes", ], ) From 49f63df0689a2c3351f051801bdd24833daa9a91 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Fri, 16 Feb 2024 01:56:09 +0800 Subject: [PATCH 198/283] [bazel] commit after run buildifier (#2912) --- utils/bazel/torch-mlir-overlay/BUILD.bazel | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/utils/bazel/torch-mlir-overlay/BUILD.bazel b/utils/bazel/torch-mlir-overlay/BUILD.bazel index 11fcb5714a2c..9edb488a0939 100644 --- a/utils/bazel/torch-mlir-overlay/BUILD.bazel +++ b/utils/bazel/torch-mlir-overlay/BUILD.bazel @@ -885,8 +885,8 @@ cc_library( "@llvm-project//mlir:Dialect", "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:IR", - "@stablehlo//:stablehlo_passes", "@stablehlo//:linalg_passes", + "@stablehlo//:stablehlo_passes", ], ) From 074f112d6afbfe48441083fa0e9764114d3c72de Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 15 Feb 2024 10:17:13 -0800 Subject: [PATCH 199/283] [onnx] Add testing using the `onnx` compilation using torch tests (#2795) We can route the torch tests via `onnx` using the `torch.onnx.export` tooling. We can then reimport, lower to torch, and compile to linalg to validate the onnx path is working correctly. The current implementation exposes some failures in the `onnx` path so we cannot enable the onnx test suite yet due to segmentation faults. --- build_tools/ci/test_posix.sh | 4 + .../python_deploy/build_linux_packages.sh | 3 + .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 4 + .../TorchToLinalg/Uncategorized.cpp | 11 +- projects/pt1/e2e_testing/main.py | 13 +- projects/pt1/e2e_testing/xfail_sets.py | 756 ++++++++++++++++++ .../torch_mlir_e2e_test/configs/__init__.py | 1 + .../configs/onnx_backend.py | 101 +++ .../onnx_backends/__init__.py | 0 .../torch_mlir_e2e_test/onnx_backends/abc.py | 49 ++ .../onnx_backends/linalg_on_tensors.py | 65 ++ python/torch_mlir/extras/onnx_importer.py | 8 +- .../torch_mlir/tools/import_onnx/__main__.py | 2 +- .../python/onnx_importer/import_smoke_test.py | 2 +- 14 files changed, 1009 insertions(+), 10 deletions(-) create mode 100644 projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py create mode 100644 projects/pt1/python/torch_mlir_e2e_test/onnx_backends/__init__.py create mode 100644 projects/pt1/python/torch_mlir_e2e_test/onnx_backends/abc.py create mode 100644 projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py diff --git a/build_tools/ci/test_posix.sh b/build_tools/ci/test_posix.sh index 73818051d06b..71a22d0f714e 100755 --- a/build_tools/ci/test_posix.sh +++ b/build_tools/ci/test_posix.sh @@ -24,6 +24,10 @@ echo "::group::Run Stablehlo e2e integration tests" python -m e2e_testing.main --config=stablehlo -v echo "::endgroup::" +echo "::group::Run ONNX e2e integration tests" +python -m e2e_testing.main --config=onnx -v +echo "::endgroup::" + case $torch_version in nightly) # Failing with: NotImplementedError: diff --git a/build_tools/python_deploy/build_linux_packages.sh b/build_tools/python_deploy/build_linux_packages.sh index f0336b2a1a4b..401930887d2f 100755 --- a/build_tools/python_deploy/build_linux_packages.sh +++ b/build_tools/python_deploy/build_linux_packages.sh @@ -305,6 +305,9 @@ function test_in_tree() { echo ":::: Run Linalg e2e integration tests" python -m e2e_testing.main --config=linalg -v + echo ":::: Run Onnx e2e integration tests" + python -m e2e_testing.main --config=onnx -v + # Dynamo is changing a lot in nightly versions, and thus the implementation # tends to become incompatible to the stable version. echo ":::: Run TorchDynamo e2e integration tests" diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index dd9fa211b551..9b2f3673cf33 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -524,6 +524,10 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( loc, rewriter.getI64IntegerAttr(i)))); } + // Correct for negative axis: + if (axis < 0) + axis += dataRank; + // 4. We can not directly perform torch.gather as the onnx.gather op // collects the input data at different location of output compared to // torch.gather op. The output of torch.gather and onnx.gather ops are diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 25c2a93d2797..0019acfc2944 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -586,14 +586,19 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } if (auto add = dyn_cast(op)) { AtenAddTensorOp::Adaptor adaptor(operands); + Type resultElementType = add.getType().cast().getDtype(); Type dtype = converter->convertType(add.getType()) .cast() .getElementType(); - Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); - Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); + Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); + Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype, + /*srcOriginalDtype=*/std::nullopt, + /*dstOriginalDtype=*/resultElementType); Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype, /*srcOriginalDtype=*/std::nullopt, - /*dstOriginalDtype=*/dtype); + /*dstOriginalDtype=*/resultElementType); if (dtype.isa()) { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); diff --git a/projects/pt1/e2e_testing/main.py b/projects/pt1/e2e_testing/main.py index b9cd04c1e80b..cb7bc191ae47 100644 --- a/projects/pt1/e2e_testing/main.py +++ b/projects/pt1/e2e_testing/main.py @@ -18,12 +18,14 @@ LinalgOnTensorsBackendTestConfig, StablehloBackendTestConfig, NativeTorchTestConfig, + OnnxBackendTestConfig, TorchScriptTestConfig, TosaBackendTestConfig, TorchDynamoTestConfig, ) from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend +from torch_mlir_e2e_test.onnx_backends.linalg_on_tensors import LinalgOnTensorsOnnxBackend from torch_mlir_e2e_test.tosa_backends.linalg_on_tensors import LinalgOnTensorsTosaBackend from torch_mlir_e2e_test.stablehlo_backends.linalg_on_tensors import LinalgOnTensorsStablehloBackend @@ -36,7 +38,9 @@ LTC_XFAIL_SET, LTC_CRASHING_SET, TORCHDYNAMO_XFAIL_SET, - TORCHDYNAMO_CRASHING_SET + TORCHDYNAMO_CRASHING_SET, + ONNX_CRASHING_SET, + ONNX_XFAIL_SET, ) # Import tests to register them in the global registry. @@ -44,7 +48,7 @@ register_all_tests() def _get_argparse(): - config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo"] + config_choices = ["native_torch", "torchscript", "linalg", "stablehlo", "make_fx_tosa", "tosa", "lazy_tensor_core", "torchdynamo", "onnx"] parser = argparse.ArgumentParser(description="Run torchscript e2e tests.") parser.add_argument("-c", "--config", choices=config_choices, @@ -58,6 +62,7 @@ def _get_argparse(): "torchscript": compile the model to a torch.jit.ScriptModule, and then run that as-is (useful for verifying TorchScript is modeling the program correctly). "lazy_tensor_core": run the model through the Lazy Tensor Core frontend and execute the traced graph. "torchdynamo": run the model through the TorchDynamo frontend and execute the graph using Linalg-on-Tensors. +"onnx": export to the model via onnx and reimport using the torch-onnx-to-torch path. """) parser.add_argument("-f", "--filter", default=".*", help=""" Regular expression specifying which tests to include in this run. @@ -120,6 +125,10 @@ def main(): config = TorchDynamoTestConfig(RefBackendLinalgOnTensorsBackend()) xfail_set = TORCHDYNAMO_XFAIL_SET crashing_set = TORCHDYNAMO_CRASHING_SET + elif args.config == "onnx": + config = OnnxBackendTestConfig(LinalgOnTensorsOnnxBackend()) + xfail_set = ONNX_XFAIL_SET + crashing_set = ONNX_CRASHING_SET do_not_attempt = set(args.crashing_tests_to_not_attempt_to_run_and_a_bug_is_filed or []).union(crashing_set) available_tests = [test for test in GLOBAL_TEST_REGISTRY if test.unique_name not in do_not_attempt] diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 36a1d5662810..440b7d730c93 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1451,3 +1451,759 @@ "ElementwiseBitwiseAndScalarInt8Module_basic", "Conv2dQInt8Module_basic", } + +ONNX_XFAIL_SET = { + # Failure - onnx_export + "AdaptiveAvgPool1dGeneralDynamic_basic", + "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool1dStaticLargerOutput_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeDynamicModule_basic", + "AdaptiveMaxPool2dDynamicWithIndices_basic", + "AdaptiveMaxPool2dDynamic_basic", + "AdaptiveMaxPool2dStaticWithIndices_basic", + "AdaptiveMaxPool2dStatic_basic", + "AddCDivModule_basic", + "AddIntModule_basic", + "Add_Module_basic", + "AllBoolFalseModule_basic", + "AllBoolTrueModule_basic", + "AnyBoolFalseModule_basic", + "AnyBoolTrueModule_basic", + "AtenComplex64Module_basic", + "AtenComplexImagModule_basic", + "AtenComplexRealModule_basic", + "AtenComplexViewModule_basic", + "AtenEmbeddingBagStaticModule_basic", + "AtenEmbeddingBagSumExample_basic", + "AtenFloatScalarModule_basic", + "AtenIntBoolOpConstFalseModule_basic", + "AtenIntBoolOpConstTrueModule_basic", + "AtenIntBoolOpModule_basic", + "AtenIntTensorByteDtypeModule_basic", + "AtenIntTensorCharDtypeModule_basic", + "AtenItemFpOpModule_basic", + "AtenItemIntOpModule_basic", + "AtenMmQuint8_basic", + "AtenRealView128Module_basic", + "AtenRealView64Module_basic", + "AtenSubFloatModule_basic", + "AtenTopKModule_basic", + "AtenTopKSmallestModule_basic", + "Aten_EmbeddingBagExample_basic", + "AvgPool2dWithoutPadModule_basic", + "BatchMlpLayerModule_basic", + "BincountMinlengthModule_basic", + "BincountModule_basic", + "BincountStaticSizeModule_basic", + "BoolFloatConstantModule_basic", + "BoolFloatFalseModule_basic", + "BoolFloatTrueModule_basic", + "BoolIntConstantModule_basic", + "BoolIntFalseModule_basic", + "BoolIntTrueModule_basic", + "CeilFloatModule_basic", + "ChunkListUnpackDynamic_Module_basic", + "ChunkListUnpackUnevenDynamic_Module_basic", + "CollapseAllDimensionsModule_basic", + "CollapseFullDynamicModule_basic", + "CollapsePartialDynamicModule_basic", + "CollapseRank1DynamicModule_basic", + "CollapseStaticModule_basic", + "ConstantBoolParameterModule_basic", + "ContainsIntList_False", + "ContainsIntList_True", + "Conv1dModule_basic", + "Conv2dBiasNoPaddingModule_basic", + "Conv2dModule_basic", + "Conv2dNoPaddingModule_basic", + "Conv2dQInt8Module_basic", + "Conv2dWithPaddingDilationStrideModule_basic", + "Conv2dWithPaddingModule_basic", + "Conv3dModule_basic", + "ConvTbcModule_basic", + "Conv_Transpose2dModule_basic", + "Convolution2DModule_basic", + "Convolution2DStridedModule_basic", + "ConvolutionBackwardModule2DPadded_basic", + "ConvolutionBackwardModule2DStatic_basic", + "ConvolutionBackwardModule2DStrided_basic", + "ConvolutionBackwardModule2D_basic", + "ConvolutionModule2DGroups_basic", + "ConvolutionModule2DTransposeNonUnitOutputPadding_basic", + "ConvolutionModule2DTransposeStrided_basic", + "ConvolutionModule2DTranspose_basic", + "DivFloatModule_basic", + "DivIntModule_basic", + "ElementwiseAcoshIntModule_basic", + "ElementwiseAcoshModule_basic", + "ElementwiseAsinhIntModule_basic", + "ElementwiseAsinhModule_basic", + "ElementwiseAtanhIntModule_basic", + "ElementwiseAtanhModule_basic", + "ElementwiseAtenIsneginfOpModule_basic", + "ElementwiseAtenIsposinfOpModule_basic", + "ElementwiseBitwiseAndModule_basic", + "ElementwiseBitwiseAndScalarInt32Module_basic", + "ElementwiseBitwiseAndScalarInt64Module_basic", + "ElementwiseBitwiseAndScalarInt8Module_basic", + "ElementwiseBitwiseAndStaticShapeModule_basic", + "ElementwiseBitwiseLeftShiftInt32Module_basic", + "ElementwiseBitwiseLeftShiftInt64Module_basic", + "ElementwiseBitwiseLeftShiftInt8Module_basic", + "ElementwiseBitwiseNotInt32Module_basic", + "ElementwiseBitwiseNotInt64Module_basic", + "ElementwiseBitwiseOrModule_basic", + "ElementwiseBitwiseOrStaticShapeModule_basic", + "ElementwiseBitwiseRightShiftInt32Module_basic", + "ElementwiseBitwiseRightShiftInt64Module_basic", + "ElementwiseBitwiseRightShiftInt8Module_basic", + "ElementwiseBitwiseXorModule_basic", + "ElementwiseBitwiseXorStaticShapeModule_basic", + "ElementwiseCoshIntModule_basic", + "ElementwiseCoshModule_basic", + "ElementwiseDequantizePerChannelModule_basic", + "ElementwiseDequantizePerTensorModule_basic", + "ElementwiseEluNonDefaultModule_basic", + "ElementwiseExpm1IntModule_basic", + "ElementwiseExpm1Module_basic", + "ElementwiseMulTensorComplexModule_basic", + "ElementwiseOrTensorModule_basic", + "ElementwiseOrTensorStaticShapeModule_basic", + "ElementwiseQuantizePerTensorModule_basic", + "ElementwiseRemainderTensorModule_Int_basic", + "EmptyStridedModule_basic", + "EmptyStridedSizeIntStrideModule_basic", + "EqIntModule_basic", + "ExponentialModule_basic", + "GeFloatIntModule_basic", + "GeFloatModule_basic", + "GeIntModule_basic", + "GeluBackwardModule_basic", + "GtFloatIntModule_basic", + "GtIntModule_basic", + "HardtanhBackward_basic", + "IndexPutImpl1DFloatAccumulateModule_basic", + "IndexPutImpl1DFloatNonAccumulateModule_basic", + "IndexPutImpl1DIntAccumulateModule_basic", + "IndexPutImpl1DIntNonAccumulateModule_basic", + "IndexPutImpl2DFloatAccumulateModule_basic", + "IndexPutImpl2DFloatNonAccumulateModule_basic", + "IndexPutImpl2DIndexModule_basic", + "IndexPutImpl2DNoneIndexStaticModule_basic", + "IndexPutImpl3DFloatAccumulateModule_basic", + "IndexPutImpl3DFloatNonAccumulateModule_basic", + "IndexPutImplIndexWithNoneModule_basic", + "IntFloatModule_basic", + "IouOfModule_basic", + "IsFloatingPointFloat_True", + "IsFloatingPointInt_False", + "IscloseStaticModuleTrue_basic", + "IscloseStaticModule_basic", + "LeakyReluBackwardModule_basic", + "LeakyReluBackwardStaticModule_basic", + "LenStrModule_basic", + "LiftFreshCopyModule_basic", + "LogSoftmaxBackwardModule_basic", + "MaxPool2dCeilModeTrueModule_basic", + "MaxPool2dModule_basic", + "MaxPool2dWithIndicesAllOnesModule_basic", + "MaxPool2dWithIndicesBackwardDynamic3DModule_basic", + "MaxPool2dWithIndicesBackwardDynamic4DModule_basic", + "MaxPool2dWithIndicesBackwardStatic3DModule_basic", + "MaxPool2dWithIndicesBackwardStatic4DModule_basic", + "MaxPool2dWithIndicesCeilModeTrueModule_basic", + "MaxPool2dWithIndicesFullSizeKernelModule_basic", + "MaxPool2dWithIndicesModule_basic", + "MaxPool2dWithIndicesNonDefaultDilationModule_basic", + "MaxPool2dWithIndicesNonDefaultParamsModule_basic", + "MaxPool2dWithIndicesNonDefaultStrideModule_basic", + "MaxPool3dCeilModeTrueModule_basic", + "MaxPool3dLargeDatadModule_basic", + "MaxPool3dModuleRandomSimple_basic", + "MaxPool3dModule_basic", + "MeanDimEmptyDimModule_basic", + "Mlp1LayerModule_basic", + "Mlp2LayerModuleNoBias_basic", + "Mlp2LayerModule_basic", + "MulFloatModule_basic", + "MulIntModule_basic", + "NarrowHorizontalTest2_basic", + "NarrowHorizontalTest_basic", + "NarrowTensorHorizontalModule_basic", + "NarrowTensorVerticalModule_basic", + "NarrowVerticalTest2_basic", + "NarrowVerticalTest_basic", + "NativeBatchNorm1DModule_basic", + "NativeBatchNorm2DModule_basic", + "NativeBatchNorm3DModule_basic", + "NativeBatchNormNoneWeightModule_basic", + "NativeDropoutEvalFloatModule_basic", + "NativeGroupNormBackwardModule_basic", + "NativeGroupNormModule_basic", + "NativeLayerNormDynamicModule_basic", + "NeFloatIntModule_basic", + "NeIntModule_basic", + "NewEmptyStridedModuleDefaultDtype_basic", + "NllLossModuleBackward1DMeanWeight_basic", + "NllLossModuleBackward1DMean_basic", + "NllLossModuleBackward1DSumWeight_basic", + "NllLossModuleBackward1DSum_basic", + "NllLossModuleBackward1DWeight_basic", + "NllLossModuleBackward1D_basic", + "NllLossModuleBackwardMeanWeight_basic", + "NllLossModuleBackwardMean_basic", + "NllLossModuleBackwardSumWeight_basic", + "NllLossModuleBackwardSum_basic", + "NllLossModuleBackwardWeight_basic", + "NllLossModuleBackward_basic", + "NllLossModuleBackward_ignore_index", + "NllLossModule_1D_basic", + "NllLossModule_basic", + "NllLossModule_ignore_index_out_of_bounds_basic", + "NllLossModule_mean_basic", + "NllLossModule_sum_basic", + "NormScalarOptDimKeepDimModule_basic", + "NormScalarOptDimModule_basic", + "NormalFunctionalModule_basic", + "NumToTensorFloatModule_basic", + "NumToTensorIntModule_basic", + "NumelModule_basic", + "NumelZeroRankModule_basic", + "PixelShuffleModuleFullDynamic_basic", + "PixelShuffleModuleSpatiallyDynamic_basic", + "PixelShuffleModuleSpatiallyStatic_basic", + "PixelShuffleModuleStaticRank3Int64_basic", + "PowIntFloatModule_basic", + "PrimMaxIntModule_basic", + "PrimMinIntDynamicModule_basic", + "PrimMinIntModule_basic", + "PrimsConvertElementTypeModule_basic", + "PrimsSqueezeEmptyDimensionsModule_basic", + "PrimsSqueezeModule_basic", + "PrimsViewOfModule_basic", + "PrimsViewOfZeroRankModule_basic", + "RandIntDtypeModule_basic", + "RandIntModule_basic", + "RandIntPinMemoryModule_basic", + "ReshapeAliasCollapseModule_basic", + "ReshapeAliasExpandModule_basic", + "ReshapeExpandModule_basic", + "ScalarConstantTupleModule_basic", + "ScalarImplicitFloatModule_basic", + "ScalarImplicitIntModule_basic", + "ScatterReduceFloatMaxModule", + "ScatterReduceFloatMeanModule", + "ScatterReduceFloatMeanModuleIncludeSelf", + "ScatterReduceFloatMinModule", + "ScatterReduceFloatProdModule", + "ScatterReduceFloatSumModule", + "ScatterReduceIntMaxModule", + "ScatterReduceIntMeanModule", + "ScatterReduceIntMeanModuleIncludeSelf", + "ScatterReduceIntMinModule", + "ScatterReduceIntProdModule", + "ScatterReduceIntSumModule", + "SelectScattertModule_basic", + "SelectScattertStaticModule_basic", + "SliceEndSleStartModule_basic", + "SliceOutOfUpperBoundIndexModule_basic", + "SliceScatterModule_basic", + "SliceScatterNegativeDimModule_basic", + "SliceScatterNegativeEndModule_basic", + "SliceScatterStaticModule_basic", + "SliceScatterStepVariationModule_basic", + "SliceScatterZeroDimModule_basic", + "SliceStartEqEndModule_basic", + "SoftmaxBackwardModule_basic", + "SortIntListReverse_basic", + "SortIntList_basic", + "SplitDimDynamicModule_basic", + "SplitDimStaticModule_basic", + "SqrtIntConstantModule_basic", + "SqrtIntModule_basic", + "StdCorrectionEmptyDimModule_basic", + "StdDimEmptyDimModule_basic", + "SubFloatModule_basic", + "SubIntModule_basic", + "TanhBackward_basic", + "TensorToBoolZeroRank_basic", + "TensorToBool_basic", + "TensorToFloatZeroRank_basic", + "TensorToFloat_basic", + "TensorToIntZeroRank_basic", + "TensorToInt_basic", + "TestMultipleTensorAndPrimitiveTypesReturn_basic", + "Threshold1dFloatModule_basic", + "Threshold1dIntI32Module_basic", + "Threshold1dIntModule_basic", + "Threshold2dFloatModule_basic", + "Threshold2dIntModule_basic", + "Threshold3dFloatModule_basic", + "Threshold3dIntModule_basic", + "ThresholdBackward1dFloatModule_basic", + "ThresholdBackward1dIntModule_basic", + "ThresholdBackward1dMixedModule_basic", + "ThresholdBackward2dFloatModule_basic", + "ThresholdBackward2dIntModule_basic", + "ThresholdBackward2dMixedModule_basic", + "ThresholdBackward3dFloatModule_basic", + "ThresholdBackward3dIntModule_basic", + "ThresholdBackward3dMixedModule_basic", + "ToCopyBoolDTypeStaticModule_basic", + "ToCopyModule_basic", + "ToCopyWithDTypeFalsePinMemoryModule_basic", + "ToCopyWithDTypeModule_basic", + "TorchPrimLoopForLikeModule_basic", + "TorchPrimLoopWhileLikeModule_basic", + "TraceModule_basic", + "TraceModule_empty", + "TraceModule_nonsquare", + "TraceSignedIntModule_basic", + "TraceUnsignedIntModule_basic", + "TraceUnsignedIntModule_empty", + "UniformModule_basic", + "UniformNoCorrelationModule_basic", + "UniformStaticShapeModule_basic", + "UnsafeIndexPutHackedTwin1DFloatNonAccumulateModule_basic", + "UnsafeView1DFoldModule_basic", + "UnsafeViewCollapseDynamicWithAtenSizeIntModule_basic", + "UnsafeViewCollapseModule_basic", + "UnsafeViewDynamicExpandModule_basic", + "UnsafeViewDynamicExpandWithAtenSizeIntModule_basic", + "UnsafeViewExpandModule_basic", + "UpSampleNearest2dBackwardScalesNone_basic", + "UpSampleNearest2dBackward_basic", + "UpSampleNearest2dDynamicFactor_basic", + "UpSampleNearest2dStaticFactor_basic", + "UpSampleNearest2d_basic", + "VarCorrectionEmptyDimModule_basic", + "VarDimEmptyDimModule_basic", + "ViewCollapseDynamicWithAtenSizeIntModule_basic", + "ViewCollapseModule_basic", + "ViewDynamicExpandCollapseModule_basic", + "ViewDynamicExpandCollapseWithAtenIntModule_basic", + "ViewDynamicExpandModule_basic", + "ViewDynamicExpandWithAtenSizeIntModule_basic", + "ViewExpandDynamicDimModule_basic", + "ViewNoChange1dModule_basic", + "ViewNoChange2dModule_basic", + "ViewNoChange3dModule_basic", + "_Convolution2DAllFalseModule_basic", + "_Convolution2DBenchmarkModule_basic", + "_Convolution2DCudnnModule_basic", + "_Convolution2DDeterministicModule_basic", + "_Convolution2DTF32Module_basic", + "_ConvolutionDeprecated2DAllFalseModule_basic", + "_ConvolutionDeprecated2DBenchmarkModule_basic", + "_ConvolutionDeprecated2DCudnnModule_basic", + "_ConvolutionDeprecated2DDeterministicModule_basic", + "_SoftmaxModule_basic", + + # Failure - onnx_import + "BucketizeTensorFloatModule_basic", + "BucketizeTensorModule_basic", + "BucketizeTensorOutInt32RightModule_basic", + "BucketizeTensorStaticFloatModule_basic", + "BucketizeTensorStaticModule_basic", + "DiagonalModule_basic", + "DiagonalModule_nonsquare", + "DiagonalModule_transposed", + "DiagonalModule_with_dims", + "DiagonalModule_with_dims_and_offset", + "DiagonalModule_with_negative_dims", + "DiagonalModule_with_offset", + "ElementwiseClampMaxModule_basic", + "ElementwiseClampMinModule_basic", + "ElementwiseClampMinTensorFloatModule_basic", + "ElementwiseClampMinTensorIntModule_basic", + "ElementwiseClampModule_basic", + "ElementwiseClampTensorFloatModule_basic", + "ElementwiseClampTensorInt8Module_basic", + "ElementwiseClampTensorIntModule_basic", + "EmptyLikeMemoryFormatModule_basic", + "EmptyLikeModule_defaultDtype", + "EmptyLikeModule_falsePinMemory", + "EmptyLikeModule_float", + "EmptyLikeModule_int", + "Fill_TensorFloat32WithFloat32_basic", + "Fill_TensorFloat32WithFloat64_basic", + "Fill_TensorFloat32WithInt64_basic", + "Fill_TensorFloat64WithFloat32_basic", + "Fill_TensorFloat64WithFloat64_basic", + "Fill_TensorFloat64WithInt64_basic", + "FullLikeModuleDefaultDtype_basic", + "FullLikeModuleFalsePinMemory_basic", + "FullLikeModuleFloat2D_basic", + "FullLikeModuleFloat3D_basic", + "FullLikeModuleInt2D_basic", + "FullLikeModuleInt3D_basic", + "HBC_basic", + "IndexPut1DFloatAccumulateModule_basic", + "IndexPut1DIntAccumulateModule_basic", + "IndexPut2DFloatAccumulateModule_basic", + "IndexPut2DIntAccumulateModule_basic", + "IndexPut3DFloatAccumulateModule_basic", + "IndexPut3DIntAccumulateModule_basic", + "IndexPutHackedTwin1DFloatAccumulateModule_basic", + "IndexPutHackedTwin1DIntAccumulateModule_basic", + "IndexPutHackedTwin2DFloatAccumulateModule_basic", + "IndexPutHackedTwin2DIntAccumulateModule_basic", + "IndexPutHackedTwin3DFloatAccumulateModule_basic", + "IndexPutHackedTwin3DIntAccumulateModule_basic", + "NormalizeModule_basic", + "OnesLikeModule_defaultDtype", + "OnesLikeModule_falsePinMemory", + "OnesLikeModule_float", + "OnesLikeModule_int", + "PadWithNoneValModule_basic", + "QuantizedMLP_basic", + "RandModule_basic", + "ScatterReduceFloatMaxModuleIncludeSelf", + "ScatterReduceFloatMinModuleIncludeSelf", + "ScatterReduceFloatProdModuleIncludeSelf", + "ScatterReduceFloatSumModuleIncludeSelf", + "ScatterReduceIntMaxModuleIncludeSelf", + "ScatterReduceIntMinModuleIncludeSelf", + "ScatterReduceIntProdModuleIncludeSelf", + "ScatterReduceIntSumModuleIncludeSelf", + "TileBigDimsSizeModule_basic", + "TileSmallDimsSizeModule_basic", + "UpSampleNearest2dDynamicSize_basic", + "UpSampleNearest2dStaticSize_basic", + "ZeroFloat32Module_basic", + "ZeroInt32Module_basic", + "ZeroInt64Module_basic", + "ZerosLikeModule_defaultDtype", + "ZerosLikeModule_falsePinMemory", + "ZerosLikeModule_float", + "ZerosLikeModule_int", + + # Failure - onnx_lowering + "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", + "AdaptiveAvgPool1dStaticEvenMultiple_basic", + "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", + "AtenMmFloatTypes_basic", + "AtenMmIntTypes_basic", + "AtenTrilModule_basic", + "AtenTrilWithNegDiagonalModule_basic", + "AtenTrilWithPosDiagonalModule_basic", + "AtenTriuModule_basic", + "AtenTriuWithNegDiagonalModule_basic", + "AtenTriuWithPosDiagonalModule_basic", + "AvgPool1dFloatModule_basic", + "AvgPool1dIntModule_basic", + "AvgPool1dStaticModule_basic", + "AvgPool2dCeilModeTrueModule_basic", + "AvgPool2dDivisorOverrideModule_basic", + "AvgPool2dFloatModule_basic", + "AvgPool2dIntModule_basic", + "AvgPool2dStaticModule_basic", + "BernoulliFloatModule_basic", + "BernoulliModule_basic", + "BernoulliPModule_basic", + "BernoulliTensorModule_basic", + "ConstantPad2dStaticModule_basic", + "ConstantPadNdModule_basic", + "ConstantPadNdPartialStaticModule_basic", + "ConstantPadNdStaticModule_basic", + "CrossEntropyLossModule_basic", + "CrossEntropyLossNoReductionModule_basic", + "DropoutTrainModule_basic", + "DropoutTrainStaticShapeModule_basic", + "EinsumStaticContractRhsModule_basic", + "EinsumStaticFourDimensionModule_basic", + "EinsumStaticModule_basic", + "ElementwiseMishModule_basic", + "ElementwiseRemainderScalarModule_Bool_basic", + "ElementwiseRemainderScalarModule_Int_basic", + "ElementwiseToDtypeI64ToI8Module_basic", + "ElementwiseToDtypeI64ToUI8Module_basic", + "GroupNormModule_basic", + "GroupNormNoWeightAndBiasModule_basic", + "HardswishModule_basic", + "HardswishRandomModule_basic", + "IndexPut1DFloatNonAccumulateModule_basic", + "IndexPut1DIntNonAccumulateModule_basic", + "IndexPut2DFloatNonAccumulateModule_basic", + "IndexPut2DIntNonAccumulateModule_basic", + "IndexPut3DFloatNonAccumulateModule_basic", + "IndexPut3DIntNonAccumulateModule_basic", + "IndexPutHackedTwin1DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin1DIntNonAccumulateModule_basic", + "IndexPutHackedTwin2DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin2DIntNonAccumulateModule_basic", + "IndexPutHackedTwin3DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin3DIntNonAccumulateModule_basic", + "LogSoftmaxIntModule_basic", + "MaxPool2dWithIndicesAllNegativeValuesModule_basic", + "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", + "MaxPool2dWithIndicesStaticModule_basic", + "MmDagModule_basic", + "MmModule_basic", + "MmModule_chained", + "MmTanhModule_basic", + "MobilenetV3Module_basic", + "MseLossSumReductionWithDifferentElemTypeModule_basic", + "NativeDropoutTrainModule_basic", + "NativeDropoutTrainStaticShapeModule_basic", + "OneHotModule_basic", + "PadModule_basic", + "RandIntLowDtypeModule_basic", + "RandIntLowModule_basic", + "RandLikeDtypeModule_basic", + "RandLikeModule_basic", + "RandnDtypeDeviceModule_basic", + "RandnGeneratorF64Module_basic", + "RandnGeneratorModule_basic", + "RandnLikeDtypeModule_basic", + "RandnLikeModule_basic", + "RandnModule_basic", + "ReduceL1NormModule_basic", + "ReduceL1NormWithDTypeModule_basic", + "ReduceL2NormModule_basic", + "ReduceL3NormAllDimsModule_basic", + "ReduceL3NormKeepDimModule_basic", + "ReduceProdDimIntFloatModule_basic", + "ReduceSumDtypeFloatModule_basic", + "ReduceSumDtypeIntModule_basic", + "ReduceSumElementTypeBoolModule_basic", + "ReduceSumFloatModule_basic", + "ReduceSumSignedIntModule_basic", + "ReduceSumUnsignedIntModule_basic", + "ReflectionPad1dModule2dInput_Right", + "ReflectionPad1dModule2dInput_basic", + "ReflectionPad1dModule3dInput_Left", + "ReflectionPad1dModule3dInput_basic", + "ReflectionPad2dModule_Bottom", + "ReflectionPad2dModule_Left", + "ReflectionPad2dModule_Right", + "ReflectionPad2dModule_Top", + "ReflectionPad2dModule_basic", + "ReplicationPad2dModule_basic", + "ReplicationPad2dModule_bottom0", + "ReplicationPad2dModule_left0", + "ReplicationPad2dModule_right0", + "ReplicationPad2dModule_top0", + "ScatterSrcModule_basic", + "ScatterSrcStaticModule_basic", + "ScatterValueFloatModule_basic", + "ScatterValueIntModule_basic", + "SoftplusModule_basic", + "SortTensorDescending_basic", + "SortTensorInteger_basic", + "SortTensorNegativeDimension_basic", + "SortTensorSpecificDimension_basic", + "SortTensor_basic", + "SqueezeModule_allUnitDim", + "SqueezeModule_broadcast", + "SqueezeModule_static", + "StdCorrectionAllDimReduceModule_basic", + "StdCorrectionKeepDimModule_basic", + "StdCorrectionLargeInputModule_basic", + "StdCorrectionModule_basic", + "StdCorrectionNoneModule_basic", + "StdCorrectionSingleDimReduceModule_basic", + "StdDimKeepDimFalseModule_basic", + "StdDimKeepDimTrueModule_basic", + "StdDimNoneDimModule_basic", + "StdUnbiasedModule_basic", + "TriuBroadcastModule_basic", + "TriuModule_basic", + "TypeConversionI1ToI32Module_basic", + "TypeConversionI64ToI32Module_basic", + "UnflattenIntNegativeOneDimStaticModule_basic", + "UnflattenIntNegativeOneSizeStaticModule_basic", + "UnflattenIntStaticModule_basic", + "UnflattenStaticModule_basic", + "VarCorrectionAllDimReduceModule_basic", + "VarCorrectionKeepDimModule_basic", + "VarCorrectionLargeInputModule_basic", + "VarCorrectionModule_basic", + "VarCorrectionNoneModule_basic", + "VarCorrectionSingleDimReduceModule_basic", + "VarDimAllDimReduceModule_basic", + "VarDimModule_basic", + "VarDimMultiDimModule_basic", + "VarDimNegativeModule_basic", + "VarDimNoneDimModule_basic", + "VarDimSingleDimModule_basic", + "VarDimUnbiasedModule_basic", + "VarMeanCorrectionModule_basic", + "VarMeanCorrectionNoneModule_basic", + "VarMeanDimModule_basic", + "VarMeanUnbiasedModule_basic", + "VarUnbiasedModule_basic", + "_LogSoftmaxModuleStable_basic", + "_LogSoftmaxModule_basic", + + # Failure - cast_error + "MeanDimNoneDimModule_basic", + "MeanDtypeModule_basic", + "MeanDynamicSizesModule_basic", + "MeanModule_basic", + "MseLossMeanReductionModule_basic", + "StdBiasedModule_basic", + "VarBiasedModule_basic", + "VarMeanBiasedModule_basic", + + # Failure - constant_int + "ReduceMinAlongDimNegative_basic", + "ReduceMinAlongDimSignedInt_basic", + "ReduceMinAlongDim_basic", + "ReduceMinFloatModule_basic", + "ReduceMinKeepDimReturnBoth_basic", + "ReduceMinSignedIntModule_basic", + "ReduceMinUnsignedIntModule_basic", + "SplitTensorGetItem_Module_basic", + "SplitTensorLastSmallerModule_basic", + "SplitTensorListUnpackModule_basic", + "SplitTensorNegativeDimModule_basic", + "SplitWithSizesListUnpackModule_basic", + "UnbindIntGetItem_Module_basic", + "UnbindIntListUnpack_Module_basic", + + # Failure - operand_type + "ElementwiseAcosIntModule_basic", + "ElementwiseAsinIntModule_basic", + "ElementwiseAtanTensorIntModule_basic", + "ElementwiseCosIntModule_basic", + "ElementwiseErfIntModule_basic", + "ElementwiseExpIntModule_basic", + "ElementwiseLog10IntModule_basic", + "ElementwiseLog2IntModule_basic", + "ElementwiseLogIntModule_basic", + "ElementwiseSinIntModule_basic", + "ElementwiseTanIntModule_basic", + "ElementwiseUnaryIntModule_basic", + + # Failure - expand_multidim + "IndexTensorHackedTwinModule3dInput_basic", + "IndexTensorHackedTwinModule_basic", + "IndexTensorModule3dInput_basic", + "IndexTensorModule_basic", + "IndexTensorMultiInputContiguousOneDimDynamic_basic", + "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", + + # Failure - rankless_return + "ReduceAmaxMultiDim_basic", + "ReduceAmaxOutOfOrderDim_basic", + "ReduceAmaxSingleDim_basic", + "ReduceMaxAllDims_basic", + "ReduceMaxAlongDimNegative_basic", + "ReduceMaxAlongDimSignedInt_basic", + "ReduceMaxAlongDimUnsignedInt_basic", + "ReduceMaxAlongDim_basic", + "ReduceMaxFloatModule_basic", + "ReduceMaxSignedIntModule_basic", + "ReduceMaxUnsignedIntModule_basic", + + # Failure - slice_lowering + "ScaledDotProductAttentionDifferentModule_basic", + "ScaledDotProductAttentionSameModule_basic", + + # Failure - view_lowering + "AddSizeIntModule_basic", + "ElementwiseFlattenBroadcastModule_basic", + "FlattenRank0Module_basic", + "IndexTensorDyanmicInputContiguousWithNoneModule_basic", + "IndexTensorDyanmicInputNonContiguousWithNoneModule_basic", + "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", + "IndexTensorMultiInputContiguousCenter_basic", + "IndexTensorMultiInputNonContiguousDynamic_basic", + "IndexTensorMultiInputNonContiguousMultipleStaticDims_basic", + "IndexTensorMultiInputNonContiguous_basic", + "IndexTensorMultiInputOneDim_basic", + "IndexTensorMultiInputThreeIndexers_basic", + "IndexTensorMultiInput_basic", + "IndexTensorSelectDimModule_basic", + "IndexTensorStaticContiguousWithNoneModule_basic", + "RepeatModule_basic", + "SelectIntModule_basic", + "SelectIntNegativeDimAndIndexStaticModule_basic", + "SliceSingleIdxModule_basic", + "ViewFlattenAndExpandModule_basic", + "ViewSizeDimFollowedByCollapsedOnesModule_basic", + "ViewSizeDimFollowedByExpandedOnesModule_basic", + "ViewSizeDimLedAndFollowedByCollapsedOnesModule_basic", + "ViewSizeDimLedAndFollowedByExpandedOnesModule_basic", + "ViewSizeDimLedByCollapsedOnesModule_basic", + "ViewSizeDimLedByExpandedOnesModule_basic", + + # Failure - numerical + "AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic", + "ElementwiseSeluModule_basic", + "EmbeddingModule1DIndices_basic", + "EmbeddingModuleI32Static_basic", + "FlipNegativeIndexModule_basic", + "HardsigmoidModule_basic", + "HardsigmoidRandomModule_basic", + "IndexSelectDynamicIndexSizeModule_basic", + "IndexSelectDynamicInputSizeModule_basic", + "IndexSelectDynamicModulebasic", + "IndexSelectNegativeDimModule_basic", + "IndexSelectSingleIdxModule_basic", + "IndexSelectTwoIdxModule_basic", + "IndexSelectWholeDimensionModule_basic", + "IndexTensorStaticModule_basic", + "IndexTensorStaticNonContiguousWithNoneModule_basic", + "PixelShuffleModuleStaticRank4Float32_basic", + "ResNet18Module_basic", + "SliceCopyEndGreaterThanDimSize_Module_basic", + "SliceCopyNegative_Module_basic", + "SliceCopyNonZeroDim_Module_basic", + "SliceCopy_Module_basic", + "TupleModule_basic", + + # Failure - shape + "ArangeStartOutDtypeModule_basic", + "ArangeStartOutViewModule_basic", + "BroadcastDynamicDimModule_basic", + "BroadcastToModule_basic", + "EmbeddingModuleF16_basic", + "EmbeddingModuleI32_basic", + "EmbeddingModuleI64_basic", + "ExpandModule_basic", + "ReduceAmaxKeepDim_basic", + "ReduceMaxKeepDimReturnBoth_basic", + "ReduceMaxNegativeDim_basic", + "ViewSizeFromOtherTensor_basic", + + # Failure - unknown + "ChunkListUnpackUneven_Module_basic", + "ChunkListUnpack_Module_basic", + "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", + "CopyWithDifferentDTypesAndSizesModule_basic", + "CopyWithDifferentDTypesModule_basic", + "CosineSimilarityStaticBroadcastModule_basic", + "CumsumInputDtypeInt32Module_basic", + "ElementwiseAtan2TensorIntModule_basic", + "ElementwiseDivRoundingModeTruncModule_basic", + "ElementwisePreluModule_basic", + "ElementwiseUnsqueezeNegDimsModule_basic", + "ElementwiseWhereScalarModule_basic", + "FlattenDynamicModule_basic", + "FlipModuleStaticShape_basic", + "GluStaticModule_basic", + "MaskedFillTensorFloatValueModule_basic", + "ReduceAllDimEmpty_basic", + "ReduceAllDimFloat_basic", + "ReduceAllDimInt_basic", + "ReduceMinAlongDimUnsignedInt_basic", + "TensorsStackNegativeDimModule_basic", + "TensorsStackPromoteDTypeModule_basic", +} + +ONNX_CRASHING_SET = { + "ElementwiseSigmoidIntModule_basic", + "FlipModule_basic", + "IndexTensorNegativeIndexModule_basic", + "MoveDimIntNegativeIndexModule_basic", + "PermuteNegativeIndexModule_basic", + "RollModule_basic", + "SliceModule_basic", + "SliceNegIdxModule_basic", + "SliceOutOfLowerBoundEndIndexModule_basic", + "SliceOutOfLowerBoundStartIndexModule_basic", + "SliceSizeTwoStepModule_basic", +} diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/configs/__init__.py index 4ca4c3dce803..b11c242db2cb 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/__init__.py @@ -6,6 +6,7 @@ from .lazy_tensor_core import LazyTensorCoreTestConfig from .linalg_on_tensors_backend import LinalgOnTensorsBackendTestConfig from .native_torch import NativeTorchTestConfig +from .onnx_backend import OnnxBackendTestConfig from .torchscript import TorchScriptTestConfig from .stablehlo_backend import StablehloBackendTestConfig from .tosa_backend import TosaBackendTestConfig diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py new file mode 100644 index 000000000000..e411a7cbb67f --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/onnx_backend.py @@ -0,0 +1,101 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +from pathlib import Path +from typing import Any + +import io +import onnx +import torch +import torch_mlir + +from torch_mlir_e2e_test.onnx_backends.abc import OnnxBackend +from torch_mlir_e2e_test.framework import TestConfig, Trace, TraceItem +from torch_mlir_e2e_test.utils import convert_annotations_to_placeholders +from .utils import ( + recursively_convert_to_numpy, + recursively_convert_from_numpy, +) + +from torch_mlir.extras import onnx_importer +from torch_mlir.dialects import torch as torch_d +from torch_mlir.ir import Context, Module + + +def import_onnx(contents): + # Import the ONNX model proto from the file contents: + raw_model = onnx.load_from_string(contents) + model_proto = onnx.shape_inference.infer_shapes(raw_model) + + # Import the ONNX module into an MLIR module: + context = Context() + torch_d.register_dialect(context) + model_info = onnx_importer.ModelInfo(model_proto) + m = model_info.create_module(context=context) + imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m.operation) + imp.import_all() + return m + + +def convert_onnx(model, inputs): + buffer = io.BytesIO() + + # Process the type information so we export with the dynamic shape information + examples = [] + input_names = [] + dynamic_tensors = {} + for (index, arg) in enumerate(inputs): + shape = map(lambda d : d if d >= 0 else 1, arg.shape) + shape = tuple(shape) + examples.append(torch.zeros(size=shape, dtype=arg.dtype)) + + input_name = "input_{}".format(index) + input_names.append(input_name) + + dynamic_dims = {} + for (dimindex, dim) in enumerate(arg.shape): + if (dim < 0): + dynamic_dims[dimindex] = "dim_{}_{}".format(index, dimindex) + + if (dynamic_dims): + dynamic_tensors[input_name] = dynamic_dims + + + examples=tuple(examples) + torch.onnx.export(model, examples, buffer, input_names=input_names, dynamic_axes=dynamic_tensors) + buffer = buffer.getvalue() + return import_onnx(buffer) + +class OnnxBackendTestConfig(TestConfig): + """Base class for TestConfig's that are implemented with ONNX. + + This class handles all the common lowering that torch-mlir does before + reaching the ONNX abstraction level. + """ + def __init__(self, backend: OnnxBackend, use_make_fx: bool = False): + super().__init__() + self.backend = backend + self.use_make_fx = use_make_fx + + def compile(self, program: torch.nn.Module) -> Any: + example_args = convert_annotations_to_placeholders(program.forward) + onnx_module = convert_onnx(program, example_args) + compiled_module = self.backend.compile(onnx_module) + return compiled_module + + + + def run(self, artifact: Any, trace: Trace) -> Trace: + backend_module = self.backend.load(artifact) + result: Trace = [] + for item in trace: + numpy_inputs = recursively_convert_to_numpy(item.inputs) + outputs = getattr(backend_module, "main_graph")(*numpy_inputs) + output = recursively_convert_from_numpy(outputs) + result.append( + TraceItem(symbol=item.symbol, + inputs=item.inputs, + output=output)) + return result diff --git a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/abc.py b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/abc.py new file mode 100644 index 000000000000..684c08df4fa1 --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/abc.py @@ -0,0 +1,49 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import abc +from typing import TypeVar + +import torch + +from torch_mlir.ir import Module + +# A type shared between the result of `OnnxBackend.compile` and the +# input to `OnnxBackend.load`. Each backend will likely have a +# different definition of this type. +CompiledArtifact = TypeVar('CompiledArtifact') + +# A wrapper around a backend-specific loaded program representation +# that uniformly translates the `x.method(...)` interface expected of +# Torch modules into appropriate lower-level operations. +Invoker = TypeVar('Invoker') + + +class OnnxBackend(abc.ABC): + """The interface to an ONNX backend. + + Backends are recommended to raise meaningful exceptions in case of error, + ideally with easy reproduction instructions. + """ + @abc.abstractmethod + def compile(self, module: Module) -> CompiledArtifact: + """Compile the provided MLIR module into a compiled artifact. + + The module adheres to the ONNX backend contract + (see the VerifyOnnxBackendContract pass). + + The compiled artifact can be any type, but must be correctly + interpreted by the `load` method. + """ + + @abc.abstractmethod + def load(self, artifact: CompiledArtifact) -> Invoker: + """Load the compiled artifact into a uniformly invokable form. + + The compiled artifact is the result of a previous call to `compile`. + + See the description of `Invoker` for the requirements on the returned + type. + """ diff --git a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py new file mode 100644 index 000000000000..e77d795b7269 --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py @@ -0,0 +1,65 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + + +from torch_mlir.compiler_utils import run_pipeline_with_repro_report +from torch_mlir.ir import * +from torch_mlir.passmanager import * +from torch_mlir.torchscript import OutputType +from torch_mlir.torchscript import _lower_mlir_module + +from torch_mlir_e2e_test.linalg_on_tensors_backends.refbackend import RefBackendLinalgOnTensorsBackend + +from .abc import OnnxBackend + +__all__ = [ + "LinalgOnTensorsOnnxBackend", +] + +# The pipeline of func.func passes that lower the ONNX backend contract to the +# Linalg-on-Tensors backend contract accepted by RefBackend. +ONNX_TO_TORCH_FUNC_PIPELINE = ",".join([ + "convert-torch-onnx-to-torch", +]) + + +class LinalgOnTensorsOnnxBackend(OnnxBackend): + """Main entry-point for the linalg-on-tensors based ONNX backend. + + This currently uses the linalg-on-tensors RefBackend for actual execution. + """ + + def __init__(self): + super().__init__() + self.refbackend = RefBackendLinalgOnTensorsBackend() + + def compile(self, imported_module: Module): + """Compiles an imported module that satisfied the ONNX backend contract. + + Args: + imported_module: The MLIR module consisting of ONNX operations wrapped by + torch.operator. + Returns: + An opaque, backend specific compiled artifact object that can be + passed to `load`. + """ + run_pipeline_with_repro_report( + imported_module, + f"builtin.module(func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))", + "Lowering Onnx backend contract to Linalg-on-Tensors backend contract") + + run_pipeline_with_repro_report( + imported_module, + f"builtin.module(torch-lower-to-backend-contract)", + "Lowering TorchFX IR -> Torch Backend IR", + ) + + imported_module = _lower_mlir_module(False, OutputType.LINALG_ON_TENSORS, imported_module) + compiled_module = self.refbackend.compile(imported_module) + return compiled_module + + def load(self, module): + """Loads a compiled artifact into the runtime.""" + return self.refbackend.load(module) diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index c651f79b15fe..24520e9ce970 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -102,7 +102,7 @@ def __init__(self, model_proto: onnx.ModelProto, *, config: Config = Config()): def create_module(self, context: Optional[Context] = None) -> Operation: if not context: context = Context() - module_op = Module.create(Location.unknown(context)).operation + module_op = Module.create(Location.unknown(context)) # TODO: Populate module level metadata from the ModelProto return module_op @@ -334,7 +334,8 @@ def import_attributes( f"This likely means that this is a special node which requires specific " f"handling in the importer: {onnx_attr}" ) - attrs[f"torch.onnx.{onnx_attr.name}"] = handler(onnx_attr, self._cc) + result = handler(onnx_attr, self._cc) + attrs[f"torch.onnx.{onnx_attr.name}"] = result def import_initializer(self, initializer: onnx.TensorProto, extern_name: str = None) -> Value: # If an explicitly specified name is given, use that; otherwise, pick @@ -502,9 +503,10 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: if tp.HasField("raw_data"): # Conveniently, DenseResourceElementsAttr shares the raw data # format. We just give it maximum numeric alignment. - return DenseResourceElementsAttr.get_from_buffer( + resource = DenseResourceElementsAttr.get_from_buffer( tp.raw_data, self._sanitize_name(tp.name), tensor_type, alignment=8 ) + return resource else: # We have to do a data type specific instantiation from proto fields. # Since this is typically used for small tensor constants, we instantiate diff --git a/python/torch_mlir/tools/import_onnx/__main__.py b/python/torch_mlir/tools/import_onnx/__main__.py index 399073be1570..547fe5339dad 100644 --- a/python/torch_mlir/tools/import_onnx/__main__.py +++ b/python/torch_mlir/tools/import_onnx/__main__.py @@ -34,7 +34,7 @@ def main(args: argparse.Namespace): context = Context() torch_d.register_dialect(context) model_info = onnx_importer.ModelInfo(model_proto) - m = model_info.create_module(context=context) + m = model_info.create_module(context=context).operation imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m) imp.import_all() if not args.no_verify: diff --git a/test/python/onnx_importer/import_smoke_test.py b/test/python/onnx_importer/import_smoke_test.py index 39a0b3098150..f27cc9caf5bd 100644 --- a/test/python/onnx_importer/import_smoke_test.py +++ b/test/python/onnx_importer/import_smoke_test.py @@ -326,7 +326,7 @@ def run_import_test(self, norm_name: str, rel_path: str): model_info = onnx_importer.ModelInfo( self.load_onnx_model(ONNX_TEST_DATA_DIR / rel_path), ) - m = model_info.create_module(context=context) + m = model_info.create_module(context=context).operation try: imp = onnx_importer.NodeImporter.define_function(model_info.main_graph, m) imp.import_all() From 5253282c55546d19e00f8b244c2da74cf76c8486 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 16 Feb 2024 09:46:30 -0800 Subject: [PATCH 200/283] [fx] Support mutation in ExportedProgram. (#2916) As of https://github.com/pytorch/pytorch/pull/118969, `ExportedProgram` has the long awaited fixes to correctly categorize various things relating to parameters, buffers, mutated inputs and constants. With this additional modeling, we are finally able to implement (safely/soundly) the mutable semantics that were attempted on the TorchScript path. The difference is that on that path, we had to conservatively treat everything as mutable and run some dodgy heuristics (which have been the cause of many bugs relating to "MaximizeValueSemantics") to try to get back to an immutable state. The new model supports mutability at the graph edges, allowing both user inputs and buffers to be mutated (there is some more support than that, but that is all I fully tracked through to implementation). Therefore, when we receive programs like this, we now can selectively enable mutation at the edges. This happens to be the mutability model that IREE supports, which I expect to be a primary beneficiary. However, there is nothing stopping anyone else from handling the `!torch.tensor` types and the existing copy/overwrite ops that will be selectively added. Since this relies on API changes that will not release until 2.3, I'm being a bit cautious about not refactoring existing facilities. --- python/torch_mlir/extras/fx_importer.py | 505 +++++++++++++++--- python/torch_mlir/fx.py | 21 +- test/python/fx_importer/sparse_test.py | 2 +- test/python/fx_importer/v2.3/lit.local.cfg | 9 + .../fx_importer/v2.3/mutation_import.py | 163 ++++++ 5 files changed, 634 insertions(+), 66 deletions(-) create mode 100644 test/python/fx_importer/v2.3/lit.local.cfg create mode 100644 test/python/fx_importer/v2.3/mutation_import.py diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 5677ee4f75ba..89a3caa16843 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -42,7 +42,9 @@ from torch.fx import ( Graph, GraphModule, + Node, ) + try: import ml_dtypes except ModuleNotFoundError: @@ -89,11 +91,6 @@ "FxImporter", ] -# An external callback that, given a Python value and a GraphNodeImporter, may choose -# to materialize IR to load the value as a vtensor. If it returns None, then default -# literal resolution proceeds. -LiteralResolverCallback = Callable[[Any, "GraphNodeImporter"], Optional[Value]] - REQUIRED_DIALCTS = [ "builtin", "func", @@ -280,6 +277,44 @@ def is_builtin_function_or_method(obj: Any) -> bool: return isinstance(obj, (BuiltinMethodType, BuiltinFunctionType)) +@dataclass(frozen=True, slots=True) +class InputInfo: + """Provides additional metadata when resolving inputs.""" + + program: torch.export.ExportedProgram + input_spec: torch.export.graph_signature.InputSpec + node: Node + ir_type: IrType + mutable_producer_node_name: Optional[str] = None + + +class FxImporterHooks: + """Hooks to control the behavior of the FxImporter.""" + + def prepare_module(self, module_op: Operation): + """Performs any needed preparation work on the module.""" + ... + + def resolve_literal( + self, gni: "GraphNodeImporter", literal: Any + ) -> Optional[Value]: + """User overridable hook to resolve a literal value.""" + return None + + def resolve_input( + self, gni: "GraphNodeImporter", value: Any, info: InputInfo + ) -> Optional[Value]: + """Resolves a Parameter or Buffer input to an IR value. + + If the 'mutable_producer_node_name' option is set, then the result must + be a `!torch.tensor`. + Otherwise, it must be an immutable `!torch.vtensor`. If this constraint cannot + be met, the implementation must either error or return None to delegate to + the default. + """ + return None + + class FxImporter: """Main entry-point for importing an fx.GraphModule. @@ -302,10 +337,10 @@ class FxImporter: __slots__ = [ "_c", "_cc", - "_literal_resolver_callback", "_m", "_m_ip", "_py_attr_tracker", + "_hooks", "symbol_table", ] @@ -315,8 +350,8 @@ def __init__( module: Optional[Module] = None, context: Optional[Context] = None, config_check: bool = True, - literal_resolver_callback: Optional[LiteralResolverCallback] = None, py_attr_tracker: Optional["RefTracker"] = None, + hooks: Optional[FxImporterHooks] = None, ): if module is not None: assert context is None, "If configuring with a Module, context must be None" @@ -331,8 +366,9 @@ def __init__( self._py_attr_tracker = py_attr_tracker or RefTracker() self._cc = ContextCache(self._c, py_attr_tracker=self._py_attr_tracker) self._m_ip = InsertionPoint(self._m.body) - self._literal_resolver_callback = literal_resolver_callback + self._hooks = hooks or FxImporterHooks() self.symbol_table = SymbolTable(self._m.operation) + self._hooks.prepare_module(self._m.operation) def _config_check(self): for dname in REQUIRED_DIALCTS: @@ -352,7 +388,204 @@ def module(self) -> Module: def module_op(self) -> Operation: return self._m.operation - def import_frozen_exported_program(self, prog: torch.export.ExportedProgram): + def import_program( + self, prog: torch.export.ExportedProgram, *, func_name: str = "main" + ): + """Imports an ExportedProgram according to our chosen canonical representation. + + This mechanism is the fully general solution for handling an ExportedProgram + and should eventually supercede all others. However, it depends on the + PyTorch 2.3 release to function properly (specifically, this patch + made ExportedProgram minimally correct for mutation: + https://github.com/pytorch/pytorch/pull/118969). + + For stateless programs, the result of this import is a normal function + defined for immutable `!torch.vtensors`. + + However, if the program mutates its inputs or buffers, then it will be imported + with those parameters as `!torch.tensor` and appropriate copies and overwrites + will be done on the inside. Note that the function is still mostly stateless, + but with `torch.copy.to_vtensor` and `torch.overwrite.tensor.contents` + ops at the earliest consumer or latest producer to update an argument or + buffer. + + It is recommended that integrators subclass and override the `resolve_literal` + method to control access to mutable buffers and parameters. Without that, the + default policy is to capture them as frozen values. + """ + # Create lookaside table of placeholders/outputs. + placeholder_nodes: dict[str, Node] = {} + all_producer_nodes: dict[str, Node] = {} + loc: Optional[Location] = None + for node in prog.graph.nodes: + if loc is None: + loc = self._cc.get_node_location(node) + if node.op == "placeholder": + placeholder_nodes[node.name] = node + all_producer_nodes[node.name] = node + elif node.op == "call_function": + all_producer_nodes[node.name] = node + if loc is None: + loc = Location.unknown(self._c) + + # This API is fast evolving. We keep these imports local for now so that we + # can disable this entire function if needed. + from torch.export.graph_signature import ( + InputKind, + OutputKind, + TensorArgument, + SymIntArgument, + ) + + sig = prog.graph_signature + + # Invert the (producer, node_name) maps for mutated user inputs and mutated + # buffers. This is because we hit-detect based on the input node name. + mutated_user_inputs = { + node_name: producer + for producer, node_name in sig.user_inputs_to_mutate.items() + } + + # Additional bindings that we need to set up after the function is created. + mutable_buffer_target_producers: dict[str, str] = {} + constant_tensors: dict[Node, torch.Tensor] = {} + parameter_bindings: dict[Node, tuple[Any, InputInfo]] = {} + buffer_bindings: dict[Node, tuple[Any, InputInfo]] = {} + + # Derive user outputs that we preserve. These will be nodes of the + # producer for the output. + user_outputs: list[Node] = [] + user_output_types: list[IrType] = [] + for output_spec in sig.output_specs: + kind = output_spec.kind + arg = output_spec.arg + if kind == OutputKind.USER_OUTPUT: + if not isinstance(arg, (TensorArgument, SymIntArgument)): + raise NotImplementedError( + f"OutputKind.USER_OUTPUT for {type(arg)}: {arg}" + ) + output_producer_node = all_producer_nodes[arg.name] + user_outputs.append(output_producer_node) + user_output_types.append( + self._cc.node_val_to_type(output_producer_node) + ) + elif kind == OutputKind.BUFFER_MUTATION and isinstance(arg, TensorArgument): + mutable_buffer_target_producers[output_spec.target] = arg.name + + # Derive user inputs. These will be op=='placeholder' nodes. + user_inputs: list[Node] = [] + user_input_types: list[IrType] = [] + for input_spec in sig.input_specs: + arg = input_spec.arg + if input_spec.kind == InputKind.USER_INPUT: + # Set up user input. + if not isinstance(arg, (TensorArgument, SymIntArgument)): + raise NotImplementedError( + f"InputKind.USER_INPUT for {type(arg)}: {arg}" + ) + placeholder_node = placeholder_nodes[arg.name] + mutable = placeholder_node.name in mutated_user_inputs + user_inputs.append(placeholder_node) + user_input_types.append( + self._cc.node_val_to_type(placeholder_node, mutable=mutable) + ) + elif input_spec.kind == InputKind.CONSTANT_TENSOR and isinstance( + arg, TensorArgument + ): + # Remember constant tensor binding. + constant_tensors[placeholder_nodes[arg.name]] = prog.constants[ + input_spec.target + ] + elif input_spec.kind == InputKind.PARAMETER and isinstance( + arg, TensorArgument + ): + # Remember parameter binding. + value = prog.state_dict.get(input_spec.target) + assert ( + not input_spec.persistent or value is not None + ), "Expected state_dict value for persistent value" + node = placeholder_nodes[arg.name] + node_ir_type = self._cc.node_val_to_type(node, mutable=False) + parameter_bindings[node] = ( + value, + InputInfo(prog, input_spec, node=node, ir_type=node_ir_type), + ) + elif input_spec.kind == InputKind.BUFFER and isinstance( + arg, TensorArgument + ): + # Remember buffer binding. + value = prog.state_dict.get(input_spec.target) + assert ( + not input_spec.persistent or value is not None + ), "Expected state_dict value for persistent value" + node = placeholder_nodes[arg.name] + mutable_producer_node_name = mutable_buffer_target_producers.get( + input_spec.target + ) + node_ir_type = self._cc.node_val_to_type( + node, mutable=bool(mutable_producer_node_name) + ) + buffer_bindings[node] = ( + value, + InputInfo( + prog, + input_spec, + node=node, + ir_type=node_ir_type, + mutable_producer_node_name=mutable_producer_node_name, + ), + ) + else: + raise NotImplementedError( + f"InputSpec not of a known kind: {input_spec}" + ) + + ftype = FunctionType.get(user_input_types, user_output_types, context=self._c) + + # Create the function. + with loc: + func_op = func_dialect.FuncOp(func_name, ftype, ip=self._m_ip) + entry_block = Block.create_at_start(func_op.body, ftype.inputs) + + node_importer = GraphNodeImporter( + self, + self._c, + self._cc, + entry_block, + ) + + # Bind constants to IR values. + for constant_node, constant_tensor in constant_tensors.items(): + node_importer.import_constant(loc, constant_node, constant_tensor) + + # Bind user inputs to IR values. + for user_input_node, block_arg_value in zip(user_inputs, entry_block.arguments): + if user_input_node.name in mutated_user_inputs: + # Materialize + node_importer.import_mutable_to_vtensor( + loc, + user_input_node, + block_arg_value, + mutated_user_inputs[user_input_node.name], + ) + else: + # Normal value tensor binding. + node_importer.bind_node_value(user_input_node, block_arg_value) + + # Lazy bind buffer and parameter inputs. + for node, (parameter_value, info) in parameter_bindings.items(): + node_importer.lazy_import_parameter(loc, node, parameter_value, info) + for node, (buffer_value, info) in buffer_bindings.items(): + node_importer.lazy_import_buffer(loc, node, buffer_value, info) + + # Import all nodes and return. + node_importer.import_nodes( + all_producer_nodes.values(), skip_placeholders_outputs=True + ) + node_importer.return_node_values(loc, user_outputs) + self.symbol_table.insert(func_op) + + def import_frozen_program(self, prog: torch.export.ExportedProgram): """Imports a consolidated torch.export.ExportedProgram instance. If using the new torch.export path (vs a lower level precursor), then this is @@ -377,6 +610,10 @@ def import_frozen_exported_program(self, prog: torch.export.ExportedProgram): As we anticipate more nuanced treatment options in the future, we name this method to indicate that it is producing "frozen" modules. Additional top-level approaches to handling state can be introduced later as an addition. + + TODO: This mechanism should be eventually replaced by `import_program` with + hooks set on the subclass to freeze parameters and buffers. However, that is + waiting for the Torch 2.3 release cut. """ sig = prog.graph_signature state_dict = prog.state_dict @@ -391,7 +628,9 @@ def import_frozen_exported_program(self, prog: torch.export.ExportedProgram): try: state_value = constants[state_name] except KeyError as e: - raise AssertionError("Could not find state mapping for tensor constants") from e + raise AssertionError( + "Could not find state mapping for tensor constants" + ) from e arg_replacements[input_name] = state_value else: # Lift buffers. @@ -399,7 +638,9 @@ def import_frozen_exported_program(self, prog: torch.export.ExportedProgram): try: state_value = state_dict[state_name] except KeyError as e: - raise AssertionError("Could not find state mapping for buffer") from e + raise AssertionError( + "Could not find state mapping for buffer" + ) from e arg_replacements[input_name] = state_value # Lift parameters. @@ -426,11 +667,19 @@ def import_frozen_exported_program(self, prog: torch.export.ExportedProgram): self.import_stateless_graph(g) def import_graph_module(self, gm: GraphModule): - """Low-level import of a GraphModule assuming that it has been functionalized.""" + """Low-level import of a GraphModule assuming that it has been functionalized. + + TODO: This mechanism is deprecated by the `import_program` entry-point and + it should be removed when no longer required for backwards compatibility. + """ self.import_stateless_graph(gm.graph) def import_stateless_graph(self, g: Graph, func_name: str = "main"): - """Low-level import of a functionalized, assumed stateless Graph as a func.""" + """Low-level import of a functionalized, assumed stateless Graph as a func. + + TODO: This mechanism is deprecated by the `import_program` entry-point and + it should be removed when no longer required for backwards compatibility. + """ ftype, loc = self._graph_to_function_meta(g) # TODO: The FuncOp constructor requires a context-manager context. # Fix upstream and then unnest. @@ -447,7 +696,6 @@ def import_stateless_graph(self, g: Graph, func_name: str = "main"): self._c, self._cc, entry_block, - literal_resolver_callback=self._literal_resolver_callback, ) node_importer.import_nodes(g.nodes) self.symbol_table.insert(func) @@ -507,7 +755,9 @@ def __init__( ): self._c = context self._dtype_to_type: Dict[TorchDtype, IrType] = {} - self._tensor_metadata_cache: Dict[Tuple[torch.Size, torch.dtype], IrType] = {} + self._tensor_metadata_cache: Dict[ + Tuple[torch.Size, torch.dtype, Optional[SparsityMeta], bool], IrType + ] = {} self._py_attr_tracker = py_attr_tracker or RefTracker() # Common types. @@ -523,34 +773,34 @@ def integer_attr(self, value: int, bits: int) -> Attribute: c = self._c return IntegerAttr.get(IntegerType.get_signless(bits, c), value) - """Strips symbolic elements from a torch.Size object and returns shape asm""" - def format_asm_shape(self, shape: torch.Size) -> str: + """Strips symbolic elements from a torch.Size object and returns shape asm""" return ",".join("?" if is_symbolic(d) else str(d) for d in list(shape)) - """Return IrType for !torch.vtensor with the given shape and dtype""" - def get_vtensor_type( self, shape: torch.Size, dtype: torch.dtype, *, - sparsity: Optional[SparsityMeta] = None, # keyword-only + sparsity: Optional[SparsityMeta] = None, + mutable: bool = False, ): + """Return IrType for !torch.vtensor with the given shape and dtype""" + stem = "torch.tensor" if mutable else "torch.vtensor" shape_asm = self.format_asm_shape(shape) mlir_dtype = str(self.dtype_to_type(dtype)) if sparsity is not None: encoding = sparsity_encoding(shape, sparsity) assert encoding is not None return IrType.parse( - f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)},{encoding}>", + f"!{stem}<[{shape_asm}],{str(mlir_dtype)},{encoding}>", context=self._c, ) return IrType.parse( - f"!torch.vtensor<[{shape_asm}],{str(mlir_dtype)}>", context=self._c + f"!{stem}<[{shape_asm}],{str(mlir_dtype)}>", context=self._c ) - def node_val_to_type(self, node: torch_fx.Node) -> IrType: + def node_val_to_type(self, node: torch_fx.Node, *, mutable: bool = False) -> IrType: try: tensor_meta = node.meta.get("tensor_meta") val = node.meta.get("val") @@ -564,13 +814,15 @@ def node_val_to_type(self, node: torch_fx.Node) -> IrType: f"Quantized tensor meta data is not supported." ) else: - return self.tensor_metadata_to_type(tensor_meta, sparsity=sparsity) + return self.tensor_metadata_to_type( + tensor_meta, sparsity=sparsity, mutable=mutable + ) elif val is not None: # some nodes with symbolic inputs pass a 'val' attribute rather than # tensor_meta if isinstance(val, TorchFakeTensor): return self.get_vtensor_type( - val.size(), val.dtype, sparsity=sparsity + val.size(), val.dtype, sparsity=sparsity, mutable=mutable ) t = SCALAR_TYPE_TO_TORCH_MLIR_TYPE.get(type(val)) @@ -590,16 +842,19 @@ def tensor_metadata_to_type( self, tm: TensorMetadata, *, - sparsity: Optional[SparsityMeta] = None, # keyword-only + sparsity: Optional[SparsityMeta] = None, + mutable: bool = False, ) -> IrType: tm_shape = tuple( item.node if is_symbolic(item) else item for item in list(tm.shape) ) - key = (tm_shape, tm.dtype, sparsity) + key = (tm_shape, tm.dtype, sparsity, mutable) t = self._tensor_metadata_cache.get(key) if t is None: - t = self.get_vtensor_type(tm.shape, tm.dtype, sparsity=sparsity) + t = self.get_vtensor_type( + tm.shape, tm.dtype, sparsity=sparsity, mutable=mutable + ) self._tensor_metadata_cache[key] = t return t @@ -644,7 +899,7 @@ class GraphNodeImporter: "_b", "_c", "_cc", - "_literal_resolver_callback", + "_on_node_produced", "_v", "_multi_result_nodes", "fx_importer", @@ -656,21 +911,138 @@ def __init__( context: Context, context_cache: ContextCache, block: Block, - *, - literal_resolver_callback: Optional[LiteralResolverCallback] = None, ): self.fx_importer = fx_importer self._c = context self._cc = context_cache self._b = block - # Map of (Node, result_index) to MLIR Value. - self._v: Dict[Tuple[torch_fx.Node, int], Value] = {} + # Map of (Node, result_index) to MLIR Value or a callback that lazily + # constructs and returns a value. + self._v: Dict[Union[Callable[[], Value], Tuple[torch_fx.Node, int]], Value] = {} + # Map of node name to hook that should be called when it is produced. + self._on_node_produced: dict[str, Callable[[Value], None]] = {} # Statically multi-result nodes which we have de-tupled are noted here. # They will have their getitem calls short-circuited. self._multi_result_nodes: Set[torch_fx.Node] = set() - self._literal_resolver_callback = literal_resolver_callback - def import_nodes(self, nodes: Sequence[torch_fx.Node]): + def bind_node_value( + self, + node: Node, + value: Union[Value, Callable[[], Value]], + result_index: int = 0, + ): + """Binds a node to a value (and asserts if already bound). + + This is used by outside callers. Many internal callers poke directly + into the dict. + """ + key = (node, result_index) + assert key not in self._v, f"Node already has a value: {node}" + self._v[key] = value + + producer_callback = self._on_node_produced.get(node.name) + if producer_callback is not None: + producer_callback(value) + + def resolve_node_value(self, node: Node, result_index: int = 0) -> Value: + """Resolves a node to a value.""" + key = (node, result_index) + try: + binding = self._v[key] + except KeyError: + raise KeyError(f"FX Node {node} has not been bound to an MLIR value") + if isinstance(binding, Value): + return binding + + # It is a lazy callback. + value = binding() + self._v[key] = value + return value + + def import_mutable_to_vtensor( + self, loc: Location, node: Node, mutable_value: Value, producer_node_name: str + ) -> Value: + """Imports a node that is represented by a mutable IR value. + + This will generate and associate the following with the node: + %0 = torch.copy.to_vtensor {mutable_value} + + Then it will also add a trigger such that when `producer_node_name` is + produced, the following will be generated: + torch.overwrite.tensor.contents {producer}, {mutable_value} + """ + with loc, InsertionPoint(self._b): + immutable_type = self._cc.node_val_to_type(node) + copy_result = Operation.create( + "torch.copy.to_vtensor", + results=[immutable_type], + operands=[mutable_value], + ).result + self.bind_node_value(node, copy_result) + + # Add the producer trigger. + def on_produced(value: Value): + with loc, InsertionPoint(self._b): + Operation.create( + "torch.overwrite.tensor.contents", + results=[], + operands=[value, mutable_value], + ) + + self._on_node_produced[producer_node_name] = on_produced + return copy_result + + def import_constant(self, loc: Location, node: Node, constant: Any) -> Value: + with loc, InsertionPoint(self._b): + value = self._import_literal(constant) + self.bind_node_value(node, value) + return value + + def lazy_import_parameter( + self, loc, node: Node, parameter_value: Any, info: InputInfo + ): + def _on_access() -> Value: + with loc, InsertionPoint(self._b): + # TODO: Should go to a parameter binding hook. + return self._import_input(parameter_value, info) + + self.bind_node_value(node, _on_access) + + def lazy_import_buffer( + self, + loc, + node: Node, + buffer_value: Any, + info: InputInfo, + ): + def _on_access() -> Value: + with loc, InsertionPoint(self._b): + # TODO: Should go to a buffer binding hook. + return self._import_input(buffer_value, info) + + self.bind_node_value(node, _on_access) + + if info.mutable_producer_node_name is not None: + + def on_produced(value: Value): + mutable_buffer_value = self.resolve_node_value(node) + with loc, InsertionPoint(self._b): + Operation.create( + "torch.overwrite.tensor.contents", + results=[], + operands=[value, mutable_buffer_value], + ) + + self._on_node_produced[info.mutable_producer_node_name] = on_produced + + def return_node_values(self, loc, nodes: list[Node]): + with loc, InsertionPoint(self._b): + operands = [self.resolve_node_value(n) for n in nodes] + func_dialect.ReturnOp(operands, loc=loc) + + def import_nodes( + self, nodes: Sequence[Node], *, skip_placeholders_outputs: bool = False + ): with InsertionPoint(self._b): loc = Location.unknown() num_placeholders = 0 @@ -681,10 +1053,10 @@ def import_nodes(self, nodes: Sequence[torch_fx.Node]): new_loc = self._cc.get_node_location(node) if new_loc is not None: loc = new_loc - if op == "placeholder": + if op == "placeholder" and not skip_placeholders_outputs: # Associate the placeholder node with corresponding block # argument. - self._v[(node, 0)] = self._b.arguments[num_placeholders] + self.bind_node_value(node, self._b.arguments[num_placeholders]) num_placeholders += 1 elif op == "call_function": target = node.target @@ -696,9 +1068,10 @@ def import_nodes(self, nodes: Sequence[torch_fx.Node]): getitem_ref, getitem_index = node.args if getitem_ref in self._multi_result_nodes: try: - self._v[(node, 0)] = self._v[ - (getitem_ref, getitem_index) - ] + self.bind_node_value( + node, + self.resolve_node_value(getitem_ref, getitem_index), + ) except IndexError: raise RuntimeError( f"getitem de-aliasing failed. This likely " @@ -723,7 +1096,7 @@ def import_nodes(self, nodes: Sequence[torch_fx.Node]): raise NotImplementedError( f"FIX ME: Unimplemented call_function: target={node.target}, {node.meta}" ) - elif op == "output": + elif op == "output" and not skip_placeholders_outputs: # args[0] is a singleton tuple that we flatten into multiple # results. operands = [self._import_argument(loc, arg) for arg in node.args[0]] @@ -731,7 +1104,7 @@ def import_nodes(self, nodes: Sequence[torch_fx.Node]): def _promote_symbolic_scalar_int_float(self, loc, graph, param): temp_target = torch.ops.aten.Float.Scalar - temp_node = torch.fx.Node( + temp_node = Node( graph=graph, name=f"{str(param)}_as_float", op="call_function", @@ -756,11 +1129,7 @@ def _import_symbolic_torch_op( # operations on symbolic arguments as regular python expressions rather than as torch ops if is_builtin_function_or_method(target): arg_types = [ - ( - arg.meta["val"].node.pytype - if isinstance(arg, torch.fx.Node) - else type(arg) - ) + (arg.meta["val"].node.pytype if isinstance(arg, Node) else type(arg)) for arg in node.args ] is_int = [item == int for item in arg_types] @@ -776,7 +1145,7 @@ def _import_symbolic_torch_op( # promote int argument to float - following torch-mlir convention arg0, arg1 = node.args if is_int[0]: - if isinstance(arg0, torch.fx.Node): + if isinstance(arg0, Node): prom_arg = self._promote_symbolic_scalar_int_float( loc, node.graph, arg0 ) @@ -785,7 +1154,7 @@ def _import_symbolic_torch_op( arg0 = float(arg0) new_args = (arg0, arg1) else: - if isinstance(arg1, torch.fx.Node): + if isinstance(arg1, Node): prom_arg = self._promote_symbolic_scalar_int_float( loc, node.graph, arg1 ) @@ -923,7 +1292,7 @@ def _import_torch_op_overload( # Record value mapping. for i, value in enumerate(operation.results): - self._v[(node, i)] = value + self.bind_node_value(node, value, i) def _import_argument( self, loc: Location, arg: NodeArgument, expected_jit_type=None @@ -943,9 +1312,9 @@ def _import_argument( ), f"Attempting to retrieve attribute '{arg.target}' from module, but no such attribute exists" obj = getattr(gm, arg.target) with loc: - self._v[(arg, 0)] = self._import_literal(obj) + self.bind_node_value(arg, self._import_literal(obj)) - return self._v[(arg, 0)] + return self.resolve_node_value(arg) elif isinstance(arg, torch_fx.immutable_collections.immutable_list): return self._import_list_argument(loc, arg, expected_jit_type) elif isinstance(expected_jit_type, torch.TensorType) and not isinstance( @@ -959,12 +1328,10 @@ def _import_argument( def _import_literal(self, py_value: Any) -> Value: # Apply the conversion callback. - user_callback = self._literal_resolver_callback - if user_callback: - user_value = user_callback(py_value, self) - if user_value is not None: - assert isinstance(user_value, Value) - return user_value + user_value = self.fx_importer._hooks.resolve_literal(self, py_value) + if user_value is not None: + assert isinstance(user_value, Value) + return user_value # Default conversion path. converter = LITERAL_CONVERTER_MAP.lookup(type(py_value)) @@ -974,6 +1341,20 @@ def _import_literal(self, py_value: Any) -> Value: ) return converter(py_value, self, self._cc) + def _import_input(self, py_value: Any, info: InputInfo) -> Value: + # Try the hook. + user_value = self.fx_importer._hooks.resolve_input(self, py_value, info) + if user_value is not None: + assert isinstance(user_value, Value) + return user_value + + # Fall-back to treating as a literal if not mutating. + if info.mutable_producer_node_name is not None: + raise ValueError( + f"Cannot import {info.input_spec} as a literal because it is mutable" + ) + return self._import_literal(py_value) + def _import_scalar_as_tensor(self, loc: Location, arg: NodeArgument) -> Value: tensor_arg = torch.tensor(arg) result_type = self._cc.get_vtensor_type(tensor_arg.size(), tensor_arg.dtype) @@ -1020,10 +1401,10 @@ def _import_list_argument( for operand in arg: operand_type = type(operand) - if isinstance(operand, torch.fx.Node): + if isinstance(operand, Node): if operand in self._multi_result_nodes: raise RuntimeError(f"Attempt to de-reference a multi-result node") - val = self._v[(operand, 0)] + val = self.resolve_node_value(operand) val_type = str(val.type) assert ( isinstance(element_type, str) and element_type in val_type @@ -1099,8 +1480,8 @@ def _make_vtensor_literal_op( mapping = py_attr_tracker.track(tensor) if mapping.is_empty: # check support for bfloat16 - assert ( - not (tensor.dtype == torch.bfloat16 and ml_dtypes is None) + assert not ( + tensor.dtype == torch.bfloat16 and ml_dtypes is None ), f"torch.bfloat16 requires the ml_dtypes package, please run:\n\npip install ml_dtypes\n" # Resolve the attribute. npy_dtype = TORCH_DTYPE_TO_NPY_TYPE.get(tensor.dtype) diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 3abb70261db8..1f5aa8f74add 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -1,10 +1,17 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + from typing import Optional +import warnings + import torch import torch.export import torch.nn as nn -from torch_mlir.extras.fx_importer import FxImporter +from torch_mlir.extras.fx_importer import FxImporter, FxImporterHooks from torch_mlir import ir from torch_mlir.dialects import torch as torch_d from torch_mlir.extras.fx_decomp_util import get_decomposition_table @@ -14,15 +21,23 @@ def export_and_import( *args, fx_importer: Optional[FxImporter] = None, constraints: Optional[torch.export.Constraint] = None, + experimental_support_mutation: bool = False, + hooks: Optional[FxImporterHooks] = None, **kwargs, ): context = ir.Context() torch_d.register_dialect(context) if fx_importer is None: - fx_importer = FxImporter(context=context) + fx_importer = FxImporter(context=context, hooks=hooks) prog = torch.export.export(f, args, kwargs, constraints=constraints) decomp_table = get_decomposition_table() prog = prog.run_decompositions(decomp_table) - fx_importer.import_frozen_exported_program(prog) + if experimental_support_mutation: + if torch.__version__ < "2.3.0.dev20240207": + warnings.warn("Mutable program import only supported on PyTorch 2.3+") + fx_importer.import_program(prog) + else: + fx_importer.import_frozen_program(prog) + return fx_importer.module_op diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index d0b94ac83656..679eede5b1a7 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -130,7 +130,7 @@ def export_and_import(f, *args, **kwargs): torch_d.register_dialect(context) fx_importer = FxImporter(context=context) prog = sparse_export(f, args, kwargs) - fx_importer.import_frozen_exported_program(prog) + fx_importer.import_frozen_program(prog) return fx_importer.module diff --git a/test/python/fx_importer/v2.3/lit.local.cfg b/test/python/fx_importer/v2.3/lit.local.cfg new file mode 100644 index 000000000000..b10b239f8b3a --- /dev/null +++ b/test/python/fx_importer/v2.3/lit.local.cfg @@ -0,0 +1,9 @@ +config.unsupported = True + +try: + import torch + if torch.__version__ >= "2.3.0.dev20240207": + print("Enabling Torch v2.3+ tests") + config.unsupported = False +except ModuleNotFoundError: + ... diff --git a/test/python/fx_importer/v2.3/mutation_import.py b/test/python/fx_importer/v2.3/mutation_import.py new file mode 100644 index 000000000000..ef293b8cb134 --- /dev/null +++ b/test/python/fx_importer/v2.3/mutation_import.py @@ -0,0 +1,163 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +# RUN: %PYTHON %s | FileCheck %s + +from typing import Optional + +import torch +import torch.export +import torch.nn as nn + +from torch_mlir import fx + +from torch_mlir.ir import ( + Operation, +) + + +def run(f): + print(f"{f.__name__}") + print("-" * len(f.__name__)) + f() + print() + + +@run +# Tests that constants and parameters work generally with the mutation path. +# This doesn't do mutation but ensures that the basics remain functional. +# CHECK-LABEL: test_import_frozen_exported_program +# CHECK: func.func @main(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> +# CHECK-DAG: %[[a:.+]] = torch.vtensor.literal(dense_resource : tensor<1x4xf32>) : !torch.vtensor<[1,4],f32> +# CHECK-DAG: %[[b:.+]] = torch.vtensor.literal(dense_resource : tensor<3x1xf32>) : !torch.vtensor<[3,1],f32> +# CHECK-DAG: %[[p:.+]] = torch.vtensor.literal(dense<{{.*>+}} : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> +# CHECK-DAG: %[[tanh:.+]] = torch.aten.tanh %[[ARG0]] +# CHECK-DAG: %[[mul_a:.+]] = torch.aten.mul.Tensor %[[tanh]], %[[a]] +# CHECK-DAG: %[[mul_b:.+]] = torch.aten.mul.Tensor %[[mul_a]], %[[b]] +# CHECK-DAG: %[[mul_p:.+]] = torch.aten.mul.Tensor %[[mul_b]], %[[p]] +# CHECK: return %[[mul_p]] +def test_import_frozen_exported_program(): + @torch._dynamo.assume_constant_result + def get_a(): + return torch.randn(1, 4) + + class Basic(nn.Module): + def __init__(self): + super().__init__() + self.b = torch.randn(3, 1) + self.p = nn.Parameter(torch.randn(1, 1)) + + def forward(self, x): + return torch.tanh(x) * get_a() * self.b * self.p + + m = fx.export_and_import( + Basic(), torch.randn(3, 4), experimental_support_mutation=True + ) + print(m) + m.operation.verify() + + +@run +# CHECK-LABEL: test_user_input_mutate +# CHECK: func.func @main(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.tensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> +# CHECK-DAG: %[[arg1_copy:.+]] = torch.copy.to_vtensor %arg1 : !torch.vtensor<[3,4],f32> +# CHECK-DAG: %[[arg1_mul:.+]] = torch.aten.mul.Tensor %[[arg1_copy]], %arg0 +# CHECK-DAG: torch.overwrite.tensor.contents %[[arg1_mul]] overwrites %arg1 +# CHECK-DAG: %[[arg0_mul:.+]] = torch.aten.mul.Tensor %arg0, %[[arg1_mul]] +# CHECK: return %[[arg0_mul]] +def test_user_input_mutate(): + class Basic(nn.Module): + def forward(self, x, y): + y.mul_(x) + return x * y + + m = fx.export_and_import( + Basic(), + torch.randn(3, 4), + torch.randn(3, 4), + experimental_support_mutation=True, + ) + print(m) + m.operation.verify() + + +@run +# CHECK-LABEL: test_frozen_buffer +# CHECK: %[[buffer_literal:.+]] = torch.vtensor.literal +# CHECK: %[[mul:.+]] = torch.aten.mul.Tensor %arg0, %0 +# CHECK: return %[[mul]] +def test_frozen_buffer(): + class Basic(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buffer", torch.randn(3, 4)) + + def forward(self, x): + return x * self.buffer + + m = fx.export_and_import( + Basic(), torch.randn(3, 4), experimental_support_mutation=True + ) + print(m) + m.operation.verify() + + +class ExternalBufferHooks(fx.FxImporterHooks): + def prepare_module(self, module_op: Operation): + module_op.context.allow_unregistered_dialects = True + + def resolve_input(self, gni, value, info): + return Operation.create( + "my_dialect.import_buffer", results=[info.ir_type] + ).result + + +@run +# CHECK-LABEL: test_mutable_buffer +# CHECK: %[[buffer:.+]] = "my_dialect.import_buffer"() : () -> !torch.tensor<[3,4],f32> +# CHECK: %[[mul:.+]] = torch.aten.mul.Tensor %[[buffer]], %arg0 +# CHECK: torch.overwrite.tensor.contents %[[mul]] overwrites %[[buffer]] +# CHECK: return %arg0 +def test_mutable_buffer(): + class Basic(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buffer", torch.randn(3, 4)) + + def forward(self, x): + self.buffer.mul_(x) + return x + + m = fx.export_and_import( + Basic(), + torch.randn(3, 4), + experimental_support_mutation=True, + hooks=ExternalBufferHooks(), + ) + print(m) + m.operation.verify() + + +@run +# CHECK-LABEL: test_mutable_buffer_not_supported_from_literal +# CHECK: ERROR: Cannot import {{.*}} as a literal because it is mutable +def test_mutable_buffer_not_supported_from_literal(): + class Basic(nn.Module): + def __init__(self): + super().__init__() + self.register_buffer("buffer", torch.randn(3, 4)) + + def forward(self, x): + self.buffer.mul_(x) + return x + + try: + m = fx.export_and_import( + Basic(), + torch.randn(3, 4), + experimental_support_mutation=True, + ) + except ValueError as e: + print("ERROR:", e) From c5d8c12469d7b7badd35369106c4b975f718536c Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Fri, 16 Feb 2024 13:02:00 -0800 Subject: [PATCH 201/283] [torch-mlir][sparse][NFC] fixed typo (#2917) grammar police --- test/python/fx_importer/sparse_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 679eede5b1a7..e936e40cb039 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -98,7 +98,7 @@ def sparse_export( annotation sparse parameters with their actual sparse layout attributes. This temporary solution accelerates testing torch-mlir with PyTorch sparse tensors until the issue is - resovled. + resolved. """ # Convert all arguments to dense. dargs = tuple(a.to_dense() if a.layout in SPARSE_LAYOUTS else a for a in args) From 468c5339424e0f42d474106943c750de5519ff4d Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 16 Feb 2024 13:04:47 -0800 Subject: [PATCH 202/283] [onnx] Fix crash when negative transpose values exist (#2915) We are crashing due to indexing into a negative shape. Updated the lowering to avoid the crash. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 8e46aa9ec7ed..525335161db8 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1447,7 +1447,17 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( current[i] = i; } - // Convert dynamic shape dimension. + for (auto &dim : permutations) + dim = dim < 0 ? dim + rank : dim; + + // We need to override to the destination if known: + if (resultType.hasSizes()) { + for (int i = 0; i < rank; ++i) { + shape[permutations[i]] = resultType.getSizes()[i]; + } + } + + // Convert dynamic shape dimension: for (unsigned i = 0; i < shape.size(); i++) { if (shape[i] == ShapedType::kDynamic) shape[i] = Torch::kUnknownSize; From 7a0d0e954b145d28c6e495b5324d11cb03402f60 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 16 Feb 2024 13:05:44 -0800 Subject: [PATCH 203/283] [onnx] Fix onnx.gather lowering to use torch.aten.index_select (#2913) Onnx's gather maps directly to `torch.aten.index_select`. We should just use that path. --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 186 +++++------------- projects/pt1/e2e_testing/xfail_sets.py | 5 +- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 43 ++-- 3 files changed, 68 insertions(+), 166 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 9b2f3673cf33..50d4fae53812 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -471,146 +471,66 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.s64IntegerAttr(axis, "axis", 0)) return failure(); Location loc = binder.getLoc(); + auto ctx = binder.op->getContext(); + auto indicesTy = cast(indices.getType()); + auto dataTy = cast(data.getType()); + if (!dataTy || !dataTy.hasSizes()) + return failure(); + if (axis < 0) + axis += dataTy.getSizes().size(); - // 1. Get data shape and rank. - auto dataTensorType = data.getType().cast(); - if (!dataTensorType || !dataTensorType.hasSizes()) { - return rewriter.notifyMatchFailure(binder.op, - "Expect non empty input data"); - } - ArrayRef dataShape = dataTensorType.getSizes(); - unsigned dataRank = dataShape.size(); - - // 2. Get indices shape and rank. - auto indexType = indices.getType().cast(); - if (!indexType || !indexType.hasSizes()) { - return rewriter.notifyMatchFailure(binder.op, - "Expect non empty index tensor"); - } - ArrayRef indexShape = indexType.getSizes(); - unsigned indexRank = indexShape.size(); + Value index = rewriter.create( + loc, Torch::IntType::get(ctx), rewriter.getI64IntegerAttr(axis)); - // 3. Compute total elements in the indices tensor, as we will collapse - // the indices tensor to a unary tensor. Also compute index shape and - // data shape tensors as they will be used for creating output types. - int64_t indexElemCount = 1; - for (int64_t dim : indexShape) { - if (dim == Torch::kUnknownSize) { - indexElemCount = Torch::kUnknownSize; + // Apply bounds checking on the input: + auto intTy = rewriter.getType(); + auto boolTy = rewriter.getType( + indicesTy.getSizes(), rewriter.getI1Type()); + Value zero = rewriter.create( + loc, intTy, rewriter.getI64IntegerAttr(0)); + Value one = rewriter.create( + loc, intTy, rewriter.getI64IntegerAttr(1)); + Value lt = + rewriter.create(loc, boolTy, indices, zero); + Value dim = + rewriter.create(loc, intTy, data, index); + Value add = rewriter.create(loc, indicesTy, + indices, dim, one); + indices = rewriter.create(loc, indicesTy, lt, + add, indices); + + auto intListTy = rewriter.getType( + rewriter.getType()); + auto indicesSize = + rewriter.create(loc, intListTy, indices); + + // Determine the collapsed dim size: + auto indicesCt = 1; + for (auto sz : indicesTy.getSizes()) { + if (sz == Torch::kUnknownSize) { + indicesCt = Torch::kUnknownSize; break; } - indexElemCount *= dim; - } - - Value constOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - SmallVector indexShapeTensor; - Value indexElemCountVal = constOne; - for (unsigned i = 0; i < indexRank; ++i) { - Value indexDimVal = rewriter.create( - loc, indices, - rewriter.create( - loc, rewriter.getI64IntegerAttr(i))); - indexShapeTensor.emplace_back(indexDimVal); - indexElemCountVal = rewriter.create( - loc, indexElemCountVal, indexDimVal); - } - - SmallVector dataShapeTensor; - for (unsigned i = 0; i < dataRank; ++i) { - dataShapeTensor.emplace_back(rewriter.create( - loc, data, - rewriter.create( - loc, rewriter.getI64IntegerAttr(i)))); - } - // Correct for negative axis: - if (axis < 0) - axis += dataRank; - - // 4. We can not directly perform torch.gather as the onnx.gather op - // collects the input data at different location of output compared to - // torch.gather op. The output of torch.gather and onnx.gather ops are - // indexed differently. - // check https://onnx.ai/onnx/operators/onnx__Gather.html for more - // details. So we will collapse indices tensor to a unary tensor and - // materialize to non-axis dimension of data tensor. For example, - // assuming indices is of shape (4, 5, 6), data is (8, 10, 11, 12) and - // axis=1. we will collapse indices into a (120,) unary tensor, - // materialize to non-axis dimension of data i.e. reshaping the unary - // indices tensor to (1, 120, 1, 1) and then perform the torch.gather - // operation. Now broadcast the output of gather operation to non-axis - // dimensions of data tensor. This would make the result of shape (8, - // 10, 120, 12). Post the broadcasting, expand the indices dimensions by - // reshaping (8, 10, 120, 12) to (8, 10, 4, 5, 6, 12) tensor, which is - // our expected final result. - SmallVector collapsedIndexShape(dataRank, 1); - collapsedIndexShape[axis] = indexElemCount; - Type collapsedIndexType = Torch::ValueTensorType::get( - indexType.getContext(), llvm::ArrayRef(collapsedIndexShape), - indexType.getOptionalDtype()); - - SmallVector collapsedIndexSize(dataRank, constOne); - collapsedIndexSize[axis] = indexElemCountVal; - auto collapsedIndexSizeList = - rewriter.create( - loc, - rewriter.getType( - rewriter.getType()), - collapsedIndexSize); - - auto collapsedIndices = rewriter.create( - loc, collapsedIndexType, indices, collapsedIndexSizeList); - - // 5. Compute gather result type and perform gather operation. - Type gatherResultType = Torch::ValueTensorType::get( - dataTensorType.getContext(), llvm::ArrayRef(collapsedIndexShape), - dataTensorType.getOptionalDtype()); - Value constAxis = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), axis)); - Value constFalse = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getBoolAttr(false)); - auto gatherOp = rewriter.create( - loc, gatherResultType, data, constAxis, collapsedIndices, - /*sparseGrad=*/constFalse); - - // 6. Broadcast the gather output to non-axis dimensions of data tensor. - SmallVector dataShapeVector(dataShape); - dataShapeVector[axis] = indexElemCount; - Type expandResultType = Torch::ValueTensorType::get( - dataTensorType.getContext(), llvm::ArrayRef(dataShapeVector), - dataTensorType.getOptionalDtype()); - - dataShapeTensor[axis] = indexElemCountVal; - auto expandSizeList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(data.getContext())), - dataShapeTensor); - auto expandedGather = rewriter.create( - loc, expandResultType, gatherOp, expandSizeList, - /*implicit=*/constFalse); - - // 7. Compute the result type of reshape op which expands the collapsed - // indices shapes back to the original indices shapes and reshape the - // output produced at step 6. This will produce our expected result of - // onnx.gather op. - SmallVector resultShapeTensor; - for (unsigned i = 0; i < dataRank; ++i) { - if (i == axis) { - resultShapeTensor.insert(resultShapeTensor.end(), - indexShapeTensor.begin(), - indexShapeTensor.end()); - continue; - } - resultShapeTensor.emplace_back(dataShapeTensor[i]); + indicesCt *= sz; } - auto resultSizeList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(data.getContext())), - resultShapeTensor); - rewriter.replaceOpWithNewOp( - binder.op, resultType, expandedGather, resultSizeList); + auto flattenTy = rewriter.getType( + SmallVector{indicesCt}, indicesTy.getOptionalDtype()); + Value rank = rewriter.create(loc, intTy, indices); + Value end = rewriter.create(loc, rank, one); + indices = rewriter.create( + loc, flattenTy, indices, zero, end); + + llvm::SmallVector gatherShape(dataTy.getSizes()); + gatherShape[axis] = indicesCt; + + auto gatherTy = rewriter.getType( + gatherShape, dataTy.getOptionalDtype()); + Value gather = rewriter.create( + loc, gatherTy, data, index, indices); + rewriter.replaceOpWithNewOp( + binder.op, resultType, gather, index, indicesSize); return success(); }); patterns.onOp( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 440b7d730c93..66fbc41588e6 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2134,17 +2134,14 @@ "AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic", "ElementwiseSeluModule_basic", "EmbeddingModule1DIndices_basic", - "EmbeddingModuleI32Static_basic", "FlipNegativeIndexModule_basic", "HardsigmoidModule_basic", "HardsigmoidRandomModule_basic", "IndexSelectDynamicIndexSizeModule_basic", "IndexSelectDynamicInputSizeModule_basic", "IndexSelectDynamicModulebasic", - "IndexSelectNegativeDimModule_basic", - "IndexSelectSingleIdxModule_basic", - "IndexSelectTwoIdxModule_basic", "IndexSelectWholeDimensionModule_basic", + "IndexSelectWholeTensorModule_basic", "IndexTensorStaticModule_basic", "IndexTensorStaticNonContiguousWithNoneModule_basic", "PixelShuffleModuleStaticRank4Float32_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 0a154db29323..c5b28156a88e 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -39,35 +39,20 @@ func.func @test_less(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[ // CHECK-LABEL: func.func @test_gather func.func @test_gather(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} { - // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 - // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[ARG1_SIZE0:.+]] = torch.aten.size.int %arg1, %[[INT0]] - // CHECK: %[[MUL1:.+]] = torch.aten.mul.int %[[INT1]], %[[ARG1_SIZE0]] - // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 - // CHECK: %[[ARG1_SIZE1:.+]] = torch.aten.size.int %arg1, %[[INT1_0]] - // CHECK: %[[MUL2:.+]] = torch.aten.mul.int %[[MUL1]], %[[ARG1_SIZE1]] - // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: %[[ARG1_SIZE2:.+]] = torch.aten.size.int %arg1, %[[INT2]] : !torch.vtensor<[8,10,20,40],si64>, !torch.int -> !torch.int - // CHECK: %[[MUL3:.+]] = torch.aten.mul.int %[[MUL2]], %[[ARG1_SIZE2]] : !torch.int, !torch.int -> !torch.int - // CHECK-DAG: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: %[[ARG1_SIZE3:.+]] = torch.aten.size.int %arg1, %[[INT3]] : !torch.vtensor<[8,10,20,40],si64>, !torch.int -> !torch.int - // CHECK: %[[MUL4:.+]] = torch.aten.mul.int %[[MUL3]], %[[ARG1_SIZE3]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT0_2:.+]] = torch.constant.int 0 - // CHECK: %[[ARG0_SIZE0:.+]] = torch.aten.size.int %arg0, %[[INT0_2]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.int - // CHECK: %[[INT1_3:.+]] = torch.constant.int 1 - // CHECK: %[[ARG0_SIZE1:.+]] = torch.aten.size.int %arg0, %[[INT1_3]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.int - // CHECK: %[[INT2_4:.+]] = torch.constant.int 2 - // CHECK: %[[ARG0_SIZE2:.+]] = torch.aten.size.int %arg0, %[[INT2_4]] : !torch.vtensor<[3,4,5],f32>, !torch.int -> !torch.int - // CHECK: %[[LIST1:.+]] = torch.prim.ListConstruct %[[MUL4]], %[[INT1]], %[[INT1]] : (!torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %[[VIEW1:.+]] = torch.aten.view %arg1, %[[LIST1]] : !torch.vtensor<[8,10,20,40],si64>, !torch.list -> !torch.vtensor<[64000,1,1],si64> - // CHECK: %[[INT0_1:.+]] = torch.constant.int 0 - // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: %[[GATHER:.+]] = torch.aten.gather %arg0, %[[INT0_1]], %[[VIEW1]], %[[FALSE]] : !torch.vtensor<[3,4,5],f32>, !torch.int, !torch.vtensor<[64000,1,1],si64>, !torch.bool -> !torch.vtensor<[64000,1,1],f32> - // CHECK: %[[LIST2:.+]] = torch.prim.ListConstruct %[[MUL4]], %[[ARG0_SIZE1]], %[[ARG0_SIZE2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %[[EXPAND:.+]] = torch.aten.expand %[[GATHER]], %[[LIST2]], %[[FALSE]] : !torch.vtensor<[64000,1,1],f32>, !torch.list, !torch.bool -> !torch.vtensor<[64000,4,5],f32> - // CHECK: %[[LIST3:.+]] = torch.prim.ListConstruct %[[ARG1_SIZE0]], %[[ARG1_SIZE1]], %[[ARG1_SIZE2]], %[[ARG1_SIZE3]], %[[ARG0_SIZE1]], %[[ARG0_SIZE2]] : (!torch.int, !torch.int, !torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %[[RES:.+]] = torch.aten.view %[[EXPAND]], %[[LIST3]] : !torch.vtensor<[64000,4,5],f32>, !torch.list -> !torch.vtensor<[8,10,20,40,4,5],f32> - // CHECK: return %[[RES]] : !torch.vtensor<[8,10,20,40,4,5],f32> + // CHECK: %[[AXIS:.+]] = torch.constant.int 0 + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK: %[[ONE:.+]] = torch.constant.int 1 + // CHECK: %[[LT:.+]] = torch.aten.le.Scalar %arg1, %[[ZERO]] + // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]] + // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]] + // CHECK: %[[SEL:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %arg1 + // CHECK: %[[SZ:.+]] = torch.aten.size %[[SEL]] + // CHECK: %[[DIM:.+]] = torch.aten.dim %[[SEL]] + // CHECK: %[[SUB:.+]] = torch.aten.sub.int %[[DIM]], %[[ONE]] + // CHECK: %[[FLAT:.+]] = torch.aten.flatten.using_ints %[[SEL]], %[[ZERO]], %[[SUB]] + // CHECK: %[[ISEL:.+]] = torch.aten.index_select %arg0, %[[AXIS]], %[[FLAT]] + // CHECK: %[[RES:.+]] = torch.aten.unflatten.int %[[ISEL]], %[[AXIS]], %[[SZ]] + // CHECK: return %[[RES]] %0 = torch.operator "onnx.Gather"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32> return %0 : !torch.vtensor<[8,10,20,40,4,5],f32> } From d65925a8b465d4a84be947d37197178d5c5cc6d2 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 16 Feb 2024 13:35:25 -0800 Subject: [PATCH 204/283] [onnx] Fix `onnx.sigmoid` for integer inputs/outputs (#2914) Sample compilation crashes due to sigmoid with integer inputs/outputs. This fix avoids crashing but still experiences an error. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 5 +- .../TorchToLinalg/Uncategorized.cpp | 106 +++++++++--------- projects/pt1/e2e_testing/xfail_sets.py | 4 +- 3 files changed, 56 insertions(+), 59 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 525335161db8..405a02bb3c58 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1615,9 +1615,10 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( llvm::SmallVector intermediateShape(operandTy.getShape()); for (int i = 0, s = operandTy.getRank(); i < s; ++i) { - if (operandTy.getDimSize(i) != resultTy.getDimSize(i)) { + if (operandTy.getDimSize(i) != resultTy.getDimSize(i)) intermediateShape[i] = -1; - } + if (intermediateShape[i] == ShapedType::kDynamic) + intermediateShape[i] = Torch::kUnknownSize; } auto intermediateType = Torch::ValueTensorType::get( context, intermediateShape, resultTorchType.getOptionalDtype()); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 0019acfc2944..e8e671955835 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -128,16 +128,20 @@ static Value buildUnitNormalCdf(OpBuilder &b, Location &loc, Value x) { } template -static Value -createCalculationForMathOpWithDtypeConversion(OpBuilder &b, - const TypeConverter *converter, - Value payloadArg, Operation *op) { - Type dtype = converter->convertType(op->getResult(0).getType()) - .template cast() - .getElementType(); +static Value createFpOpWithDtype(OpBuilder &b, const TypeConverter *converter, + Value payloadArg, Operation *op) { + Type inTTy = cast(op->getOperand(0).getType()).getDtype(); + Type outTTy = cast(op->getResult(0).getType()).getDtype(); + Type outTy = + cast(converter->convertType(op->getResult(0).getType())) + .getElementType(); + Type computeTy = outTy; + if (isa(computeTy)) + computeTy = b.getF32Type(); Location loc = op->getLoc(); - Value arg = convertScalarToDtype(b, loc, payloadArg, dtype); - return b.create(loc, arg); + Value arg = convertScalarToDtype(b, loc, payloadArg, computeTy, inTTy); + auto newOp = b.create(loc, arg); + return convertScalarToDtype(b, loc, newOp, outTy, std::nullopt, outTTy); } template @@ -217,92 +221,70 @@ static Value createLinalgPayloadCalculationForElementwiseOp( if (isa(op)) return b.create(loc, payloadArgs[0]); if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (isa(op)) { - return createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + return createFpOpWithDtype(b, converter, payloadArgs[0], op); } if (auto clone = dyn_cast(op)) { int64_t memoryFormat; @@ -453,13 +435,25 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return createEqual(b, loc, abs.getType(), abs, infinity); } if (isa(op)) { - auto negate = createCalculationForMathOpWithDtypeConversion( - b, converter, payloadArgs[0], op); + Type inTTy = cast(op->getOperand(0).getType()).getDtype(); + Type outTTy = cast(op->getResult(0).getType()).getDtype(); + Type outTy = cast( + converter->convertType(op->getResult(0).getType())) + .getElementType(); + Type computeTy = outTy; + if (isa(computeTy)) + computeTy = b.getF32Type(); + + Value arg = payloadArgs[0]; + arg = convertScalarToDtype(b, loc, payloadArgs[0], computeTy, inTTy); + auto negate = b.create(loc, arg); auto one = b.create(loc, FloatAttr::get(negate.getType(), 1)); auto exp = b.create(loc, negate); auto added = b.create(loc, exp, one); - return b.create(loc, one, added); + auto div = b.create(loc, one, added); + outTy.dump(); + return convertScalarToDtype(b, loc, div, outTy, std::nullopt, outTTy); } if (auto relu = dyn_cast(op)) { if (!relu.getType() diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 66fbc41588e6..a1cee9037933 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2165,6 +2165,9 @@ "ReduceMaxKeepDimReturnBoth_basic", "ReduceMaxNegativeDim_basic", "ViewSizeFromOtherTensor_basic", + + # Failure - onnx traces differently + "ElementwiseSigmoidIntModule_basic", # Failure - unknown "ChunkListUnpackUneven_Module_basic", @@ -2192,7 +2195,6 @@ } ONNX_CRASHING_SET = { - "ElementwiseSigmoidIntModule_basic", "FlipModule_basic", "IndexTensorNegativeIndexModule_basic", "MoveDimIntNegativeIndexModule_basic", From 78e10ff09b78ba14c4c97100360205550c36669e Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Fri, 16 Feb 2024 20:56:42 -0800 Subject: [PATCH 205/283] [torch-mlir][sparse] inline sparse helper methods (#2918) Even though the reference compiler is not about performance, inlining the generated sparse helper methods has a rather big positive impact on performance, leaving a much better first impression. Therefore, we added this inlining pass (which leaves all other PyTorch modules unaffected, since they tend to be one big main() method to start with). testing: $./tools/e2e_test.sh --config linalg Summary: Passed: 1164 Expectedly Failed: 8 $ python -m e2e_testing.main --config=torchdynamo Summary: Passed: 976 Expectedly Failed: 162 --- .../torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py index 9c33d8fd504d..ad2669c51d50 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py +++ b/projects/pt1/python/torch_mlir_e2e_test/linalg_on_tensors_backends/refbackend.py @@ -139,6 +139,7 @@ def invoke(*args): "sparse-assembler", "sparsification-and-bufferization", "sparse-storage-specifier-to-llvm", + "inline", # inline sparse helper methods where useful # Bufferize. "func.func(scf-bufferize)", "func.func(tm-tensor-bufferize)", From d29157b33fb66a6ef37971452cc6f8399bfbf374 Mon Sep 17 00:00:00 2001 From: aldesilv <35313749+aldesilv@users.noreply.github.com> Date: Mon, 19 Feb 2024 06:23:48 -0800 Subject: [PATCH 206/283] OnnxToTorch support for onnx.InstanceNormalization op (#2710) https://github.com/nod-ai/SHARK-Turbine/issues/327 --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 31 ++++ .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 31 ++++ .../Transforms/AbstractInterpLibrary.cpp | 8 + .../Torch/Transforms/DecomposeComplexOps.cpp | 146 ++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 4 + .../build_tools/abstract_interp_lib_gen.py | 8 + .../build_tools/torch_ods_gen.py | 3 + .../test_suite/norm_like.py | 18 +++ .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 9 ++ 10 files changed, 259 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 0becb668636e..99d00e287106 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6203,6 +6203,37 @@ def Torch_AtenBatchNormOp : Torch_Op<"aten.batch_norm", [ }]; } +def Torch_AtenInstanceNormOp : Torch_Op<"aten.instance_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchOptionalTensorType:$weight, + AnyTorchOptionalTensorType:$bias, + AnyTorchOptionalTensorType:$running_mean, + AnyTorchOptionalTensorType:$running_var, + Torch_BoolType:$use_input_stats, + Torch_FloatType:$momentum, + Torch_FloatType:$eps, + Torch_BoolType:$cudnn_enabled + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenInstanceNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 9, 1); + } + void AtenInstanceNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 9, 1); + } + }]; +} + def Torch_AtenNativeGroupNormOp : Torch_Op<"aten.native_group_norm", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 50d4fae53812..12b7ab559f4f 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -392,6 +392,37 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, lhs, rhs); return success(); }); + patterns.onOp( + "InstanceNormalization", 6, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + llvm::SmallVector operands; + float eps; + + if (binder.tensorOperands(operands, 3) || + binder.tensorResultType(resultType) || operands.size() != 3 || + binder.f32FloatAttr(eps, "epsilon", 1e-05f)) { + return failure(); + } + Value none = rewriter.create(binder.getLoc()); + Value boolTrue = + rewriter.create(binder.getLoc(), true); + Value boolFalse = + rewriter.create(binder.getLoc(), false); + auto epsValue = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(eps)); + + auto momentum = rewriter.create( + binder.getLoc(), rewriter.getF64FloatAttr(0.0f)); + rewriter.replaceOpWithNewOp( + binder.op, resultType, /* input */ operands[0], + /* weight */ operands[1], + /* bias */ operands[2], /* running mean */ none, + /* running var */ none, + /* use input stats */ boolTrue, momentum, epsValue, + /* cudnn enabled */ boolFalse); + return success(); + }); patterns.onOp( "Max", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 29c94304288b..39813da66e85 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8784,6 +8784,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.TupleConstruct %0, %1, %2 : !torch.list, !torch.list, !torch.list -> !torch.tuple, list, list>\n" " return %3 : !torch.tuple, list, list>\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.instance_norm\"(%arg0: !torch.list, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.slice.Tensor\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.slice(%arg0, %arg1, %arg2, %arg3, %arg4) : (!torch.list, !torch.int, !torch.optional, !torch.optional, !torch.int) -> !torch.list\n" " return %0 : !torch.list\n" @@ -9643,6 +9647,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.TupleConstruct %0#1, %0#1, %0#1 : !torch.int, !torch.int, !torch.int -> !torch.tuple\n" " return %3 : !torch.tuple\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.instance_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional>, %arg2: !torch.optional>, %arg3: !torch.optional>, %arg4: !torch.optional>, %arg5: !torch.bool, %arg6: !torch.float, %arg7: !torch.float, %arg8: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.bernoulli_.float\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.any) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index abd716c56afa..f9c1f63b568c 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -3962,6 +3962,151 @@ class DecomposeAtenLayerNormOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenInstanceNormOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenInstanceNormOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + auto context = op.getContext(); + + auto inputTy = op.getInput().getType().cast(); + int64_t inputRank = inputTy.getSizes().size(); + auto reduceDimInts = + llvm::SmallVector({inputRank - 2, inputRank - 1}); + + SmallVector reducedShape(inputTy.getSizes()); + reducedShape[inputRank - 1] = 1; + reducedShape[inputRank - 2] = 1; + + Type dtype = inputTy.getOptionalDtype(); + Type reducedTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(reducedShape), dtype); + + auto sizeListType = ListType::get(IntType::get(context)); + SmallVector reduceDimVals; + reduceDimVals.reserve(reduceDimInts.size()); + std::transform(reduceDimInts.begin(), reduceDimInts.end(), + std::back_inserter(reduceDimVals), [&](int64_t d) { + return rewriter.create( + loc, rewriter.getI64IntegerAttr(d)); + }); + Value reduceDimList = + rewriter.create(loc, sizeListType, reduceDimVals); + Value cstTrue = rewriter.create(loc, true); + Value none = rewriter.create(loc); + + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + + // mean(x) + Value inputMean = rewriter.create( + loc, reducedTy, op.getInput(), reduceDimList, cstTrue, none); + + // x - mean(x) + Value inputMeanExpanded = + rewriter.create(loc, inputTy, inputMean, op.getInput()); + Value inputSubMean = rewriter.create( + loc, inputTy, op.getInput(), inputMeanExpanded, one); + // (x - mean(x))^2 + Value inputSubMeanSquare = rewriter.create( + loc, inputTy, inputSubMean, inputSubMean); + + Value variancesum = rewriter.create( + loc, reducedTy, inputSubMeanSquare, reduceDimList, cstTrue, + /*dtype=*/none); + + Value hw = rewriter.create( + loc, rewriter.getI64IntegerAttr(inputTy.getSizes()[inputRank - 1] * + inputTy.getSizes()[inputRank - 2])); + Value inputVar = + rewriter.create(loc, reducedTy, variancesum, hw); + + // rsqrt(var(x) + eps) + Value inputVarPlusEps = rewriter.create( + loc, reducedTy, inputVar, op.getEps(), one); + Value inputRsqrtVar = + rewriter.create(loc, reducedTy, inputVarPlusEps); + + // (x - mean(x)) * rsqrt(var(x) + eps) + Value inputRsqrtVarExpanded = rewriter.create( + loc, inputTy, inputRsqrtVar, op.getInput()); + Value inputNormalized = rewriter.create( + loc, inputTy, inputSubMean, inputRsqrtVarExpanded); + Value out = rewriter.create( + loc, op.getResult().getType(), inputNormalized); + + Value weight = op.getWeight(); + auto weightTy = weight.getType().cast(); + dtype = weightTy.getOptionalDtype(); + + SmallVector weightShape(weightTy.getSizes()); + SmallVector newWeightShape; + newWeightShape.push_back(1); + newWeightShape.append(weightShape); + + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Type newWeightTy = ValueTensorType::get( + op.getContext(), llvm::ArrayRef(newWeightShape), dtype); + weight = rewriter.create(loc, newWeightTy, weight, zero); + + Value two = rewriter.create( + loc, rewriter.getI64IntegerAttr(2)); + newWeightShape.push_back(1); + newWeightTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newWeightShape), dtype); + weight = rewriter.create(loc, newWeightTy, weight, two); + + Value three = rewriter.create( + loc, rewriter.getI64IntegerAttr(3)); + newWeightShape.push_back(1); + newWeightTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newWeightShape), dtype); + weight = rewriter.create(loc, newWeightTy, weight, three); + + Value weightExpanded = + rewriter.create(loc, inputTy, weight, op.getInput()); + + Value bias = op.getBias(); + auto biasTy = bias.getType().cast(); + dtype = biasTy.getOptionalDtype(); + + SmallVector biasShape(biasTy.getSizes()); + SmallVector newBiasShape; + newBiasShape.push_back(1); + newBiasShape.append(biasShape); + + Type newBiasTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newBiasShape), dtype); + bias = rewriter.create(loc, newBiasTy, bias, zero); + + newBiasShape.push_back(1); + newBiasTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newBiasShape), dtype); + bias = rewriter.create(loc, newBiasTy, bias, two); + + newBiasShape.push_back(1); + newBiasTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newBiasShape), dtype); + bias = rewriter.create(loc, newBiasTy, bias, three); + + Value biasExpanded = + rewriter.create(loc, inputTy, bias, op.getInput()); + + out = rewriter.create(loc, out.getType(), out, + weightExpanded); + out = rewriter.create(loc, out.getType(), out, + biasExpanded, one); + + rewriter.replaceOp(op, out); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenNativeLayerNormOp : public OpRewritePattern { @@ -6733,6 +6878,7 @@ class DecomposeComplexOpsPass DecomposeAtenAddCLikeOp>(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenAddCLikeOp>(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 306b2446adb6..5d3488b11aed 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -409,6 +409,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, }); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index a1cee9037933..0cb0888745b3 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -233,6 +233,7 @@ # END tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' # START tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' + "AtenInstanceNormModule_basic", "BatchNorm1DModule_basic", "BatchNorm1DWith2DInputModule_basic", "BatchNorm2DModule_basic", @@ -898,6 +899,7 @@ "AtenEyeModuleFalsePinMemory_basic", "AtenEyeModuleFloat2D_basic", "AtenRoundIntModule_basic", + "AtenInstanceNormModule_basic", "AtenToDeviceModule_basic", "BaddbmmBroadcast1DInputModule_basic", "BaddbmmBroadcast2DInputModule_basic", @@ -1306,6 +1308,8 @@ "Conv2dNoPaddingModule_basic", "Conv2dWithPaddingDilationStrideModule_basic", "Conv2dWithPaddingModule_basic", + + "AtenInstanceNormModule_basic", } LTC_CRASHING_SET = { diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index c014808af97a..a856ac02639a 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1415,6 +1415,9 @@ def aten〇group_norm〡shape(input: List[int], num_groups: int, weight: Optiona def aten〇native_group_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], N: int, C: int, HxW: int, group: int, eps: float) -> Tuple[List[int], List[int], List[int]]: return upstream_shape_functions.unary(input), [N, group], [N, group] +def aten〇instance_norm〡shape(input: List[int], weight: Optional[List[int]], bias: Optional[List[int]], running_mean: Optional[List[int]], running_var: Optional[List[int]], use_input_stats: bool, momentum: float, eps: float, cudnn_enabled: bool) -> List[int]: + return upstream_shape_functions.unary(input) + def aten〇slice〇Tensor〡shape(self: List[int], dim: int = 0, start: Optional[int] = None, end: Optional[int] = None, step: int = 1) -> List[int]: return upstream_shape_functions.slice(self, dim, start, end, step) @@ -2048,6 +2051,11 @@ def aten〇native_group_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_r assert not is_integer_dtype(input_dtype) return input_dtype, input_dtype, input_dtype +# device is not supported hence unable to check the dtype function +def aten〇instance_norm〡dtype(input_rank_dtype: Tuple[int, int], weight_rank_dtype: Optional[Tuple[int, int]], bias_rank_dtype: Optional[Tuple[int, int]], running_mean_rank_dtype: Optional[Tuple[int, int]], running_var_rank_dtype: Optional[Tuple[int, int]], use_input_stats: bool, momentum: float, eps: float, cudnn_enabled: bool) -> int: + input_rank, input_dtype = input_rank_dtype + return input_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇bernoulli_〇float〡dtype(self_rank_dtype: Tuple[int, int], p: float = 0.5, generator: Any = None) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 65e9f44c1126..1dc8585d7486 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -437,6 +437,9 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::batch_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)" ) + emit( + "aten::instance_norm : (Tensor, Tensor?, Tensor?, Tensor?, Tensor?, bool, float, float, bool) -> (Tensor)" + ) emit( "aten::native_group_norm : (Tensor, Tensor?, Tensor?, int, int, int, int, float) -> (Tensor, Tensor, Tensor)" ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py index 3b17f516f9e5..56821fb694f3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/norm_like.py @@ -489,3 +489,21 @@ def forward(self, x): def LayerNormNormalizeOverAllDimsModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 2, 3)) +class AtenInstanceNormModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 2, 1, 3], torch.float32, True), + ([2], torch.float32, True), + ([2], torch.float32, True) + ]) + def forward(self, x, w, b): + return torch.ops.aten.instance_norm(x, w, b, None, + None, True, 0.0, 1e-05, False) + +@register_test_case(module_factory=lambda: AtenInstanceNormModule()) +def AtenInstanceNormModule_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 2, 1, 3), tu.rand(2), tu.rand(2)) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index c5b28156a88e..8729e7f2dd5a 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -603,6 +603,15 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 // ----- +// CHECK-LABEL: func.func @test_instancenorm + func.func @test_instancenorm(%arg0: !torch.vtensor<[1,2,1,3],f32>, %arg1: !torch.vtensor<[2],f32>, %arg2: !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,1,3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.instance_norm %arg0, %arg1, %arg2, %none, %none, %true, %float0.000000e00, %float9.999990e-06, %false : !torch.vtensor<[1,2,1,3],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>, !torch.none, !torch.none, !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor<[1,2,1,3],f32> + %0 = torch.operator "onnx.InstanceNormalization"(%arg0, %arg1, %arg2) : (!torch.vtensor<[1,2,1,3],f32>, !torch.vtensor<[2],f32>, !torch.vtensor<[2],f32>) -> !torch.vtensor<[1,2,1,3],f32> + return %0 : !torch.vtensor<[1,2,1,3],f32> + } + +// ----- + // CHECK-LABEL: func.func @test_not_2d func.func @test_not_2d(%arg0: !torch.vtensor<[3,4],i1>) -> !torch.vtensor<[3,4],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 1 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.bitwise_not %arg0 : !torch.vtensor<[3,4],i1> -> !torch.vtensor<[3,4],i1> From fd08578bdb24e0b2a3c85c907b766c972509b634 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 19 Feb 2024 10:26:21 -0800 Subject: [PATCH 207/283] [torch] Support dynamic step size for `torch.slice` (#2922) For some reason we did not directly use the step size dynamically despite its constructed using the dynamic value. --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index add32ff05cd6..b5eea7da619a 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -73,13 +73,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, torchTypeEnd.getType().isa()) return rewriter.notifyMatchFailure(op, "unimplemented optional type arg"); - int64_t step; - if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) { - if (!op.getStep().getType().template isa()) - return op->emitError("unimplemented: step is not constant"); - step = 1; - } - + Value stepIndex = castIntToIndex(rewriter, loc, adaptor.getStep()); Value start = toPositiveValidDim(rewriter, loc, torchTypeStart, builtinTypeStart, zero, dimSize); Value end = toPositiveValidDim(rewriter, loc, torchTypeEnd, builtinTypeEnd, @@ -89,7 +83,6 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, Value endSgeStart = rewriter.create( loc, arith::CmpIPredicate::sge, end, start); end = rewriter.create(loc, endSgeStart, end, start); - Value stepIndex = rewriter.create(loc, step); // Slice logic: resultSize = floordiv(end - start + step - 1, step) resultShape = getTensorSizes(rewriter, loc, input); From cea51897a5255363f5f09dcb91433dfc11492598 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 19 Feb 2024 10:26:29 -0800 Subject: [PATCH 208/283] [onnx] Simplify onnx.slice lowering (#2919) Onnx slice lowering used arange needlessly instead of directly constructing the constant dimension values. This makes lowerings to linalg struggle as multiple folders are required to get what is a constant index value. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 32 +++++++---------- projects/pt1/e2e_testing/xfail_sets.py | 4 --- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 35 +++++++------------ 3 files changed, 24 insertions(+), 47 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 405a02bb3c58..bc2cde573967 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1540,15 +1540,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( if (binder.tensorOperandAtIndex(axes, 3)) { return failure(); } - } else { - // The default axes value is the range from 0 to the size of first - // dimension of `starts` and `ends`. - Value none = rewriter.create(loc); - Value arangeLength = rewriter.create( - loc, rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), startSize)); - axes = rewriter.create( - loc, startsTorchTy, arangeLength, none, none, none, none); } // Binding `steps` from its arguments or through a default value @@ -1579,14 +1570,16 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, "Expected the rank of starts and ends tensors to be 1 " "and their dimensions to match"); - auto axesTorchTy = axes.getType().cast(); - auto axesTy = - axesTorchTy.toBuiltinTensor().dyn_cast(); - int64_t numAxes = axesTy.getDimSize(0); + if (axes) { + auto axesTorchTy = axes.getType().cast(); + auto axesTy = + axesTorchTy.toBuiltinTensor().dyn_cast(); + int64_t numAxes = axesTy.getDimSize(0); - if (!(axesTy && numAxes == endSize)) - return rewriter.notifyMatchFailure( - binder.op, "Axes should be the same size of starts and ends"); + if (!(axesTy && numAxes == endSize)) + return rewriter.notifyMatchFailure( + binder.op, "Axes should be the same size of starts and ends"); + } auto stepsTy = steps.getType() .cast() @@ -1622,7 +1615,7 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( } auto intermediateType = Torch::ValueTensorType::get( context, intermediateShape, resultTorchType.getOptionalDtype()); - for (int i = 0; i < numAxes; ++i) { + for (int i = 0; i < endSize; ++i) { Value k = rewriter.create( loc, rewriter.getType(), @@ -1636,12 +1629,11 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value start = select(starts, kTensor); Value end = select(ends, kTensor); - Value axis = select(axes, kTensor); + Value axis = axes ? select(axes, kTensor) : k; Value step = select(steps, kTensor); auto sliceType = intermediateType; - if (i == numAxes - 1) - sliceType = resultTorchType; + sliceType = i == (endSize - 1) ? resultTorchType : sliceType; operand = rewriter.create( loc, sliceType, operand, axis, start, end, step); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 0cb0888745b3..7de8047fa98a 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2101,10 +2101,6 @@ "ReduceMaxSignedIntModule_basic", "ReduceMaxUnsignedIntModule_basic", - # Failure - slice_lowering - "ScaledDotProductAttentionDifferentModule_basic", - "ScaledDotProductAttentionSameModule_basic", - # Failure - view_lowering "AddSizeIntModule_basic", "ElementwiseFlattenBroadcastModule_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 9f2354d13e39..704e03acc1e2 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1157,9 +1157,6 @@ func.func @test_slice(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtenso func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[3],si64>, %arg2: !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { //CHECK: %[[NONE_1:.*]] = torch.constant.none //CHECK: %[[AXES_DEFAULT_SIZE:.*]] = torch.constant.int 3 - //CHECK: %[[DEFAULT_AXES:.*]] = torch.aten.arange %[[AXES_DEFAULT_SIZE:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]], %[[NONE_1:.*]] : !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64> - //CHECK: %[[NONE_2:.*]] = torch.constant.none - //CHECK: %[[DEFAULT_SIZE_AMOUNT:.*]] = torch.constant.int 3 //CHECK: %[[DEFAULT_SIZE_INPUT:.*]] = torch.prim.ListConstruct %[[DEFAULT_SIZE_AMOUNT:.*]] : (!torch.int) -> !torch.list //CHECK: %[[DEFAULT_SIZES:.*]] = torch.aten.ones %[[DEFAULT_SIZE_INPUT:.*]], %[[NONE_2:.*]], %[[NONE_2:.*]], %[[NONE_2:.*]], %[[NONE_2:.*]] : !torch.list, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[3],si64> //CHECK: %[[INDEX_TO_GRAB:.*]] = torch.constant.int 0 @@ -1170,11 +1167,9 @@ func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f3 //CHECK: %[[STARTS_ELEMENT_0:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int //CHECK: %[[ENDS_INDEX_VEC_0:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> //CHECK: %[[ENDS_ELEMENT_0:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int - //CHECK: %[[AXES_INDEX_VEC_0:.*]] = torch.aten.index_select %[[DEFAULT_AXES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> - //CHECK: %[[AXES_ELEMENT_0:.*]] = torch.aten.item %[[AXES_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int //CHECK: %[[STEPS_INDEX_VEC_0:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ZERO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> //CHECK: %[[STEPS_ELEMENT_0:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_0:.*]] : !torch.vtensor<[1],si64> -> !torch.int - //CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %arg0, %[[AXES_ELEMENT_0:.*]], %[[STARTS_ELEMENT_0:.*]], %[[ENDS_ELEMENT_0:.*]], %[[STEPS_ELEMENT_0:.*]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32> + //CHECK: %[[SLICE_0:.*]] = torch.aten.slice.Tensor %arg0, %[[CONST_0:.*]], %[[STARTS_ELEMENT_0:.*]], %[[ENDS_ELEMENT_0:.*]], %[[STEPS_ELEMENT_0:.*]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32> //CHECK: %[[CONST_1:.*]] = torch.constant.int 1 //CHECK: %[[ONE_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_1:.*]] : !torch.int -> !torch.vtensor<[1],si64> @@ -1182,11 +1177,9 @@ func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f3 //CHECK: %[[STARTS_ELEMENT_1:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int //CHECK: %[[ENDS_INDEX_VEC_1:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> //CHECK: %[[ENDS_ELEMENT_1:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int - //CHECK: %[[AXES_INDEX_VEC_1:.*]] = torch.aten.index_select %[[DEFAULT_AXES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> - //CHECK: %[[AXES_ELEMENT_1:.*]] = torch.aten.item %[[AXES_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int //CHECK: %[[STEPS_INDEX_VEC_1:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[ONE_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> //CHECK: %[[STEPS_ELEMENT_1:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_1:.*]] : !torch.vtensor<[1],si64> -> !torch.int - //CHECK: %[[TWO_INDEX_VEC:.*]] = torch.aten.slice.Tensor %[[SLICE_0:.*]], %[[AXES_ELEMENT_1:.*]], %[[STARTS_ELEMENT_1:.*]], %[[ENDS_ELEMENT_1:.*]], %[[STEPS_ELEMENT_1:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32> + //CHECK: %[[TWO_INDEX_VEC:.*]] = torch.aten.slice.Tensor %[[SLICE_0:.*]], %[[CONST_1:.*]], %[[STARTS_ELEMENT_1:.*]], %[[ENDS_ELEMENT_1:.*]], %[[STEPS_ELEMENT_1:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,?],f32> //CHECK: %[[CONST_2:.*]] = torch.constant.int 2 //CHECK: %[[TWO_INDEX_VEC:.*]] = torch.prim.NumToTensor.Scalar %[[CONST_2:.*]] : !torch.int -> !torch.vtensor<[1],si64> @@ -1194,11 +1187,9 @@ func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f3 //CHECK: %[[STARTS_ELEMENT_2:.*]] = torch.aten.item %[[STARTS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int //CHECK: %[[ENDS_INDEX_VEC_2:.*]] = torch.aten.index_select %arg2, %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> //CHECK: %[[ENDS_ELEMENT_2:.*]] = torch.aten.item %[[ENDS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int - //CHECK: %[[AXES_INDEX_VEC_2:.*]] = torch.aten.index_select %[[DEFAULT_AXES:.*]], %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> - //CHECK: %[[AXES_ELEMENT_2:.*]] = torch.aten.item %[[AXES_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int //CHECK: %[[STEPS_INDEX_VEC_2:.*]] = torch.aten.index_select %[[DEFAULT_SIZES:.*]], %[[INDEX_TO_GRAB:.*]], %[[TWO_INDEX_VEC:.*]] : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> //CHECK: %[[STEPS_ELEMENT_2:.*]] = torch.aten.item %[[STEPS_INDEX_VEC_2:.*]] : !torch.vtensor<[1],si64> -> !torch.int - //CHECK: torch.aten.slice.Tensor %[[TWO_INDEX_VEC:.*]], %[[AXES_ELEMENT_2:.*]], %[[STARTS_ELEMENT_2:.*]], %[[ENDS_ELEMENT_2:.*]], %[[STEPS_ELEMENT_2:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32> + //CHECK: torch.aten.slice.Tensor %[[TWO_INDEX_VEC:.*]], %[[CONST_2:.*]], %[[STARTS_ELEMENT_2:.*]], %[[ENDS_ELEMENT_2:.*]], %[[STEPS_ELEMENT_2:.*]] : !torch.vtensor<[20,10,?],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32> %0 = torch.operator "onnx.Slice"(%arg0, %arg1, %arg2) : (!torch.vtensor<[20,10,5],f32>, !torch.vtensor<[3],si64>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[20,10,1],f32> return %0 : !torch.vtensor<[20,10,1],f32> } @@ -1211,17 +1202,15 @@ func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f3 // CHECK-SAME: %[[ARG2:.*]]: !torch.vtensor<[1],si64> // CHECK: %[[ZERO0:.*]] = torch.constant.int 0 -// CHECK: %[[ZERO1:.*]] = torch.constant.int 0 -// CHECK: %[[SCALAR:.*]] = torch.prim.NumToTensor.Scalar %[[ZERO1]] : !torch.int -> !torch.vtensor<[1],si64> -// CHECK: %[[SELECT0:.*]] = torch.aten.index_select %[[ARG1]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> -// CHECK: %[[ITEM0:.*]] = torch.aten.item %[[SELECT0]] : !torch.vtensor<[1],si64> -> !torch.int -// CHECK: %[[SELECT1:.*]] = torch.aten.index_select %[[ARG2]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> -// CHECK: %[[ITEM1:.*]] = torch.aten.item %[[SELECT1]] : !torch.vtensor<[1],si64> -> !torch.int -// CHECK: %[[SELECT2:.*]] = torch.aten.index_select %{{.*}}, %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> -// CHECK: %[[ITEM2:.*]] = torch.aten.item %[[SELECT2]] : !torch.vtensor<[1],si64> -> !torch.int -// CHECK: %[[SELECT3:.*]] = torch.aten.index_select %{{.*}}, %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> -// CHECK: %[[ITEM3:.*]] = torch.aten.item %[[SELECT3]] : !torch.vtensor<[1],si64> -> !torch.int -// CHECK: torch.aten.slice.Tensor %[[ARG0]], %[[ITEM2]], %[[ITEM0]], %[[ITEM1]], %[[ITEM3]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32> +// CHECK-NEXT: %[[ZERO1:.*]] = torch.constant.int 0 +// CHECK-NEXT: %[[SCALAR:.*]] = torch.prim.NumToTensor.Scalar %[[ZERO1]] : !torch.int -> !torch.vtensor<[1],si64> +// CHECK-NEXT: %[[SELECT0:.*]] = torch.aten.index_select %[[ARG1]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> +// CHECK-NEXT: %[[ITEM0:.*]] = torch.aten.item %[[SELECT0]] : !torch.vtensor<[1],si64> -> !torch.int +// CHECK-NEXT: %[[SELECT1:.*]] = torch.aten.index_select %[[ARG2]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> +// CHECK-NEXT: %[[ITEM1:.*]] = torch.aten.item %[[SELECT1]] : !torch.vtensor<[1],si64> -> !torch.int +// CHECK-NEXT: %[[SELECT3:.*]] = torch.aten.index_select %{{.*}}, %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> +// CHECK-NEXT: %[[ITEM3:.*]] = torch.aten.item %[[SELECT3]] : !torch.vtensor<[1],si64> -> !torch.int +// CHECK: torch.aten.slice.Tensor %[[ARG0]], %[[ZERO1]], %[[ITEM0]], %[[ITEM1]], %[[ITEM3]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32> func.func @test_slice_default_axes_and_steps(%arg0: !torch.vtensor<[20,10,5],f32>, %arg1: !torch.vtensor<[1],si64>, %arg2: !torch.vtensor<[1],si64>) -> !torch.vtensor<[20,10,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { %0 = torch.operator "onnx.Slice"(%arg0, %arg1, %arg2) : (!torch.vtensor<[20,10,5],f32>, !torch.vtensor<[1],si64>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[20,10,1],f32> From e80054a3cca385bf50760ad43a6d8e8bb799001d Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 19 Feb 2024 10:28:23 -0800 Subject: [PATCH 209/283] [torch] Folders for `torch.aten.*.tensor` operators [add, sub, mul] (#2878) Simple folder for limited size aten tensor operations. This is primarily useful for shape computation folding as they unfortunately can use `aten` operators. Add, sub, mul are common examples of these folders. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 3 + lib/Dialect/Torch/IR/TorchOps.cpp | 213 ++++++++++++++++++ .../build_tools/torch_ods_gen.py | 6 +- test/Dialect/Torch/canonicalize.mlir | 5 +- .../Torch/torch-nary-canonicalize.mlir | 143 ++++++++++++ 5 files changed, 364 insertions(+), 6 deletions(-) create mode 100644 test/Dialect/Torch/torch-nary-canonicalize.mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 99d00e287106..5e4662369caf 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -3790,6 +3790,7 @@ def Torch_AtenMulTensorOp : Torch_Op<"aten.mul.Tensor", [ printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -3839,6 +3840,7 @@ def Torch_AtenAddTensorOp : Torch_Op<"aten.add.Tensor", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; let hasCanonicalizer = 1; } @@ -3889,6 +3891,7 @@ def Torch_AtenSubTensorOp : Torch_Op<"aten.sub.Tensor", [ printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; let hasCanonicalizer = 1; } diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index d831b70767c4..18c8501df3d0 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1106,6 +1106,177 @@ LogicalResult rewrite0DBinaryTensorOp(Operation *op, return success(); } +//===----------------------------------------------------------------------===// +// NAry folder helpers +//===----------------------------------------------------------------------===// + +static bool checkSameDTypes(llvm::ArrayRef attrs) { + bool allFp = true; + bool allInt = true; + + for (auto attr : attrs) { + if (!attr) + return false; + + Type attrty; + if (auto dense = dyn_cast_or_null(attr)) + attrty = dense.getType(); + if (auto fp = dyn_cast_or_null(attr)) + attrty = fp.getType(); + if (auto integer = dyn_cast_or_null(attr)) + attrty = integer.getType(); + if (auto shaped = dyn_cast_or_null(attrty)) + attrty = shaped.getElementType(); + allFp &= isa(attrty); + allInt &= isa(attrty); + } + + return allFp || allInt; +} + +static bool checkAllSplats(llvm::ArrayRef attrs) { + for (auto attr : attrs) { + if (auto dense = dyn_cast_or_null(attr)) { + if (!dense.isSplat()) + return false; + } + } + + return true; +} + +llvm::SmallVector getFoldValueAtIndexFp(llvm::ArrayRef attrs, + int64_t idx = 0) { + llvm::SmallVector splattrs; + + for (auto attr : attrs) { + if (auto dense = dyn_cast(attr)) { + if (dense.isSplat()) { + splattrs.push_back(dense.getSplatValue().convertToDouble()); + } else { + splattrs.push_back(dense.getValues()[idx].convertToDouble()); + } + } else if (auto intattr = dyn_cast(attr)) { + splattrs.push_back(intattr.getValueAsDouble()); + } else { + return {}; + } + } + + return splattrs; +} + +llvm::SmallVector getFoldValueAtIndexInt(llvm::ArrayRef attrs, + int64_t bitwidth, + int64_t idx = 0) { + llvm::SmallVector splattrs; + + for (auto attr : attrs) { + bool isunsigned = false; + if (auto dense = dyn_cast(attr)) { + isunsigned = dyn_cast(dense.getElementType()).isUnsigned(); + if (dense.isSplat()) { + splattrs.push_back(dense.getSplatValue()); + } else { + splattrs.push_back(dense.getValues()[idx]); + } + } else if (auto intattr = dyn_cast(attr)) { + isunsigned = cast(intattr.getType()).isUnsigned(); + splattrs.push_back(intattr.getValue()); + } else { + return {}; + } + + auto &apint = splattrs.back(); + if (apint.getBitWidth() < bitwidth) { + if (isunsigned) { + apint = apint.zextOrTrunc(bitwidth); + } else { + apint = apint.sextOrTrunc(bitwidth); + } + } + } + + return splattrs; +} + +using NAryFoldFpOperator = std::function)>; +using NAryFoldIntOperator = std::function)>; + +static OpFoldResult naryFolderHelper(ArrayRef operands, Type ty, + NAryFoldFpOperator fpFolder, + NAryFoldIntOperator intFolder) { + constexpr int64_t maxFold = 16; + if (!checkSameDTypes(operands)) + return nullptr; + + auto resultTy = dyn_cast(ty); + if (!resultTy || !resultTy.hasDtype() || !resultTy.hasSizes()) + return nullptr; + + auto dty = resultTy.getDtype(); + auto resultBTy = resultTy.toBuiltinTensor().clone(dty); + + auto fpTy = dyn_cast(dty); + auto intTy = dyn_cast(dty); + if (!fpTy && !intTy) + return nullptr; + + bool allSplats = checkAllSplats(operands); + bool withinMaxFold = + resultBTy.hasStaticShape() && resultBTy.getNumElements() <= maxFold; + + if (!allSplats && !withinMaxFold) + return nullptr; + + // We do not support broadcasting in the non-splat case so validate same + // shaped inputs / outputs: + if (!allSplats) { + auto resultShape = resultBTy.getShape(); + for (int i = 0, s = operands.size(); i < s; ++i) { + if (auto dense = dyn_cast(operands[i])) { + if (dense.isSplat()) + continue; + auto operandShape = cast(dense.getType()).getShape(); + if (operandShape.size() != resultShape.size()) + return nullptr; + for (int i = 0, s = operandShape.size(); i < s; ++i) + if (operandShape[i] != resultShape[i]) + return nullptr; + } + } + } + + const int64_t numValues = allSplats ? 1 : resultBTy.getNumElements(); + + if (fpTy) { + llvm::SmallVector folded; + for (int i = 0, s = numValues; i < s; ++i) { + auto inputs = getFoldValueAtIndexFp(operands, i); + double fold = fpFolder(inputs); + + APFloat val(fold); + bool unused; + val.convert(fpTy.getFloatSemantics(), APFloat::rmNearestTiesToEven, + &unused); + folded.push_back(val); + } + return DenseElementsAttr::get(resultBTy, folded); + } + + if (intTy) { + llvm::SmallVector folded; + for (int i = 0, s = numValues; i < s; ++i) { + auto inputs = + getFoldValueAtIndexInt(operands, dty.getIntOrFloatBitWidth(), i); + folded.push_back(intFolder(inputs)); + } + return DenseElementsAttr::get(resultBTy, folded); + } + + return nullptr; +} + //===----------------------------------------------------------------------===// // AtenAddTensorOp //===----------------------------------------------------------------------===// @@ -1116,6 +1287,20 @@ void AtenAddTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +OpFoldResult AtenAddTensorOp::fold(FoldAdaptor adaptor) { + auto fpFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 3); + return inputs[0] + (inputs[1] * inputs[2]); + }; + + auto intFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 3); + return inputs[0] + (inputs[1] * inputs[2]); + }; + + return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold); +} + //===----------------------------------------------------------------------===// // AtenAddScalarOp //===----------------------------------------------------------------------===// @@ -1136,6 +1321,20 @@ void AtenSubTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +OpFoldResult AtenSubTensorOp::fold(FoldAdaptor adaptor) { + auto fpFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 3); + return inputs[0] - (inputs[1] * inputs[2]); + }; + + auto intFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 3); + return inputs[0] - (inputs[1] * inputs[2]); + }; + + return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold); +} + //===----------------------------------------------------------------------===// // AtenSubScalarOp //===----------------------------------------------------------------------===// @@ -1166,6 +1365,20 @@ void AtenMulTensorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, }); } +OpFoldResult AtenMulTensorOp::fold(FoldAdaptor adaptor) { + auto fpFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 2); + return inputs[0] * inputs[1]; + }; + + auto intFold = [](llvm::ArrayRef inputs) { + assert(inputs.size() == 2); + return inputs[0] * inputs[1]; + }; + + return naryFolderHelper(adaptor.getOperands(), getType(), fpFold, intFold); +} + //===----------------------------------------------------------------------===// // AtenEqTensorOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 1dc8585d7486..bfbebf86be0b 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -340,9 +340,9 @@ def emit_with_mutating_variants(key, **kwargs): # Elementwise tensor compute ops that don't have the standard mutating # variants. emit_with_mutating_variants("aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)", has_canonicalizer=True) - emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) - emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) - emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)", has_canonicalizer=True, has_folder=True) + emit_with_mutating_variants("aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True) + emit_with_mutating_variants("aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)", has_canonicalizer=True, has_folder=True) emit_with_mutating_variants("aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 03eeaaeb525b..bb57135075bd 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1916,9 +1916,8 @@ func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor } // CHECK-LABEL: func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6]] = torch.constant.int 6 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR0]] : !torch.vtensor<[],si64> +// CHECK: %[[INT6:.+]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[INT6]] : !torch.vtensor<[],si64> func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %0 = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> %1 = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> diff --git a/test/Dialect/Torch/torch-nary-canonicalize.mlir b/test/Dialect/Torch/torch-nary-canonicalize.mlir new file mode 100644 index 000000000000..b0d22e35da9c --- /dev/null +++ b/test/Dialect/Torch/torch-nary-canonicalize.mlir @@ -0,0 +1,143 @@ +// RUN: torch-mlir-opt %s -canonicalize | FileCheck %s + +// CHECK-LABEL: @fold_aten_add_splat_int +func.func @fold_aten_add_splat_int() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<29> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %int2 = torch.constant.int 2 + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_add_splat_int_mismatch +func.func @fold_aten_add_splat_int_mismatch() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<29> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi32>) : !torch.vtensor<[4],si32> + %int2 = torch.constant.int 2 + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si32>, !torch.int -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_add_splat_float +func.func @fold_aten_add_splat_float() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<2.900000e+01> : tensor<4xf32>) + %int2 = torch.constant.float 2.0 + %cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_add_splat_float_mismatch +func.func @fold_aten_add_splat_float_mismatch() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<2.900000e+01> : tensor<4xf32>) + %int2 = torch.constant.float 2.0 + %cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf64>) : !torch.vtensor<[4],f64> + %cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],f64>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_add_arr0_int +func.func @fold_aten_add_arr0_int() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<[28, 29, 30, 31]> : tensor<4xsi64>) + %cst_7 = torch.vtensor.literal(dense<[6,7,8,9]> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %int2 = torch.constant.int 2 + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + + +// ----- + +// CHECK-LABEL: @fold_aten_add_arr1_int +func.func @fold_aten_add_arr1_int() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<[27, 29, 31, 33]> : tensor<4xsi64>) + %int2 = torch.constant.int 2 + %cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_11 = torch.vtensor.literal(dense<[10,11,12,13]> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + + +// ----- + +// CHECK-LABEL: @fold_aten_add_arr0_float +func.func @fold_aten_add_arr0_float() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<[2.800000e+01, 2.900000e+01, 3.000000e+01, 3.100000e+01]> : tensor<4xf32>) + %int2 = torch.constant.float 2.0 + %cst_7 = torch.vtensor.literal(dense<[6.0, 7.0, 8.0, 9.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %int2 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_add_arr1_float +func.func @fold_aten_add_arr1_float() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<[2.700000e+01, 2.900000e+01, 3.100000e+01, 3.300000e+01]> : tensor<4xf32>) + %fp_2 = torch.constant.float 2.0 + %cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_11 = torch.vtensor.literal(dense<[10.0,11.0,12.0,13.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.add.Tensor %cst_7, %cst_11, %fp_2 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_sub_splat_int +func.func @fold_aten_sub_splat_int() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<-15> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %int_2 = torch.constant.int 2 + %cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %0 = torch.aten.sub.Tensor %cst_7, %cst_11, %int_2 : !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64>, !torch.int -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_sub_splat_float +func.func @fold_aten_sub_splat_float() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<-1.500000e+01> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %fp_2 = torch.constant.float 2.0 + %cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.sub.Tensor %cst_7, %cst_11, %fp_2 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} + +// ----- + +// CHECK-LABEL: @fold_aten_mul_splat_int +func.func @fold_aten_mul_splat_int() -> !torch.vtensor<[4],si64> { + // CHECK: torch.vtensor.literal(dense<77> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_7 = torch.vtensor.literal(dense<7> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %cst_11 = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64> + %0 = torch.aten.mul.Tensor %cst_7, %cst_11: !torch.vtensor<[4],si64>, !torch.vtensor<[4],si64> -> !torch.vtensor<[4],si64> + return %0 : !torch.vtensor<[4],si64> +} + +// ----- + +// CHECK-LABEL: @fold_aten_mul_splat_float +func.func @fold_aten_mul_splat_float() -> !torch.vtensor<[4],f32> { + // CHECK: torch.vtensor.literal(dense<7.700000e+01> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_7 = torch.vtensor.literal(dense<7.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %cst_11 = torch.vtensor.literal(dense<11.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %0 = torch.aten.mul.Tensor %cst_7, %cst_11 : !torch.vtensor<[4],f32>, !torch.vtensor<[4],f32> -> !torch.vtensor<[4],f32> + return %0 : !torch.vtensor<[4],f32> +} From 135c81a4165f9e4c9070d72c485efece887d64f8 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 19 Feb 2024 11:55:54 -0800 Subject: [PATCH 210/283] [torch] Add folder for `prim.NumToTensor.Scalar` (#2921) Useful for `slice` lowerings that depend on tensors made form scalars. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + .../TorchToLinalg/Uncategorized.cpp | 1 - lib/Dialect/Torch/IR/TorchOps.cpp | 24 +++ projects/pt1/e2e_testing/xfail_sets.py | 1 + .../build_tools/torch_ods_gen.py | 2 +- test/Conversion/TorchToStablehlo/basic.mlir | 10 +- test/Dialect/Torch/canonicalize.mlir | 137 ++++++------------ 7 files changed, 77 insertions(+), 99 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 5e4662369caf..c5fec66913b0 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -15070,6 +15070,7 @@ def Torch_PrimNumToTensorScalarOp : Torch_Op<"prim.NumToTensor.Scalar", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; } def Torch_PrimMinSelfIntOp : Torch_Op<"prim.min.self_int", [ diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index e8e671955835..08d69ca718b9 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -452,7 +452,6 @@ static Value createLinalgPayloadCalculationForElementwiseOp( auto exp = b.create(loc, negate); auto added = b.create(loc, exp, one); auto div = b.create(loc, one, added); - outTy.dump(); return convertScalarToDtype(b, loc, div, outTy, std::nullopt, outTTy); } if (auto relu = dyn_cast(op)) { diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 18c8501df3d0..36e089fb28d3 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -3641,6 +3641,30 @@ OpFoldResult PrimMaxIntOp::fold(FoldAdaptor adaptor) { std::max(lhs.getValue().getSExtValue(), rhs.getValue().getSExtValue())); } +//===----------------------------------------------------------------------===// +// PrimNumToTensorScalarOp +//===----------------------------------------------------------------------===// + +OpFoldResult PrimNumToTensorScalarOp::fold(FoldAdaptor adaptor) { + Attribute a = adaptor.getA(); + auto resultTy = cast(getType()); + if (!a) + return {}; + if (!resultTy.hasDtype() || !resultTy.hasSizes()) + return {}; + + auto dty = resultTy.getDtype(); + if (auto iattr = dyn_cast(a)) { + a = IntegerAttr::get(dty, iattr.getInt()); + } else if (auto fattr = dyn_cast(a)) { + a = FloatAttr::get(dty, fattr.getValueAsDouble()); + } + + auto mlirTensorType = + RankedTensorType::get(resultTy.getSizes(), resultTy.getDtype()); + return SplatElementsAttr::get(mlirTensorType, a); +} + //===----------------------------------------------------------------------===// // PrimMinSelfIntOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 7de8047fa98a..52e1ea3321b8 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -419,6 +419,7 @@ "AtenEyeModuleFloat2D_basic", "AtenEyeModuleInt2D_basic", "AtenFloatScalarModule_basic", + "AtenInstanceNormModule_basic", "AtenIntBoolOpConstFalseModule_basic", "AtenIntBoolOpConstTrueModule_basic", "AtenIntBoolOpModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index bfbebf86be0b..64f03add759e 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -846,7 +846,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("prim::device : (Tensor) -> (Device)", has_canonicalizer=True) emit("prim::dtype : (Tensor) -> (int)", has_folder=True) emit("prim::TupleUnpack : (Any) -> (...)", has_canonicalizer=True) - emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)") + emit("prim::NumToTensor.Scalar : (Scalar) -> (Tensor)", has_folder=True) emit("prim::min.self_int : (int[]) -> (int)", has_folder=True) emit("prim::min.int : (int, int) -> (int)", has_folder=True) emit("prim::max.self_int : (int[]) -> (int)") diff --git a/test/Conversion/TorchToStablehlo/basic.mlir b/test/Conversion/TorchToStablehlo/basic.mlir index b502d3ffcce9..5f096205ea8c 100644 --- a/test/Conversion/TorchToStablehlo/basic.mlir +++ b/test/Conversion/TorchToStablehlo/basic.mlir @@ -27,13 +27,9 @@ func.func @torch.vtensor.literal$signed() -> !torch.vtensor<[2],si64> { // CHECK-LABEL: func.func @torch.prim.NumToTensor.Scalar$basic( // CHECK-SAME: ) -> !torch.vtensor<[],si64> { -// CHECK: %[[INT1:.*]] = torch.constant.int 1 -// CHECK: %[[T0:.*]] = torch_c.to_i64 %[[INT1]] -// CHECK: %[[T1:.*]] = tensor.from_elements %[[T0]] : tensor<1xi64> -// CHECK: %[[T2:.*]] = stablehlo.convert %[[T1]] : tensor<1xi64> -// CHECK: %[[T3:.*]] = stablehlo.reshape %[[T2]] : (tensor<1xi64>) -> tensor -// CHECK: %[[T4:.*]] = torch_c.from_builtin_tensor %[[T3]] : tensor -> !torch.vtensor<[],si64> -// CHECK: return %[[T4]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.*]] = stablehlo.constant dense<1> : tensor +// CHECK: %[[FROM:.*]] = torch_c.from_builtin_tensor %[[CST]] : tensor -> !torch.vtensor<[],si64> +// CHECK: return %[[FROM]] : !torch.vtensor<[],si64> func.func @torch.prim.NumToTensor.Scalar$basic() -> !torch.vtensor<[], si64> { %int1 = torch.constant.int 1 %0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[], si64> diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index bb57135075bd..4df52cfb174b 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1687,13 +1687,8 @@ func.func @torch.aten.Bool.int$fold_cst() -> !torch.bool { } // CHECK-LABEL: func.func @torch.aten.add.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR3]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.add.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 @@ -1705,11 +1700,8 @@ func.func @torch.aten.add.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor } // CHECK-LABEL: @torch.aten.add.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.add.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 @@ -1760,11 +1752,8 @@ func.func @prim.ListUnpack$fold_list(%arg0: !torch.vtensor<[2,3],f32>, %arg1: !t } // CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.*]] = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %int6 = torch.constant.int 6 %str = torch.constant.str "floor" @@ -1775,13 +1764,8 @@ func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d() -> !torch.vtenso } // CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR3]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int6 = torch.constant.int 6 %int2 = torch.constant.int 2 @@ -1793,11 +1777,8 @@ func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d() -> !torch.vt } // CHECK-LABEL: func.func @torch.aten.add.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR1]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.add.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 @@ -1808,9 +1789,8 @@ func.func @torch.aten.add.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor } // CHECK-LABEL: func.func @torch.aten.add.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR0]] : !torch.vtensor<[],si64> +// CHECK: %[[CST]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.add.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 @@ -1820,13 +1800,8 @@ func.func @torch.aten.add.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[], } // CHECK-LABEL: func.func @torch.aten.sub.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT_6:.*]] = torch.constant.int -6 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR3]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.sub.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 @@ -1838,11 +1813,8 @@ func.func @torch.aten.sub.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor } // CHECK-LABEL: @torch.aten.sub.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT_6:.*]] = torch.constant.int -6 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] func.func @torch.aten.sub.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 @@ -1854,11 +1826,8 @@ func.func @torch.aten.sub.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[], } // CHECK-LABEL: func.func @torch.aten.sub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT_6:.*]] = torch.constant.int -6 -// CHECK: %[[INT0:.*]] = torch.constant.int 0 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT0]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR1]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.sub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int0 = torch.constant.int 0 %int2 = torch.constant.int 2 @@ -1869,9 +1838,8 @@ func.func @torch.aten.sub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor } // CHECK-LABEL: func.func @torch.aten.sub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT_6:.*]] = torch.constant.int -6 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT_6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR0]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.sub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 @@ -1891,9 +1859,8 @@ func.func @torch.aten.sub.float$fold() -> !torch.float { } // CHECK-LABEL: func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6]] = torch.constant.int 6 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR0]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %int3 = torch.constant.int 3 %0 = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> @@ -1902,11 +1869,8 @@ func.func @torch.aten.mul.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[], } // CHECK-LABEL: func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR1]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 @@ -1916,8 +1880,8 @@ func.func @torch.aten.mul.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor } // CHECK-LABEL: func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6:.+]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> -// CHECK: return %[[INT6]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %0 = torch.vtensor.literal(dense<2> : tensor) : !torch.vtensor<[],si64> %1 = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> @@ -1926,13 +1890,8 @@ func.func @torch.aten.mul.Tensor$canonicalize_literal_0d() -> !torch.vtensor<[], } // CHECK-LABEL: func.func @torch.aten.mul.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[PR0:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<6> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.mul.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 @@ -1943,13 +1902,8 @@ func.func @torch.aten.mul.Tensor$canonicalize_numtotensor_0d() -> !torch.vtensor } // CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d_trunc() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[INT2:.*]] = torch.constant.int 2 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT2]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR3:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR3]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d_trunc() -> !torch.vtensor<[],si64> { %int6 = torch.constant.int 6 %int2 = torch.constant.int 2 @@ -1961,11 +1915,8 @@ func.func @torch.aten.div.Tensor_mode$canonicalize_numtotensor_0d_trunc() -> !to } // CHECK-LABEL: func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d_trunc() -> !torch.vtensor<[],si64> { -// CHECK: %[[INT3:.*]] = torch.constant.int 3 -// CHECK: %[[INT6:.*]] = torch.constant.int 6 -// CHECK: %[[PR1:.*]] = torch.prim.NumToTensor.Scalar %[[INT6]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[PR2:.*]] = torch.prim.NumToTensor.Scalar %[[INT3]] : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[PR2]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<3> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.div.Tensor_mode$canonicalize_literal_0d_trunc() -> !torch.vtensor<[],si64> { %int6 = torch.constant.int 6 %str = torch.constant.str "trunc" @@ -2151,9 +2102,8 @@ func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, // CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { -// CHECK: %int-1 = torch.constant.int -1 -// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[VAL_0]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { %int2 = torch.constant.int 2 %int3 = torch.constant.int 3 @@ -2163,11 +2113,8 @@ func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[] } // CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { -// CHECK: %int-1 = torch.constant.int -1 -// CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64> -// CHECK: %[[VAL_1:.*]] = torch.prim.NumToTensor.Scalar %int-1 : !torch.int -> !torch.vtensor<[],si64> -// CHECK: return %[[VAL_1]] : !torch.vtensor<[],si64> +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si64> +// CHECK: return %[[CST]] : !torch.vtensor<[],si64> func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtensor<[],si64> { %int1 = torch.constant.int 1 %int2 = torch.constant.int 2 @@ -2179,7 +2126,6 @@ func.func @torch.aten.rsub.Scalar$canonicalize_numtotensor_0d() -> !torch.vtenso // CHECK-LABEL: func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number { // CHECK: %int1 = torch.constant.int 1 -// CHECK: %[[VAL_0:.*]] = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64> // CHECK: %[[VAL_1:.*]] = torch.derefine %int1 : !torch.int to !torch.number // CHECK: return %[[VAL_1]] : !torch.number func.func @torch.aten.ScalarImplicit$canonicalize_numtotensor_0d() -> !torch.number { @@ -2347,6 +2293,17 @@ func.func @fold_aten_where_true_attr() -> !torch.vtensor<[4],si64> { // ----- +// CHECK-LABEL: @fold_prim_numtotensor_scalar +func.func @fold_prim_numtotensor_scalar() -> !torch.vtensor<[1],si64> { + %int42 = torch.constant.int 42 + // CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<42> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + // CHECK: return %[[TENSOR]] + %0 = torch.prim.NumToTensor.Scalar %int42 : !torch.int -> !torch.vtensor<[1],si64> + return %0 : !torch.vtensor<[1],si64> +} + +// ----- + // CHECK-LABEL: @fold_aten_where_false_attr func.func @fold_aten_where_false_attr() -> !torch.vtensor<[4],si64> { // CHECK: %[[RET:.+]] = torch.vtensor.literal(dense<11> : tensor<4xsi64>) : !torch.vtensor<[4],si64> From 4446fa00d8258311867496fc79d0b1dddd22a972 Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Tue, 20 Feb 2024 08:54:02 -0800 Subject: [PATCH 211/283] Migrate passes in TorchConversion to use FunctionOpInterface. (#2935) This enables better re-use in downstreams which use different func implementations and should have no impact on those that don't except in opt pipelines if using the old form. With interfaces, explicit pipelines via `--pass-pipeline=` must be used. --- .../Dialect/TorchConversion/Transforms/Passes.h | 10 ++++++---- .../Dialect/TorchConversion/Transforms/Passes.td | 6 +++--- .../Transforms/BackendTypeConversionPasses.cpp | 4 ++-- .../Transforms/ConvertCustomQuantOp.cpp | 2 +- lib/Dialect/TorchConversion/Transforms/PassDetail.h | 2 +- lib/Dialect/TorchConversion/Transforms/Passes.cpp | 1 + .../TorchConversion/Transforms/UnpackQuantTensor.cpp | 2 +- .../TorchConversion/convert-custom-quant-op.mlir | 2 +- .../finalizing-backend-type-conversion.mlir | 2 +- test/Dialect/TorchConversion/unpack-quant-tensor.mlir | 2 +- 10 files changed, 18 insertions(+), 15 deletions(-) diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h index d762bd840f7f..2f70cf990219 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.h @@ -10,7 +10,7 @@ #ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H #define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSES_H -#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -54,7 +54,7 @@ createVerifyStablehloBackendContractPass(); std::unique_ptr> createFuncBackendTypeConversionPass(); -std::unique_ptr> +std::unique_ptr> createFinalizingBackendTypeConversionPass(); // These passes do a one-off conversion of a specific kind of quantized group @@ -62,8 +62,10 @@ createFinalizingBackendTypeConversionPass(); // obviate them but that are being carried for now in order to unblock progress // on full integrations. See https://github.com/llvm/torch-mlir/issues/2417 for // the plan to support a more generalized lowering for these graphs. -std::unique_ptr> createUnpackQuantTensorPass(); -std::unique_ptr> createConvertCustomQuantOpPass(); +std::unique_ptr> +createUnpackQuantTensorPass(); +std::unique_ptr> +createConvertCustomQuantOpPass(); std::unique_ptr> createVerifyLinalgOnTensorsBackendContractPass(); diff --git a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td index 4d3e16a81c5c..73654c6f8034 100644 --- a/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td +++ b/include/torch-mlir/Dialect/TorchConversion/Transforms/Passes.td @@ -22,7 +22,7 @@ def FuncBackendTypeConversion : Pass<"torch-func-backend-type-conversion", "Modu } def FinalizingBackendTypeConversion - : Pass<"torch-finalizing-backend-type-conversion", "func::FuncOp"> { + : InterfacePass<"torch-finalizing-backend-type-conversion", "mlir::FunctionOpInterface"> { let summary = "Finalizes a partial conversion to builtin tensors"; let constructor = "mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass()"; @@ -51,12 +51,12 @@ def VerifyStablehloBackendContract : Pass<"torch-verify-stablehlo-backend-contra // The following passes are for a one-off conversion of a specific kind of quantized group matmul. // They should not be included in default lowering flows until further along. -def UnpackQuantTensor : Pass<"torch-unpack-quant-tensor", "func::FuncOp"> { +def UnpackQuantTensor : InterfacePass<"torch-unpack-quant-tensor", "mlir::FunctionOpInterface"> { let summary = "Unpack quantized int4 tensor from int8 containter"; let constructor = "mlir::torch::TorchConversion::createUnpackQuantTensorPass()"; } -def ConvertCustomQuantOp : Pass<"torch-convert-custom-quant-op", "func::FuncOp"> { +def ConvertCustomQuantOp : InterfacePass<"torch-convert-custom-quant-op", "mlir::FunctionOpInterface"> { let summary = "Convert torch custom quant op to linalg"; let constructor = "mlir::torch::TorchConversion::createConvertCustomQuantOpPass()"; } diff --git a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp index 5dd3d778f8f4..896dd9577617 100644 --- a/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp +++ b/lib/Dialect/TorchConversion/Transforms/BackendTypeConversionPasses.cpp @@ -115,7 +115,7 @@ static void setupFinalization(ConversionTarget &target, setupFinalization(target, patterns, typeConverter); } -static void stripTorchAttrs(func::FuncOp func) { +static void stripTorchAttrs(FunctionOpInterface func) { bool modified = false; SmallVector newAttrs; for (auto attr : func->getDialectAttrs()) { @@ -173,7 +173,7 @@ struct FinalizingBackendTypeConversionPass }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::TorchConversion::createFinalizingBackendTypeConversionPass() { return std::make_unique(); } diff --git a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp index 514d05234486..7bcb67b17c61 100644 --- a/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp +++ b/lib/Dialect/TorchConversion/Transforms/ConvertCustomQuantOp.cpp @@ -229,7 +229,7 @@ class ConvertCustomQuantOpPass }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::TorchConversion::createConvertCustomQuantOpPass() { return std::make_unique(); } diff --git a/lib/Dialect/TorchConversion/Transforms/PassDetail.h b/lib/Dialect/TorchConversion/Transforms/PassDetail.h index 224ad8e2d89a..cb80ebd89a3c 100644 --- a/lib/Dialect/TorchConversion/Transforms/PassDetail.h +++ b/lib/Dialect/TorchConversion/Transforms/PassDetail.h @@ -10,7 +10,7 @@ #ifndef TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H #define TORCHMLIR_DIALECT_TORCHCONVERSION_TRANSFORMS_PASSDETAIL_H -#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Interfaces/FunctionInterfaces.h" #include "mlir/Pass/Pass.h" namespace mlir { diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 673d7083f585..9ff447371a76 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -9,6 +9,7 @@ #include "torch-mlir/Dialect/TorchConversion/Transforms/Passes.h" #include "mlir/Conversion/Passes.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Func/Transforms/Passes.h" #include "mlir/Dialect/Linalg/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" diff --git a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp index 25f325399f12..064c87f6e6a8 100644 --- a/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp +++ b/lib/Dialect/TorchConversion/Transforms/UnpackQuantTensor.cpp @@ -137,7 +137,7 @@ class UnpackQuantTensorPass }; } // namespace -std::unique_ptr> +std::unique_ptr> mlir::torch::TorchConversion::createUnpackQuantTensorPass() { return std::make_unique(); } diff --git a/test/Dialect/TorchConversion/convert-custom-quant-op.mlir b/test/Dialect/TorchConversion/convert-custom-quant-op.mlir index 4f72f24e8868..7aca3551cfc2 100644 --- a/test/Dialect/TorchConversion/convert-custom-quant-op.mlir +++ b/test/Dialect/TorchConversion/convert-custom-quant-op.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt %s -torch-convert-custom-quant-op -split-input-file -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt %s '-pass-pipeline=builtin.module(func.func(torch-convert-custom-quant-op))' -split-input-file -verify-diagnostics | FileCheck %s // CHECK: #map = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK: #map1 = affine_map<(d0, d1, d2) -> (d0, d1, 0)> diff --git a/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir b/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir index 46f80c06b4ce..57077a723ada 100644 --- a/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir +++ b/test/Dialect/TorchConversion/finalizing-backend-type-conversion.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt %s -torch-finalizing-backend-type-conversion -split-input-file -verify-diagnostics -allow-unregistered-dialect | FileCheck %s +// RUN: torch-mlir-opt %s '-pass-pipeline=builtin.module(func.func(torch-finalizing-backend-type-conversion))' -split-input-file -verify-diagnostics -allow-unregistered-dialect | FileCheck %s // This test is largely copied from `finalizing-bufferize` upstream, as it // covers the same scope. diff --git a/test/Dialect/TorchConversion/unpack-quant-tensor.mlir b/test/Dialect/TorchConversion/unpack-quant-tensor.mlir index 0ca64ae09397..8fa1a775b66d 100644 --- a/test/Dialect/TorchConversion/unpack-quant-tensor.mlir +++ b/test/Dialect/TorchConversion/unpack-quant-tensor.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt %s -torch-unpack-quant-tensor -split-input-file -verify-diagnostics | FileCheck %s +// RUN: torch-mlir-opt %s '-pass-pipeline=builtin.module(func.func(torch-unpack-quant-tensor))' -split-input-file -verify-diagnostics | FileCheck %s // CHECK-LABEL: func @forward func.func @forward(%arg0: !torch.vtensor<[1,1,8],f16>) -> !torch.vtensor<[1,1,8],f16> { From 13553d49c9488e09fec6ba790fb095eea66c48ea Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 20 Feb 2024 09:30:30 -0800 Subject: [PATCH 212/283] [onnx] Update the importer to create a `none` for missing operands (#2931) Some operands are optional so we require a placeholder for missing operands. We invent an `onnx.None` operation as our placeholder. --- projects/pt1/e2e_testing/xfail_sets.py | 13 ++--- python/torch_mlir/extras/onnx_importer.py | 20 ++++++- .../python/onnx_importer/import_smoke_test.py | 53 ++----------------- 3 files changed, 25 insertions(+), 61 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 52e1ea3321b8..632b15e85c74 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2184,6 +2184,7 @@ "ElementwiseUnsqueezeNegDimsModule_basic", "ElementwiseWhereScalarModule_basic", "FlattenDynamicModule_basic", + "FlipModule_basic", "FlipModuleStaticShape_basic", "GluStaticModule_basic", "MaskedFillTensorFloatValueModule_basic", @@ -2193,17 +2194,9 @@ "ReduceMinAlongDimUnsignedInt_basic", "TensorsStackNegativeDimModule_basic", "TensorsStackPromoteDTypeModule_basic", -} -ONNX_CRASHING_SET = { - "FlipModule_basic", - "IndexTensorNegativeIndexModule_basic", "MoveDimIntNegativeIndexModule_basic", "PermuteNegativeIndexModule_basic", - "RollModule_basic", - "SliceModule_basic", - "SliceNegIdxModule_basic", - "SliceOutOfLowerBoundEndIndexModule_basic", - "SliceOutOfLowerBoundStartIndexModule_basic", - "SliceSizeTwoStepModule_basic", } + +ONNX_CRASHING_SET = { } diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index 24520e9ce970..c62324832520 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -258,6 +258,8 @@ def import_all(self): # much unused crap. for init in self._gi.initializer_map.values(): self.import_initializer(init) + + self.get_none() for node in self._gi.graph_proto.node: self.import_node(node) @@ -272,6 +274,20 @@ def import_all(self): with InsertionPoint(self._b), Location.unknown(): func_dialect.ReturnOp(outputs) + def get_none(self): + if '' in self._nv_map: + return self._nv_map[''] + + with InsertionPoint(self._b), Location.name("onnx_importer.none"): + nne = Operation.create( + name="torch.constant.none", + results=[self._cc.get_none_type()], + operands=[], + attributes={}, + ).results[0] + self._nv_map[''] = nne + return nne + def import_node(self, node: onnx.NodeProto): with InsertionPoint(self._b), Location.name(node.name): op_type = node.op_type @@ -283,7 +299,6 @@ def import_node(self, node: onnx.NodeProto): was_handled = getattr(self, special_key)(node) if was_handled: return - # General node import. input_values = [] for input_name in node.input: @@ -449,6 +464,9 @@ def tensor_element_type(self, elem_type: int) -> IrType: self._elem_type_map[elem_type] = t return t + def get_none_type(self): + return IrType.parse("!torch.none", context=self._c) + def get_vtensor_type( self, dims: tuple[Optional[int]], element_type: IrType ) -> IrType: diff --git a/test/python/onnx_importer/import_smoke_test.py b/test/python/onnx_importer/import_smoke_test.py index f27cc9caf5bd..708324e72db6 100644 --- a/test/python/onnx_importer/import_smoke_test.py +++ b/test/python/onnx_importer/import_smoke_test.py @@ -102,22 +102,12 @@ "node_test_castlike_FLOAT_to_STRING_model", "node_test_castlike_STRING_to_FLOAT_expanded_model", "node_test_castlike_STRING_to_FLOAT_model", - "node_test_center_crop_pad_crop_axes_chw_expanded_model", - "node_test_center_crop_pad_crop_axes_hwc_expanded_model", - "node_test_center_crop_pad_crop_negative_axes_hwc_expanded_model", - "node_test_clip_default_inbounds_model", - "node_test_clip_default_int8_inbounds_model", - "node_test_clip_default_int8_max_model", - "node_test_clip_default_max_model", "node_test_constantofshape_float_ones_model", "node_test_constantofshape_int_shape_zero_model", "node_test_constantofshape_int_zeros_model", "node_test_dequantizelinear_e4m3fn_model", "node_test_dequantizelinear_e4m3fn_zero_point_model", "node_test_dequantizelinear_e5m2_model", - "node_test_dft_axis_model", - "node_test_dft_inverse_model", - "node_test_dft_model", "node_test_equal_string_broadcast_model", "node_test_equal_string_model", "node_test_gru_defaults_model", @@ -175,8 +165,6 @@ "node_test_optional_get_element_optional_sequence_model", "node_test_optional_get_element_optional_tensor_model", "node_test_optional_get_element_sequence_model", - "node_test_optional_has_element_empty_no_input_name_optional_input_model", - "node_test_optional_has_element_empty_no_input_name_tensor_input_model", "node_test_optional_has_element_empty_optional_input_model", "node_test_optional_has_element_optional_input_model", "node_test_optional_has_element_tensor_input_model", @@ -187,43 +175,6 @@ "node_test_regex_full_match_basic_model", "node_test_regex_full_match_email_domain_model", "node_test_regex_full_match_empty_model", - "node_test_resize_downsample_scales_cubic_A_n0p5_exclude_outside_model", - "node_test_resize_downsample_scales_cubic_align_corners_model", - "node_test_resize_downsample_scales_cubic_antialias_model", - "node_test_resize_downsample_scales_cubic_model", - "node_test_resize_downsample_scales_linear_align_corners_model", - "node_test_resize_downsample_scales_linear_antialias_model", - "node_test_resize_downsample_scales_linear_half_pixel_symmetric_model", - "node_test_resize_downsample_scales_linear_model", - "node_test_resize_downsample_scales_nearest_model", - "node_test_resize_downsample_sizes_cubic_antialias_model", - "node_test_resize_downsample_sizes_cubic_model", - "node_test_resize_downsample_sizes_linear_antialias_model", - "node_test_resize_downsample_sizes_linear_pytorch_half_pixel_model", - "node_test_resize_downsample_sizes_nearest_model", - "node_test_resize_downsample_sizes_nearest_not_larger_model", - "node_test_resize_downsample_sizes_nearest_not_smaller_model", - "node_test_resize_tf_crop_and_resize_axes_2_3_model", - "node_test_resize_tf_crop_and_resize_axes_3_2_model", - "node_test_resize_tf_crop_and_resize_model", - "node_test_resize_upsample_scales_cubic_A_n0p5_exclude_outside_model", - "node_test_resize_upsample_scales_cubic_align_corners_model", - "node_test_resize_upsample_scales_cubic_asymmetric_model", - "node_test_resize_upsample_scales_cubic_model", - "node_test_resize_upsample_scales_linear_align_corners_model", - "node_test_resize_upsample_scales_linear_half_pixel_symmetric_model", - "node_test_resize_upsample_scales_linear_model", - "node_test_resize_upsample_scales_nearest_axes_2_3_model", - "node_test_resize_upsample_scales_nearest_axes_3_2_model", - "node_test_resize_upsample_scales_nearest_model", - "node_test_resize_upsample_sizes_cubic_model", - "node_test_resize_upsample_sizes_nearest_axes_2_3_model", - "node_test_resize_upsample_sizes_nearest_axes_3_2_model", - "node_test_resize_upsample_sizes_nearest_ceil_half_pixel_model", - "node_test_resize_upsample_sizes_nearest_floor_align_corners_model", - "node_test_resize_upsample_sizes_nearest_model", - "node_test_resize_upsample_sizes_nearest_not_larger_model", - "node_test_resize_upsample_sizes_nearest_round_prefer_ceil_asymmetric_model", "node_test_rnn_seq_length_model", "node_test_scan9_sum_model", "node_test_scan_sum_model", @@ -246,7 +197,6 @@ "node_test_split_to_sequence_1_model", "node_test_split_to_sequence_2_model", "node_test_split_to_sequence_nokeepdims_model", - "node_test_stft_model", "node_test_string_concat_broadcasting_model", "node_test_string_concat_empty_string_model", "node_test_string_concat_model", @@ -281,6 +231,9 @@ ] + + + class ImportSmokeTest(unittest.TestCase): @classmethod def setUpClass(cls): From 13113df33e53a1d0dd7a1f2313ec101df142a152 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 20 Feb 2024 10:34:21 -0800 Subject: [PATCH 213/283] [onnx] Enable crashing tests (#2928) Crashing tests no longer crash, enable as either passing or xfail tests. Co-authored-by: Xida Ren (Cedar) --- projects/pt1/e2e_testing/xfail_sets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 632b15e85c74..866c569eb147 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2194,7 +2194,7 @@ "ReduceMinAlongDimUnsignedInt_basic", "TensorsStackNegativeDimModule_basic", "TensorsStackPromoteDTypeModule_basic", - + "FlipModule_basic", "MoveDimIntNegativeIndexModule_basic", "PermuteNegativeIndexModule_basic", } From 534b266f2d198d2b6b2ed62d5b31f82d0e3d9f3c Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Tue, 20 Feb 2024 11:23:14 -0800 Subject: [PATCH 214/283] [torch-mlir][NFC] remove trailing whitespace (#2936) --- .../pt1/python/torch_mlir/compiler_utils.py | 2 +- projects/pt1/python/torch_mlir/dynamo.py | 4 +-- .../build_tools/abstract_interp_lib_gen.py | 28 +++++++++---------- .../python/torch_mlir_e2e_test/framework.py | 2 +- .../onnx_backends/linalg_on_tensors.py | 2 +- .../torch_mlir_e2e_test/test_suite/conv.py | 2 +- .../test_suite/diagonal.py | 2 +- .../test_suite/elementwise.py | 2 +- .../torch_mlir_e2e_test/test_suite/pooling.py | 16 +++++------ .../test_suite/reduction.py | 16 +++++------ 10 files changed, 38 insertions(+), 38 deletions(-) diff --git a/projects/pt1/python/torch_mlir/compiler_utils.py b/projects/pt1/python/torch_mlir/compiler_utils.py index 3a64473de118..7792006032af 100644 --- a/projects/pt1/python/torch_mlir/compiler_utils.py +++ b/projects/pt1/python/torch_mlir/compiler_utils.py @@ -64,7 +64,7 @@ def run_pipeline_with_repro_report(module, {sys.stderr.getvalue()} python exception: {e} - + For Torch-MLIR developers, the error can be reproduced with: $ torch-mlir-opt -pass-pipeline='{pipeline}' {filename} Add '{debug_options}' to get the IR dump for debugging purpose. diff --git a/projects/pt1/python/torch_mlir/dynamo.py b/projects/pt1/python/torch_mlir/dynamo.py index fa00bb9a847f..1b78b2a06e00 100644 --- a/projects/pt1/python/torch_mlir/dynamo.py +++ b/projects/pt1/python/torch_mlir/dynamo.py @@ -19,12 +19,12 @@ def _get_decomposition_table(): """Get a decomposition table suitable for Torch-MLIR. - + Sometimes TorchDynamo traces slightly different ops than what TorchScript captures. Historically we have been driven by the ops captured by TorchScript, so we try to decompose the ops captured by TorchDynamo into other ops that we already support. - + There isn't a highly principled solution here. Torch-MLIR currently supports a somewhat random set of ops, added in a demand-driven way over time, including direct backend support and decompositions internal to Torch-MLIR. diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index a856ac02639a..1a87bbb6bee1 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -81,7 +81,7 @@ def aten〇diagonal〡shape(self: List[int], offset: int = 0, dim1: int = 0, dim pass else: diagonal.append(self_dim) - + diag_size = max(min(self[dim1], self[dim2] - offset), 0) if offset<0: diag_size = max(min(self[dim1] + offset, self[dim2]), 0) @@ -295,10 +295,10 @@ def prims〇collapse〡shape(a: List[int], start: int, end: int) -> List[int]: assert end >= 0, "end out of bounds" assert start <= end, "start must be less than or equal to end" - # Examples: + # Examples: # # torch._prims.collapse(torch.empty(2,3,4), 1,2).shape - # is + # is # torch.Size([2, 12]) # # torch._prims.collapse(torch.empty(2,3,4), 1,3).shape @@ -592,7 +592,7 @@ def aten〇pixel_shuffle〡shape(self: List[int], upscale_factor: int) -> List[i assert len(self) >= 3, "input must be at least rank-3 in pixel_shuffle" upscale_factor_squared = upscale_factor * upscale_factor assert self[-3] % (upscale_factor_squared) == 0, "number of input channels must be divisible by upscale_factor^2 in pixel_shuffle" - + out = self[0:-3] out.append(self[-3] // upscale_factor_squared) out.append(self[-2] * upscale_factor) @@ -756,7 +756,7 @@ def _max_pool3d( assert ( len(stride) == 0 or len(stride) == 1 or len(stride) == 3 ), "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints" - + if len(stride) == 0: (dD, dH, dW) = (kD, kD, kD) elif len(stride) == 1: @@ -808,14 +808,14 @@ def _max_pool3d( return [nInputPlane, outputDepth, outputHeight, outputWidth] else: return [nbatch, nInputPlane, outputDepth, outputHeight, outputWidth] - + def aten〇max_pool2d〡shape(self: List[int], kernel_size: List[int], stride: List[int] = (), padding: List[int] = (0, 0,), dilation: List[int] = (1, 1,), ceil_mode: bool = False) -> List[int]: return upstream_shape_functions.max_pool2d(self, kernel_size, stride, padding, dilation, ceil_mode) @check_shape_function([ Invocation(TensorOfShape(3, 6, 10, 10, 10), [2]), # Basic using defaults Invocation(TensorOfShape(3, 6, 10, 10, 10), [4], [2], [2], [2]), # Using single values for each parameter - Invocation(TensorOfShape(3, 6, 64, 64, 64), [4, 6, 8], [2, 4, 2], [1, 2, 4], [1, 2, 4]), # Using dimensions should be + Invocation(TensorOfShape(3, 6, 64, 64, 64), [4, 6, 8], [2, 4, 2], [1, 2, 4], [1, 2, 4]), # Using dimensions should be ErrorInvocation(TensorOfShape(3, 6, 2, 2, 2), [4]), # Input is too small ErrorInvocation(TensorOfShape(3, 6, 10, 10, 10), [4], [2], [4], [2]), # The following relationship between kernel and padding needs to apply: Kernel size >= 2 * padding size ]) @@ -1374,15 +1374,15 @@ def aten〇conv_tbc〡shape(self: List[int], weight: List[int], bias: List[int], assert channels == channels_w # the out_channels in weights and biases should also match, but this assert doesn't work because typing problems - # assert out_channels == out_channels_b - + # assert out_channels == out_channels_b + self_bct = [batch, channels, time] weight_bct = [out_channels, channels, kernel_width] bias_bct = bias - # use existing shape inf + # use existing shape inf output_size_bct = upstream_shape_functions.conv_forwards(self, weight, bias, stride=[1], padding=[pad], dilation=[], transposed=False, output_padding=[], groups=1) - + batch_out, channels_out, time_out = output_size_bct # bct -> tbc @@ -1544,7 +1544,7 @@ def aten〇replication_pad2d〡shape(self: List[int], padding: List[int]) -> Lis return pad_shape_fn(self, padding) def aten〇replication_pad2d〡dtype(self_rank_dtype: Tuple[int, int], padding: List[int]) -> int: - self_rank, self_dtype = self_rank_dtype + self_rank, self_dtype = self_rank_dtype return self_dtype def aten〇pad〡shape(self: List[int], pad: List[int], mode: str = "constant", value: Optional[float] = None) -> List[int]: @@ -3618,7 +3618,7 @@ def aten〇where〇ScalarSelf〡dtype(condition_rank_dtype: Tuple[int, int], sel @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇nan_to_num〡dtype(self_rank_dtype: Tuple[int, int], nan: Optional[float] = None, posinf: Optional[float] = None, neginf: Optional[float] = None) -> int: - self_rank, self_dtype = self_rank_dtype + self_rank, self_dtype = self_rank_dtype return self_dtype @check_dtype_function( @@ -4258,7 +4258,7 @@ def aten〇cat〡dtype(tensors_rank_dtype: List[Tuple[int, int]], dim: int = 0) return promote_dtypes(ranks, dtypes) @check_dtype_function( - [Invocation("i,j->ij", [TensorOfShape(1, dtype=torch.float32), + [Invocation("i,j->ij", [TensorOfShape(1, dtype=torch.float32), TensorOfShape(1, dtype=torch.int32)]),]) def aten〇einsum〡dtype(equation: str, tensors_rank_dtype: List[Tuple[int, int]], path: Optional[List[int]] = None) -> int: ranks: List[Optional[int]] = [] diff --git a/projects/pt1/python/torch_mlir_e2e_test/framework.py b/projects/pt1/python/torch_mlir_e2e_test/framework.py index 388976256591..d3fecf54d99c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/framework.py +++ b/projects/pt1/python/torch_mlir_e2e_test/framework.py @@ -324,7 +324,7 @@ def compile_and_run_test(test: Test, config: TestConfig, verbose=False) -> Any: def run_tests(tests: List[Test], config: TestConfig, sequential=False, verbose=False) -> List[TestResult]: - """Invoke the given `Test`'s with the provided `TestConfig`.""" + """Invoke the given `Test`'s with the provided `TestConfig`.""" num_processes = min(int(mp.cpu_count() * 0.8) + 1, len(tests)) try: env_concurrency = int(os.getenv("TORCH_MLIR_TEST_CONCURRENCY", "0")) diff --git a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py index e77d795b7269..0e5073fdd89d 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py +++ b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py @@ -49,7 +49,7 @@ def compile(self, imported_module: Module): imported_module, f"builtin.module(func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))", "Lowering Onnx backend contract to Linalg-on-Tensors backend contract") - + run_pipeline_with_repro_report( imported_module, f"builtin.module(torch-lower-to-backend-contract)", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py index b12424cbb7b2..5872df170c48 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/conv.py @@ -853,7 +853,7 @@ def __init__(self): ]) def forward(self, x, weight, bias): return torch.conv_tbc(x, weight, bias) - + @register_test_case(module_factory=lambda: ConvTbcModule()) def ConvTbcModule_basic(module, tu: TestUtils): module.forward(tu.rand(9, 4, 5), tu.rand(3, 5, 6), tu.rand(6)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py index 13d49cea0737..d54bd11cb7d6 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/diagonal.py @@ -94,7 +94,7 @@ def __init__(self): @export @annotate_args([ - None, + None, ([-1, -1], torch.float32, True), ]) def forward(self, a): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 2f74ceb84416..ad4abd9f1752 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -3833,7 +3833,7 @@ def __init__(self): ([2, 5], torch.float32, True), ]) def forward(self, x): - return torch.ops.aten.isposinf(x) + return torch.ops.aten.isposinf(x) @register_test_case(module_factory=lambda: ElementwiseAtenIsposinfOpModule()) def ElementwiseAtenIsposinfOpModule_basic(module, tu:TestUtils): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py index d26d9b121cf3..22ff3bb330ad 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/pooling.py @@ -945,7 +945,7 @@ def AvgPool1dStaticModule_basic(module, tu: TestUtils): # ============================================================================== class AdaptiveAvgPool1dStaticLargerOutput(torch.nn.Module): - + def __init__(self): super().__init__() self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=13) @@ -965,7 +965,7 @@ def AdaptiveAvgPool1dStaticLargerOutput_basic( module.forward(tu.rand(5, 512, 7)) class AdaptiveAvgPool1dStaticEvenMultiple(torch.nn.Module): - + def __init__(self): super().__init__() self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=7) @@ -985,7 +985,7 @@ def AdaptiveAvgPool1dStaticEvenMultiple_basic( module.forward(tu.rand(5, 512, 147)) class AdaptiveAvgPool1dGeneralDynamic(torch.nn.Module): - + def __init__(self): super().__init__() self.aap1d = torch.nn.AdaptiveAvgPool1d(output_size=7) @@ -1085,7 +1085,7 @@ def AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic( module.forward(tu.rand(1, 512, 7)) class AdaptiveMaxPool2dDynamic(torch.nn.Module): - + def __init__(self): super().__init__() self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=False) @@ -1105,7 +1105,7 @@ def AdaptiveMaxPool2dDynamic_basic( module.forward(tu.rand(1, 512, 10, 16)) class AdaptiveMaxPool2dDynamicWithIndices(torch.nn.Module): - + def __init__(self): super().__init__() self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=True) @@ -1123,10 +1123,10 @@ def forward(self,x): def AdaptiveMaxPool2dDynamicWithIndices_basic( module, tu: TestUtils): module.forward(tu.rand(1, 512, 10, 16)) - + class AdaptiveMaxPool2dStatic(torch.nn.Module): - + def __init__(self): super().__init__() self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=False) @@ -1146,7 +1146,7 @@ def AdaptiveMaxPool2dStatic_basic( module.forward(tu.rand(1, 512, 10, 9)) class AdaptiveMaxPool2dStaticWithIndices(torch.nn.Module): - + def __init__(self): super().__init__() self.amp2d = torch.nn.AdaptiveMaxPool2d(output_size=(7,13), return_indices=True) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 804476e6a686..2c61524bd797 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -327,13 +327,13 @@ def __init__(self): ]) def forward(self, a): return torch.ops.aten.all(a, dim=0, keepdim=False) - + @register_test_case(module_factory=lambda: ReduceAllDimEmpty()) def ReduceAllDimEmpty_basic(module, tu: TestUtils): module.forward(torch.tensor([])) # ============================================================================== - + class ReduceAllDimFloat(torch.nn.Module): def __init__(self): super().__init__() @@ -345,13 +345,13 @@ def __init__(self): ]) def forward(self, a): return torch.ops.aten.all(a, dim=1, keepdim=True) - + @register_test_case(module_factory=lambda: ReduceAllDimFloat()) def ReduceAllDimFloat_basic(module, tu: TestUtils): module.forward(torch.tensor([[5.0,1e-6,-5.0],[0,5.0,0]])) # ============================================================================== - + class ReduceAllDimInt(torch.nn.Module): def __init__(self): super().__init__() @@ -363,13 +363,13 @@ def __init__(self): ]) def forward(self, a): return torch.ops.aten.all(a, dim=1, keepdim=True) - + @register_test_case(module_factory=lambda: ReduceAllDimInt()) def ReduceAllDimInt_basic(module, tu: TestUtils): module.forward(torch.tensor([[5,-5,0],[5,1e10,5]]).to(torch.int32)) # ============================================================================== - + class ReduceAllDimBool(torch.nn.Module): def __init__(self): super().__init__() @@ -381,13 +381,13 @@ def __init__(self): ]) def forward(self, a): return torch.ops.aten.all(a, dim=1, keepdim=False) - + @register_test_case(module_factory=lambda: ReduceAllDimBool()) def ReduceAllDimBool_basic(module, tu: TestUtils): module.forward(torch.tensor([[True, False, True], [True, True, True]])) # ============================================================================== - + class ReduceMaxAlongDim(torch.nn.Module): def __init__(self): super().__init__() From 0f80e75c2eb6dfed00bf051644a5e3fb97207bb8 Mon Sep 17 00:00:00 2001 From: Srinath Avadhanula Date: Tue, 20 Feb 2024 17:22:38 -0500 Subject: [PATCH 215/283] allow tosa.cast to convert from f32 to f16 (#2934) According to the [official TOSA spec](https://www.mlplatform.org/tosa/tosa_spec.html#_cast), `tosa.cast` allows a cast from `fp32` to `fp16`. We were not previously accounting for this in the `TorchToTosa` lowering. Also did a tiny bit of cleanup in the code to make it easier to spot which conversions are currently allowed. --------- Co-authored-by: Srinath Avadhanula --- .../TorchToTosa/TosaLegalizeUtils.cpp | 26 +++++++++++++++---- .../TorchToTosa/cast_fp32_to_fp16.mlir | 12 +++++++++ 2 files changed, 33 insertions(+), 5 deletions(-) create mode 100644 test/Conversion/TorchToTosa/cast_fp32_to_fp16.mlir diff --git a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp index 781a5912d83c..9259fdacff24 100644 --- a/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp +++ b/lib/Conversion/TorchToTosa/TosaLegalizeUtils.cpp @@ -266,28 +266,44 @@ std::optional getConstTensor(PatternRewriter &rewriter, } static LogicalResult checkValidityOfCast(Type src, Type dest) { - if ((src == dest) || (src.isInteger(64) && dest.isInteger(32)) || + // clang-format off + if ((src == dest) || + // int64 -> * + (src.isInteger(64) && dest.isInteger(32)) || (src.isInteger(64) && dest.isInteger(8)) || (src.isInteger(64) && dest.isInteger(1)) || (src.isInteger(64) && dest.isF32()) || + // int32 -> * (src.isInteger(32) && dest.isInteger(64)) || (src.isInteger(32) && dest.isInteger(1)) || (src.isInteger(32) && dest.isF32()) || (src.isInteger(32) && dest.isBF16()) || + // int16 -> * (src.isInteger(16) && dest.isBF16()) || + // int8 -> * (src.isInteger(8) && dest.isInteger(1)) || (src.isInteger(8) && dest.isBF16()) || + // int1 -> * (src.isInteger(1) && dest.isInteger(64)) || - (src.isInteger(1) && dest.isF32()) || (src.isF32() && dest.isF64()) || - (src.isF32() && dest.isBF16()) || (src.isF64() && dest.isF32()) || - (src.isF64() && dest.isBF16()) || (src.isF32() && dest.isInteger(8)) || + (src.isInteger(1) && dest.isF32()) || + // f64 -> * + (src.isF64() && dest.isF32()) || + (src.isF64() && dest.isBF16()) || + // f32 -> * + (src.isF32() && dest.isF64()) || + (src.isF32() && dest.isBF16()) || + (src.isF32() && dest.isF16()) || + (src.isF32() && dest.isInteger(8)) || (src.isF32() && dest.isInteger(64)) || (src.isF32() && dest.isInteger(1)) || + // bf16 -> * (src.isBF16() && dest.isInteger(8)) || (src.isBF16() && dest.isInteger(16)) || - (src.isBF16() && dest.isInteger(32)) || (src.isBF16() && dest.isF32())) { + (src.isBF16() && dest.isInteger(32)) || + (src.isBF16() && dest.isF32())) { return success(); } + // clang-format on return failure(); } diff --git a/test/Conversion/TorchToTosa/cast_fp32_to_fp16.mlir b/test/Conversion/TorchToTosa/cast_fp32_to_fp16.mlir new file mode 100644 index 000000000000..5504ac0e4002 --- /dev/null +++ b/test/Conversion/TorchToTosa/cast_fp32_to_fp16.mlir @@ -0,0 +1,12 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-tosa -split-input-file + +// CHECK: %{{.*}} = tosa.cast %{{.*}} : (tensor<1x32x220x220xf32>) -> tensor<1x32x220x220xf16> +func.func @forward(%arg0: !torch.vtensor<[1,32,220,220],f32>) -> !torch.vtensor<[1,32,220,220],f16> { + %int5 = torch.constant.int 5 + %false = torch.constant.bool false + %none = torch.constant.none + %out = torch.aten.to.dtype %arg0, %int5, %false, %false, %none : !torch.vtensor<[1,32,220,220],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[1,32,220,220],f16> + return %out : !torch.vtensor<[1,32,220,220],f16> +} + + From df2aa1a3699cbb161c0f6c0b475dee5ca8dab98d Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 21 Feb 2024 21:28:44 -0800 Subject: [PATCH 216/283] [torch] Fixed edge conditions for strided slicing (#2929) Strided slicing can occur with a negative stride. In these cases we need to bound end differently. This included removing a function that was generating bad limits. --- .../torch-mlir/Dialect/Torch/IR/TorchOps.h | 23 ----------- lib/Conversion/TorchToLinalg/DataMovement.cpp | 39 +++++++++++++++---- lib/Dialect/Torch/IR/TorchOps.cpp | 8 +++- projects/pt1/e2e_testing/xfail_sets.py | 7 ++-- 4 files changed, 40 insertions(+), 37 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h index c82a98cc5aba..e6a9e1622cc1 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.h +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.h @@ -309,29 +309,6 @@ inline int64_t getIntAttrAsSigned(IntegerAttr intAttr) { return intAttr.getValue().getSExtValue(); } -/// Returns the value from an `IntegerAttr` as an integral index. -/// -/// @param intAttr the `IntegerAttr` from which to extract the index -/// @param dimSize the size of the dimension that the attribute indexes into -/// @return the index value -/// -/// Use this function when the given `IntegerAttr` represents an index into -/// a range, such as an index into a tensor dimension. If `dimSize` is given, -/// negative index values are converted into positive vales by counting -/// elements from the "right" side of the dimension, as in python, numpy, etc. -/// For example, an index of -2 and a dimSize of 10 returns 8 because 8 is the -/// 2nd index from the high end of the range 0 to 9. If `dimSize` is not -/// given, any negative indices are returned as negative numbers. -/// -/// No bounds checking is performed on the index to ensure that it is within -/// the legal range for `dimSize`. -inline int64_t getIntAttrAsIndex(IntegerAttr intAttr, int dimSize = -1) { - int64_t signedIndex = getIntAttrAsSigned(intAttr); - if (dimSize < 0 || signedIndex > 0) - return signedIndex; - return dimSize + signedIndex; // count backwards from dimSize -} - } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index b5eea7da619a..d9132317e32f 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -51,6 +51,7 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, Value zero = rewriter.create(loc, 0); Value one = rewriter.create(loc, 1); + Value negone = rewriter.create(loc, -1); int64_t dim; if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) @@ -76,27 +77,49 @@ LogicalResult prepareArgumentsForSlicingOp(OpTy op, OpAdaptor adaptor, Value stepIndex = castIntToIndex(rewriter, loc, adaptor.getStep()); Value start = toPositiveValidDim(rewriter, loc, torchTypeStart, builtinTypeStart, zero, dimSize); - Value end = toPositiveValidDim(rewriter, loc, torchTypeEnd, builtinTypeEnd, - dimSize, dimSize); - // end >= start ? end : start - Value endSgeStart = rewriter.create( - loc, arith::CmpIPredicate::sge, end, start); - end = rewriter.create(loc, endSgeStart, end, start); + // We cannot use to positive valid dim as for negative strides we need to + // clamp to `-1` so that the full tensor bounds are available: + Value end = builtinTypeEnd; + if (torchTypeEnd.getType().isa()) { + end = dimSize; + } else { + end = castIntToIndex(rewriter, loc, end); + Value endcmp = rewriter.create( + loc, arith::CmpIPredicate::slt, end, zero); + Value endadd = rewriter.create(loc, end, dimSize); + end = rewriter.create(loc, endcmp, endadd, end); + endcmp = rewriter.create(loc, arith::CmpIPredicate::slt, end, + zero); + end = rewriter.create(loc, endcmp, negone, end); + endcmp = rewriter.create(loc, arith::CmpIPredicate::sgt, end, + dimSize); + end = rewriter.create(loc, endcmp, dimSize, end); + } // Slice logic: resultSize = floordiv(end - start + step - 1, step) resultShape = getTensorSizes(rewriter, loc, input); Value len = rewriter.create(loc, end, start); + + // We check the difference between start and end to determine the total size: + Value stepcmp = rewriter.create(loc, arith::CmpIPredicate::sge, + stepIndex, zero); + Value stepsign = rewriter.create(loc, stepcmp, one, negone); Value resultSize = rewriter.create(loc, len, stepIndex); - resultSize = rewriter.create(loc, resultSize, one); + resultSize = rewriter.create(loc, resultSize, stepsign); resultSize = rewriter.create(loc, resultSize, stepIndex); + + // Clamp the size to [0, ...]: + Value szcmp = rewriter.create(loc, arith::CmpIPredicate::slt, + resultSize, zero); + resultSize = rewriter.create(loc, szcmp, zero, resultSize); resultShape[dim] = resultSize; strides.resize(inputType.getRank(), one); offsets.resize(inputType.getRank(), zero); offsets[dim] = start; - strides[dim] = rewriter.create(loc, strides[dim], stepIndex); + strides[dim] = stepIndex; return success(); } diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 36e089fb28d3..da6f71015942 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -3327,11 +3327,15 @@ OpFoldResult AtenIndexSelectOp::fold(FoldAdaptor adaptor) { // Get the single index value for the selected dimension auto splatValue = indexAttr.getSplatValue(); - int64_t indexInt = getIntAttrAsIndex(splatValue, selfSizes[dimInt]); + int64_t indexInt = getIntAttrAsSigned(splatValue); + indexInt = indexInt < 0 && selfSizes[dimInt] ? indexInt + selfSizes[dimInt] + : indexInt; // Extract the single constant value from the input tensor and turn the // extracted value into a single-element tensor of the output shape and dtype - auto splattr = selfAttr.getValues()[indexInt]; + Attribute splattr = selfAttr.isSplat() + ? selfAttr.getSplatValue() + : selfAttr.getValues()[indexInt]; auto dty = resultTy.getDtype(); auto attrTy = resultTy.toBuiltinTensor().clone(dty); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 866c569eb147..e600a6be8a52 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2162,6 +2162,8 @@ "EmbeddingModuleI32_basic", "EmbeddingModuleI64_basic", "ExpandModule_basic", + "MoveDimIntNegativeIndexModule_basic", + "PermuteNegativeIndexModule_basic", "ReduceAmaxKeepDim_basic", "ReduceMaxKeepDimReturnBoth_basic", "ReduceMaxNegativeDim_basic", @@ -2184,7 +2186,6 @@ "ElementwiseUnsqueezeNegDimsModule_basic", "ElementwiseWhereScalarModule_basic", "FlattenDynamicModule_basic", - "FlipModule_basic", "FlipModuleStaticShape_basic", "GluStaticModule_basic", "MaskedFillTensorFloatValueModule_basic", @@ -2194,9 +2195,7 @@ "ReduceMinAlongDimUnsignedInt_basic", "TensorsStackNegativeDimModule_basic", "TensorsStackPromoteDTypeModule_basic", - "FlipModule_basic", - "MoveDimIntNegativeIndexModule_basic", - "PermuteNegativeIndexModule_basic", } ONNX_CRASHING_SET = { } + From 53f6d06ab8ae619fdeecabaa0e25f7b621598ae8 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 21 Feb 2024 21:34:43 -0800 Subject: [PATCH 217/283] [onnx] Drop `ConstantOfShape` logic form importer, fix torch lowering (#2930) There is no reason to treat `ConstantOfShape` as a specialized import any as there exists a onnx-to-torch equivalent. Dropping the import coding and adding support for resource conversion substantially increases test coverage for dynamically shaped tests. --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 71 ++++++++++++------- projects/pt1/e2e_testing/xfail_sets.py | 28 -------- python/torch_mlir/extras/onnx_importer.py | 31 -------- .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 36 +++++----- .../python/onnx_importer/import_smoke_test.py | 41 ----------- 5 files changed, 65 insertions(+), 142 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index e8c36d8cad54..99a3985a2993 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1571,7 +1571,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); }); patterns.onOp( - "ConstantOfShape", 20, + "ConstantOfShape", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value shape; @@ -1582,15 +1582,14 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( auto shapeSizes = dyn_cast(shape.getType()).getSizes(); SmallVector dimList; - SmallVector selectSizes; - selectSizes.push_back(1); Torch::BaseTensorType shapeType = shape.getType().cast(); - Type selectResultType = shapeType.getWithSizesAndDtype( - llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype()); + Type selectResultType = rewriter.getType( + ArrayRef({}), shapeType.getOptionalDtype()); Value zero = rewriter.create( binder.getLoc(), rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + for (int i = 0; i < shapeSizes[0]; i++) { Value selectIndex = rewriter.create( binder.getLoc(), rewriter.getType(), @@ -1601,6 +1600,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.getLoc(), rewriter.getType(), extract); dimList.push_back(dim); } + Value dimValueList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), @@ -1609,7 +1609,6 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // Get fill_value if it is present. // Assumption : resultDType and value attr type match. - Value value_const; auto attr = binder.op->getAttr("torch.onnx.value"); auto resultDType = resultType.getDtype(); @@ -1620,34 +1619,58 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( resultType.toBuiltinTensor().clone(resultDType), rewriter.getFloatAttr(resultDType, 0.0)); } - if (!isa(attr)) { - return rewriter.notifyMatchFailure( - binder.op, "`value` attr needs to be a tensor."); + + // If its a dense resource attr we need to convert to a dense type: + if (DenseResourceElementsAttr rattr = + attr.dyn_cast_or_null()) { + // Bytes are stored in little endian order. Big endian support will + // require swizzling. + if (!Endian::little) { + binder.op->emitError( + "unimplemented: importing on big endian systems"); + return failure(); + } + + auto ty = cast(rattr.getType()); + auto ptr = rattr.getRawHandle().getBlob()->getData(); + auto denseAttr = DenseElementsAttr::getFromRawBuffer(ty, ptr); + attr = dyn_cast_or_null(denseAttr); + } + + Attribute splattr; + if (isa(attr)) { + auto denseAttr = attr.cast(); + splattr = denseAttr.getSplatValue(); } - auto denseAttr = attr.cast(); - auto denseAttrEleType = denseAttr.getElementType(); - if (!isa(denseAttrEleType)) { + if (!isa(splattr)) { return rewriter.notifyMatchFailure( binder.op, "`value` attr tensor only supports types int and float for now."); } - // Create constant op for value - if (denseAttrEleType.isa()) { - int64_t intVal = denseAttr.getSplatValue().getSInt(); - value_const = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(intVal)); - } - if (denseAttrEleType.isa()) { - float floatVal = - denseAttr.getSplatValue().getValue().convertToFloat(); - value_const = rewriter.create( - binder.getLoc(), rewriter.getF64FloatAttr(floatVal)); + Value splatvalue; + if (auto intattr = dyn_cast(splattr)) { + IntegerType intty = cast(intattr.getType()); + int64_t value; + if (intty.isUnsignedInteger()) { + value = intattr.getUInt(); + } else if (intty.isSignedInteger()) { + value = intattr.getSInt(); + } else { + value = intattr.getInt(); + } + splatvalue = + rewriter.create(binder.getLoc(), value); } + if (auto fpattr = dyn_cast(splattr)) + splatvalue = rewriter.create( + binder.getLoc(), + rewriter.getF64FloatAttr(fpattr.getValueAsDouble())); + rewriter.replaceOpWithNewOp( - binder.op, resultType, dimValueList, value_const, /*dtype=*/noneVal, + binder.op, resultType, dimValueList, splatvalue, /*dtype=*/noneVal, /*layout=*/noneVal, /*device=*/noneVal, /*pin_memory=*/noneVal); return success(); }); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e600a6be8a52..e749b5834cc6 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1825,23 +1825,6 @@ "ElementwiseClampTensorFloatModule_basic", "ElementwiseClampTensorInt8Module_basic", "ElementwiseClampTensorIntModule_basic", - "EmptyLikeMemoryFormatModule_basic", - "EmptyLikeModule_defaultDtype", - "EmptyLikeModule_falsePinMemory", - "EmptyLikeModule_float", - "EmptyLikeModule_int", - "Fill_TensorFloat32WithFloat32_basic", - "Fill_TensorFloat32WithFloat64_basic", - "Fill_TensorFloat32WithInt64_basic", - "Fill_TensorFloat64WithFloat32_basic", - "Fill_TensorFloat64WithFloat64_basic", - "Fill_TensorFloat64WithInt64_basic", - "FullLikeModuleDefaultDtype_basic", - "FullLikeModuleFalsePinMemory_basic", - "FullLikeModuleFloat2D_basic", - "FullLikeModuleFloat3D_basic", - "FullLikeModuleInt2D_basic", - "FullLikeModuleInt3D_basic", "HBC_basic", "IndexPut1DFloatAccumulateModule_basic", "IndexPut1DIntAccumulateModule_basic", @@ -1856,10 +1839,6 @@ "IndexPutHackedTwin3DFloatAccumulateModule_basic", "IndexPutHackedTwin3DIntAccumulateModule_basic", "NormalizeModule_basic", - "OnesLikeModule_defaultDtype", - "OnesLikeModule_falsePinMemory", - "OnesLikeModule_float", - "OnesLikeModule_int", "PadWithNoneValModule_basic", "QuantizedMLP_basic", "RandModule_basic", @@ -1875,13 +1854,6 @@ "TileSmallDimsSizeModule_basic", "UpSampleNearest2dDynamicSize_basic", "UpSampleNearest2dStaticSize_basic", - "ZeroFloat32Module_basic", - "ZeroInt32Module_basic", - "ZeroInt64Module_basic", - "ZerosLikeModule_defaultDtype", - "ZerosLikeModule_falsePinMemory", - "ZerosLikeModule_float", - "ZerosLikeModule_int", # Failure - onnx_lowering "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index c62324832520..a0cfbf26ed30 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -408,37 +408,6 @@ def _handle_node_Constant(self, node: onnx.NodeProto) -> bool: self._gi.initializer_map[const_name] = value_proto.t return True - def _handle_node_ConstantOfShape(self, node: onnx.NodeProto) -> bool: - # This op is special: It has an input of the shape, and in full generality - # could involve eager production of constants of variable size. In - # practice, the DNN profile for ONNX makes this very difficult to do - # and we hard-assert that the input can be resolved to an immediate - # value. - assert len(node.input) == 1 - assert len(node.output) == 1 - shape = self._get_immediate_tensor(node.input[0]).astype(np.int64) - value_proto = _get_attr(node, "value") - assert value_proto.type == onnx.AttributeProto.AttributeType.TENSOR - tensor_proto = value_proto.t - element_type = self._cc.tensor_element_type(tensor_proto.data_type) - vtensor_type = self._cc.get_vtensor_type(tuple(shape), element_type) - assert len(tensor_proto.dims) == 1 and tensor_proto.dims[0] == 1 - try: - cb = ELEM_TYPE_SPLAT_TENSOR_PROTO_CB[tensor_proto.data_type] - except KeyError: - raise OnnxImportError( - f"Unhandled splat type for ConstantOfShape: {node} (possible missing mapping in ELEM_TYPE_SPLAT_TENSOR_PROTO_CB)" - ) - value_attr = cb(tensor_proto, tuple(shape)) - literal_op = Operation.create( - name="torch.vtensor.literal", - results=[vtensor_type], - attributes={"value": value_attr}, - ) - self._nv_map[node.output[0]] = literal_op.result - return True - - class ContextCache: """Caches per-context lookups of various things.""" diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 3e4a476dbfbb..525583b7660e 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1539,14 +1539,14 @@ func.func @test_constant_of_shape_dense_float_default() -> !torch.vtensor<[2,3,4 // CHECK: %[[SHAPE_CST:.*]] = torch.vtensor.literal(dense<[2, 3, 4]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> // CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT1:.*]] = torch.constant.int 1 - // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT2:.*]] = torch.constant.int 2 - // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[ELE_0]], %[[ELE_1]], %[[ELE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[FILL_VAL:.*]] = torch.constant.float 0.000000e+00 @@ -1563,14 +1563,14 @@ func.func @test_constant_of_shape_dense_float_cst() -> !torch.vtensor<[2,3,4], f // CHECK: %[[SHAPE_CST:.*]] = torch.vtensor.literal(dense<[2, 3, 4]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> // CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT1:.*]] = torch.constant.int 1 - // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT2:.*]] = torch.constant.int 2 - // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[ELE_0]], %[[ELE_1]], %[[ELE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[FILL_VAL:.*]] = torch.constant.float 3.4000000953674316 @@ -1587,14 +1587,14 @@ func.func @test_constant_of_shape_dense_int_cst() -> !torch.vtensor<[2,3,4], si6 // CHECK: %[[SHAPE_CST:.*]] = torch.vtensor.literal(dense<[2, 3, 4]> : tensor<3xsi64>) : !torch.vtensor<[3],si64> // CHECK: %[[INT0:.*]] = torch.constant.int 0 // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[EXTRACT_0:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT0_0]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_0:.*]] = torch.aten.item %[[EXTRACT_0]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT1:.*]] = torch.constant.int 1 - // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[EXTRACT_1:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT1]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_1:.*]] = torch.aten.item %[[EXTRACT_1]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT2:.*]] = torch.constant.int 2 - // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[1],si64> -> !torch.int + // CHECK: %[[EXTRACT_2:.*]] = torch.aten.select.int %[[SHAPE_CST]], %[[INT0]], %[[INT2]] : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ELE_2:.*]] = torch.aten.item %[[EXTRACT_2]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[DIM_LIST:.*]] = torch.prim.ListConstruct %[[ELE_0]], %[[ELE_1]], %[[ELE_2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[NONE:.*]] = torch.constant.none // CHECK: %[[FILL_VAL:.*]] = torch.constant.int 3 diff --git a/test/python/onnx_importer/import_smoke_test.py b/test/python/onnx_importer/import_smoke_test.py index 708324e72db6..22d460050cae 100644 --- a/test/python/onnx_importer/import_smoke_test.py +++ b/test/python/onnx_importer/import_smoke_test.py @@ -102,9 +102,6 @@ "node_test_castlike_FLOAT_to_STRING_model", "node_test_castlike_STRING_to_FLOAT_expanded_model", "node_test_castlike_STRING_to_FLOAT_model", - "node_test_constantofshape_float_ones_model", - "node_test_constantofshape_int_shape_zero_model", - "node_test_constantofshape_int_zeros_model", "node_test_dequantizelinear_e4m3fn_model", "node_test_dequantizelinear_e4m3fn_zero_point_model", "node_test_dequantizelinear_e5m2_model", @@ -118,44 +115,6 @@ "node_test_if_model", "node_test_if_opt_model", "node_test_if_seq_model", - "node_test_layer_normalization_2d_axis0_expanded_model", - "node_test_layer_normalization_2d_axis0_expanded_ver18_model", - "node_test_layer_normalization_2d_axis1_expanded_model", - "node_test_layer_normalization_2d_axis1_expanded_ver18_model", - "node_test_layer_normalization_2d_axis_negative_1_expanded_model", - "node_test_layer_normalization_2d_axis_negative_1_expanded_ver18_model", - "node_test_layer_normalization_2d_axis_negative_2_expanded_model", - "node_test_layer_normalization_2d_axis_negative_2_expanded_ver18_model", - "node_test_layer_normalization_3d_axis0_epsilon_expanded_model", - "node_test_layer_normalization_3d_axis0_epsilon_expanded_ver18_model", - "node_test_layer_normalization_3d_axis1_epsilon_expanded_model", - "node_test_layer_normalization_3d_axis1_epsilon_expanded_ver18_model", - "node_test_layer_normalization_3d_axis2_epsilon_expanded_model", - "node_test_layer_normalization_3d_axis2_epsilon_expanded_ver18_model", - "node_test_layer_normalization_3d_axis_negative_1_epsilon_expanded_model", - "node_test_layer_normalization_3d_axis_negative_1_epsilon_expanded_ver18_model", - "node_test_layer_normalization_3d_axis_negative_2_epsilon_expanded_model", - "node_test_layer_normalization_3d_axis_negative_2_epsilon_expanded_ver18_model", - "node_test_layer_normalization_3d_axis_negative_3_epsilon_expanded_model", - "node_test_layer_normalization_3d_axis_negative_3_epsilon_expanded_ver18_model", - "node_test_layer_normalization_4d_axis0_expanded_model", - "node_test_layer_normalization_4d_axis0_expanded_ver18_model", - "node_test_layer_normalization_4d_axis1_expanded_model", - "node_test_layer_normalization_4d_axis1_expanded_ver18_model", - "node_test_layer_normalization_4d_axis2_expanded_model", - "node_test_layer_normalization_4d_axis2_expanded_ver18_model", - "node_test_layer_normalization_4d_axis3_expanded_model", - "node_test_layer_normalization_4d_axis3_expanded_ver18_model", - "node_test_layer_normalization_4d_axis_negative_1_expanded_model", - "node_test_layer_normalization_4d_axis_negative_1_expanded_ver18_model", - "node_test_layer_normalization_4d_axis_negative_2_expanded_model", - "node_test_layer_normalization_4d_axis_negative_2_expanded_ver18_model", - "node_test_layer_normalization_4d_axis_negative_3_expanded_model", - "node_test_layer_normalization_4d_axis_negative_3_expanded_ver18_model", - "node_test_layer_normalization_4d_axis_negative_4_expanded_model", - "node_test_layer_normalization_4d_axis_negative_4_expanded_ver18_model", - "node_test_layer_normalization_default_axis_expanded_model", - "node_test_layer_normalization_default_axis_expanded_ver18_model", "node_test_loop11_model", "node_test_loop13_seq_model", "node_test_loop16_seq_none_model", From 5af249566b85a97e5c96c847877e4fc9fda57ec1 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 22 Feb 2024 21:16:53 +0530 Subject: [PATCH 218/283] build: manually update PyTorch version (#2933) Set PyTorch and TorchVision version to nightly release 2024-02-20. Signed-Off By: Vivek Khandelwal --- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- torchvision-requirements.txt | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch-hash.txt b/pytorch-hash.txt index c23d10cf50fb..81f0390b4ebb 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -b51e0246b7f119770c47183b230c553f15ab4fbb +8efa066dc0870521652c1319bd6b5b0f6dc3fe25 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 25546907e856..26abce08d1aa 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.3.0.dev20240214 +torch==2.3.0.dev20240220 diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index 4f775c549c6c..ce099fb91709 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.18.0.dev20240207 +torchvision==0.18.0.dev20240220 From 55dc8deb9221c9ec0fe2a991542f2f788c62a3e1 Mon Sep 17 00:00:00 2001 From: Andreas Falkenberg <149819731+afalkenberg1@users.noreply.github.com> Date: Fri, 23 Feb 2024 09:14:38 -0800 Subject: [PATCH 219/283] [torch] GridSample TorchToLinalg lowering (#2883) Lowers `torch.grid_sample` to the equilvalent `linalg` representation. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 27 +++ .../TorchToLinalg/Uncategorized.cpp | 168 ++++++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 15 ++ .../build_tools/abstract_interp_lib_gen.py | 8 + .../build_tools/torch_ods_gen.py | 1 + .../Conversion/TorchToLinalg/gridsampler.mlir | 60 +++++++ test/Conversion/TorchToLinalg/pooling.mlir | 1 - 7 files changed, 279 insertions(+), 1 deletion(-) create mode 100644 test/Conversion/TorchToLinalg/gridsampler.mlir diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c5fec66913b0..cc8be7c6910b 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -12353,6 +12353,33 @@ def Torch_AtenScaledDotProductAttentionOp : Torch_Op<"aten.scaled_dot_product_at }]; } +def Torch_AtenGridSamplerOp : Torch_Op<"aten.grid_sampler", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$input, + AnyTorchTensorType:$grid, + Torch_IntType:$interpolation_mode, + Torch_IntType:$padding_mode, + Torch_BoolType:$align_corners + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenGridSamplerOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenGridSamplerOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_Aten__Contains__StrOp : Torch_Op<"aten.__contains__.str", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 08d69ca718b9..ed6883000cf9 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2360,6 +2360,172 @@ class ConvertCastEquivalentOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertAtenGridSamplerOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenGridSamplerOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op->getLoc(); + Type int64type = rewriter.getI64Type(); + Type floatType = rewriter.getF32Type(); + Value zeroIndex = rewriter.create(loc, 0); + Value oneIndex = rewriter.create(loc, 1); + Value twoIndex = rewriter.create(loc, 2); + Value zeroFloat = rewriter.create( + loc, rewriter.getFloatAttr(floatType, 0.0)); + Value oneFloat = rewriter.create( + loc, rewriter.getFloatAttr(floatType, 1.0)); + Value twoFloat = rewriter.create( + loc, rewriter.getFloatAttr(floatType, 2.0)); + Value input = adaptor.getInput(); + auto inputType = input.getType().cast(); + auto inputShape = inputType.getShape(); + Value innerDim0a = rewriter.create(loc, input, 2); + Value innerDim1a = rewriter.create(loc, input, 3); + Value innerDim0b = + rewriter.create(loc, innerDim0a, oneIndex); + Value innerDim1b = + rewriter.create(loc, innerDim1a, oneIndex); + Value innerDim0c = + rewriter.create(loc, int64type, innerDim0b); + Value innerDim1c = + rewriter.create(loc, int64type, innerDim1b); + Value innerDim0d = + rewriter.create(loc, floatType, innerDim0c); + Value innerDim1d = + rewriter.create(loc, floatType, innerDim1c); + Value innerDim0e = + rewriter.create(loc, innerDim0d, twoFloat); + Value innerDim1e = + rewriter.create(loc, innerDim1d, twoFloat); + Value grid = adaptor.getGrid(); + auto gridType = grid.getType().cast(); + auto gridShape = gridType.getShape(); + auto gridRank = gridType.getRank(); + SmallVector extractGridOffsets0(gridRank, zeroIndex); + SmallVector extractGridShape = getTensorSizes(rewriter, loc, grid); + SmallVector extractGridStride(gridRank, oneIndex); + int64_t lastGridDim = gridRank - 1; + extractGridShape[lastGridDim] = oneIndex; + extractGridStride[lastGridDim] = twoIndex; + SmallVector extractGridOffsets1(gridRank, zeroIndex); + extractGridOffsets1[lastGridDim] = oneIndex; + SmallVector gridShapeExtracted(gridShape); + gridShapeExtracted.back() = 1; + SmallVector gridShapeCollapsed{gridShape[0], gridShape[1], + gridShape[2]}; + auto grid0 = rewriter.create( + loc, grid, extractGridOffsets0, extractGridShape, extractGridStride); + auto grid1 = rewriter.create( + loc, grid, extractGridOffsets1, extractGridShape, extractGridStride); + SmallVector associations{ReassociationIndices{0}, + ReassociationIndices{1}, + ReassociationIndices{2, 3}}; + auto gridCollapsed0 = + rewriter.create(loc, grid0, associations); + auto gridCollapsed1 = + rewriter.create(loc, grid1, associations); + AffineMap gridMap = AffineMap::get(4, 0, + {rewriter.getAffineDimExpr(0), + rewriter.getAffineDimExpr(2), + rewriter.getAffineDimExpr(3)}, + op->getContext()); + SmallVector gridMaps{gridMap, gridMap, + rewriter.getMultiDimIdentityMap(gridRank)}; + SmallVector gridIterators( + gridRank, utils::IteratorType::parallel); + SmallVector resultShape{inputShape[0], inputShape[1], gridShape[1], + gridShape[2]}; + auto lambdaExtract = [](OpBuilder &b, Location loc, Value input, Value idxA, + Value idxB, Value idxC, Value idxD) -> Value { + SmallVector index{idxA, idxB, idxC, idxD}; + Value result = b.create(loc, input, index); + return result; + }; + auto lambdaInter = [&](OpBuilder &b, Location loc, Value x, Value y, + Value d) -> Value { + Value dm = b.create(loc, oneFloat, d); + Value ra = b.create(loc, x, dm); + Value rb = b.create(loc, y, d); + Value res = b.create(loc, ra, rb); + return res; + }; + auto resultType = getTypeConverter() + ->convertType(op.getResult().getType()) + .cast(); + llvm::SmallVector resultSize{ + rewriter.create(loc, input, 0), + rewriter.create(loc, input, 1), + rewriter.create(loc, grid, 1), + rewriter.create(loc, grid, 2)}; + Value resultFinal = + rewriter.create(loc, resultType, resultSize); + auto sGrid = rewriter.create( + loc, TypeRange{resultType}, ValueRange{gridCollapsed0, gridCollapsed1}, + ValueRange(resultFinal), gridMaps, gridIterators, + [&](OpBuilder &b, Location loc, ValueRange args) { + Value gr0 = args[0]; + Value gr1 = args[1]; + Value gplus0 = b.create(loc, gr0, oneFloat); + Value gplus1 = b.create(loc, gr1, oneFloat); + Value result0 = b.create(loc, gplus0, innerDim0e); + Value result1 = b.create(loc, gplus1, innerDim1e); + Value lower0 = b.create(loc, int64type, result0); + Value lower1 = b.create(loc, int64type, result1); + Value oneInt = + b.create(loc, b.getIntegerAttr(int64type, 1)); + Value upper0 = + b.create(loc, int64type, lower0, oneInt); + Value upper1 = + b.create(loc, int64type, lower1, oneInt); + Value notValid0 = rewriter.create( + loc, arith::CmpIPredicate::sgt, upper0, innerDim0c); + Value notValid1 = rewriter.create( + loc, arith::CmpIPredicate::sgt, upper1, innerDim1c); + Value upperValid0 = + b.create(loc, notValid0, lower0, upper0); + Value upperValid1 = + b.create(loc, notValid1, lower1, upper1); + Value lw0 = + b.create(loc, b.getIndexType(), lower0); + Value lw1 = + b.create(loc, b.getIndexType(), lower1); + Value up0 = + b.create(loc, b.getIndexType(), upperValid0); + Value up1 = + b.create(loc, b.getIndexType(), upperValid1); + Value N = b.create(loc, 0); + Value C = b.create(loc, 1); + Value result00 = lambdaExtract(b, loc, input, N, C, lw0, lw1); + Value result01 = lambdaExtract(b, loc, input, N, C, lw0, up1); + Value result01a = + b.create(loc, notValid1, zeroFloat, result01); + Value result10 = lambdaExtract(b, loc, input, N, C, up0, lw1); + Value result10a = + b.create(loc, notValid0, zeroFloat, result10); + Value result11 = lambdaExtract(b, loc, input, N, C, up0, up1); + Value result11a = + b.create(loc, notValid0, zeroFloat, result11); + Value result11b = + b.create(loc, notValid1, zeroFloat, result11a); + Value lw0a = b.create(loc, floatType, lower0); + Value lw1a = b.create(loc, floatType, lower1); + Value d0 = b.create(loc, result0, lw0a); + Value d1 = b.create(loc, result1, lw1a); + Value resultScaled0 = lambdaInter(b, loc, result00, result01a, d0); + Value resultScaled1 = lambdaInter(b, loc, result10a, result11b, d0); + Value resultScaled = + lambdaInter(b, loc, resultScaled0, resultScaled1, d1); + b.create(loc, resultScaled); + }); + rewriter.replaceOp(op, sGrid.getResults()); + return success(); + } +}; +} // namespace + void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( TypeConverter &typeConverter, RewritePatternSet &patterns, ConversionTarget &target) { @@ -2412,4 +2578,6 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); } diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 39813da66e85..bfc2fc6a1d0c 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6597,6 +6597,17 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.grid_sampler\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.bool) -> !torch.list {\n" +" %int0 = torch.constant.int 0\n" +" %int1 = torch.constant.int 1\n" +" %int2 = torch.constant.int 2\n" +" %0 = torch.aten.__getitem__.t %arg0, %int0 : !torch.list, !torch.int -> !torch.int\n" +" %1 = torch.aten.__getitem__.t %arg0, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %2 = torch.aten.__getitem__.t %arg1, %int1 : !torch.list, !torch.int -> !torch.int\n" +" %3 = torch.aten.__getitem__.t %arg1, %int2 : !torch.list, !torch.int -> !torch.int\n" +" %4 = torch.prim.ListConstruct %0, %1, %2, %3 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.prims.collapse\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.int) -> !torch.list {\n" " %true = torch.constant.bool true\n" " %str = torch.constant.str \"AssertionError: start must be less than or equal to end\"\n" @@ -9795,6 +9806,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.grid_sampler\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int, %arg4: !torch.bool) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" return %0#1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.reflection_pad1d\"(%arg0: !torch.tuple, %arg1: !torch.list) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: padding size expected to be 2\"\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 1a87bbb6bee1..403d124ad927 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -287,6 +287,10 @@ def aten〇_make_per_tensor_quantized_tensor〡shape(self: List[int], scale: flo def prims〇convert_element_type〡shape(a: List[int], dtype: int) -> List[int]: return upstream_shape_functions.unary(a) +def aten〇grid_sampler〡shape(input: List[int], grid: List[int], interpolation_mode: int, padding_mode: int, align_corners: bool) -> List[int]: + output = [input[0],input[1],grid[1],grid[2]] + return output + def prims〇collapse〡shape(a: List[int], start: int, end: int) -> List[int]: # Obtained through trial and error on a few examples in PyTorch: assert start < len(a), "start out of bounds" @@ -2152,6 +2156,10 @@ def aten〇constant_pad_nd〡dtype(self_rank_dtype: Tuple[int, int], pad: List[i self_rank, self_dtype = self_rank_dtype return self_dtype +def aten〇grid_sampler〡dtype(input_rank_dtype: Tuple[int, int], grid_rank_dtype: Tuple[int, int], interpolation_mode: int, padding_mode: int, align_corners: bool) -> int: + input_rank, input_dtype = input_rank_dtype + grid_rank, grid_dtype = input_rank_dtype + return input_dtype @check_dtype_function([ErrorInvocation(TensorOfShape(2, 3, 4), padding=1), ErrorInvocation(TensorOfShape(2, 3, 4), padding=[]), diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 64f03add759e..51c196421b78 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -714,6 +714,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::as_strided_scatter : (Tensor, Tensor, int[], int[], int?) -> (Tensor)") emit("aten::upsample_nearest2d : (Tensor, int[], float?, float?) -> (Tensor)") emit("aten::scaled_dot_product_attention : (Tensor, Tensor, Tensor, Tensor?, float, bool, float?) -> (Tensor)") + emit("aten::grid_sampler : (Tensor, Tensor, int, int, bool) -> (Tensor)") # Dict ops. emit("aten::__contains__.str : (Dict(str, t), str) -> (bool)", has_folder=True) diff --git a/test/Conversion/TorchToLinalg/gridsampler.mlir b/test/Conversion/TorchToLinalg/gridsampler.mlir new file mode 100644 index 000000000000..d392860fa2c1 --- /dev/null +++ b/test/Conversion/TorchToLinalg/gridsampler.mlir @@ -0,0 +1,60 @@ +// RUN: torch-mlir-opt <%s -convert-torch-to-linalg -split-input-file -verify-diagnostics | FileCheck %s + +// CHECK: #map +// CHECK-LABEL: func @grid_sampler +// CHECK-DAG: %[[TC0:.*]] = torch_c.to_builtin_tensor %[[ARG0:.*]] : !torch.vtensor<[4,10,10,4],f32> -> tensor<4x10x10x4xf32> +// CHECK-DAG: %[[TC1:.*]] = torch_c.to_builtin_tensor %[[ARG1:.*]] : !torch.vtensor<[4,6,8,2],f32> -> tensor<4x6x8x2xf32> +// CHECK-DAG: %[[FALSE:.*]] = torch.constant.bool false +// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index +// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index +// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 +// CHECK-DAG: %[[CST1:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-DAG: %[[CST2:.*]] = arith.constant 2.000000e+00 : f32 +// CHECK-DAG: %[[C2_3:.*]] = arith.constant 2 : index +// CHECK-DAG: %[[DIM:.*]] = tensor.dim %[[TC0]], %[[C2_3]] : tensor<4x10x10x4xf32> +// CHECK-DAG: %[[C3:.*]] = arith.constant 3 : index +// CHECK-DAG: %[[DIM_4:.*]] = tensor.dim %[[TC0]], %[[C3]] : tensor<4x10x10x4xf32> +// CHECK-DAG: %[[X2:.*]] = arith.subi %[[DIM:.*]], %[[C1]] : index +// CHECK-DAG: %[[X3:.*]] = arith.subi %[[DIM_4]], %[[C1:.*]] : index +// CHECK-DAG: %[[X4:.*]] = arith.index_cast %[[X2]] : index to i64 +// CHECK-DAG: %[[X5:.*]] = arith.index_cast %[[X3]] : index to i64 +// CHECK-DAG: %[[X6:.*]] = arith.sitofp %[[X4]] : i64 to f32 +// CHECK-DAG: %[[X7:.*]] = arith.sitofp %[[X5]] : i64 to f32 +// CHECK-DAG: %[[X8:.*]] = arith.divf %[[X6]], %[[CST2]] : f32 +// CHECK-DAG: %[[X9:.*]] = arith.divf %[[X7]], %[[CST2]] : f32 +func.func @grid_sampler(%arg0: !torch.vtensor<[4,10,10,4],f32>, %arg1: !torch.vtensor<[4,6,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %true = torch.constant.bool 0 + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 0 + %4 = torch.aten.grid_sampler %arg0, %arg1, %int0, %int1, %true : !torch.vtensor<[4,10,10,4],f32>, !torch.vtensor<[4,6,8,2],f32>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %4 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + +// CHECK-LABEL: func @grid_sampler2 +// CHECK: #map +// CHECK-DAG: %[[X15:.*]] = arith.mulf %[[X13:.*]], %[[X8:.*]] : f32 +// CHECK-DAG: %[[X16:.*]] = arith.mulf %[[X14:.*]], %[[X9:.*]] : f32 +// CHECK-DAG: %[[X40:.*]] = arith.mulf %[[EXTRACTED:.*]], %[[X39:.*]] : f32 +// CHECK-DAG: %[[X41:.*]] = arith.mulf %[[X31:.*]], %[[X37:.*]] : f32 +// CHECK-DAG: %[[X42:.*]] = arith.addf %[[X40:.*]], %[[X41]] : f32 +// CHECK-DAG: %[[X43:.*]] = arith.subf %[[CST_1:.*]], %[[X37]] : f32 +// CHECK-DAG: %[[X45:.*]] = arith.mulf %[[X34:.*]], %[[X37]] : f32 +// CHECK-DAG: %[[X46:.*]] = arith.addf %[[X44:.*]], %[[X45]] : f32 +// CHECK-DAG: %[[X47:.*]] = arith.subf %[[CST_1]], %[[X38:.*]] : f32 +// CHECK-DAG: %[[X48:.*]] = arith.mulf %[[X42]], %[[XX47:.*]] : f32 +// CHECK-DAG: %[[X49:.*]] = arith.mulf %[[X46]], %[[XX38:.*]] : f32 +// CHECK-DAG: %[[X50:.*]] = arith.addf %[[X48]], %[[X49]] : f32 +// CHECK-DAG: linalg.yield %[[X50]] : f32 +// CHECK: } -> tensor +// CHECK: %[[X12:.*]] = torch_c.from_builtin_tensor %[[X11:.*]] : tensor -> !torch.vtensor<[?,?,?,?],f32> +// CHECK: return %[[X12]] : !torch.vtensor<[?,?,?,?],f32> +func.func @grid_sampler2(%arg0: !torch.vtensor<[?,?,?,?],f32>, %arg1: !torch.vtensor<[?,?,?,?],f32>) -> !torch.vtensor<[?,?,?,?],f32> { + %true = torch.constant.bool 0 + %int0 = torch.constant.int 0 + %int1 = torch.constant.int 0 + %4 = torch.aten.grid_sampler %arg0, %arg1, %int0, %int1, %true : !torch.vtensor<[?,?,?,?],f32>, !torch.vtensor<[?,?,?,?],f32>, !torch.int, !torch.int, !torch.bool -> !torch.vtensor<[?,?,?,?],f32> + return %4 : !torch.vtensor<[?,?,?,?],f32> +} \ No newline at end of file diff --git a/test/Conversion/TorchToLinalg/pooling.mlir b/test/Conversion/TorchToLinalg/pooling.mlir index 8ed75f648f5e..8a359ed5627d 100644 --- a/test/Conversion/TorchToLinalg/pooling.mlir +++ b/test/Conversion/TorchToLinalg/pooling.mlir @@ -71,6 +71,5 @@ func.func @forward_max_pool3d(%arg0: !torch.vtensor<[?,?,?,?,?],f32>) -> !torch. // CHECK-NEXT: %[[MAXF:.*]] = arith.maximumf %[[CURRENT_VALUE:.*]], %[[ACC_OUT:.*]] : f32 // CHECK-NEXT: linalg.yield %[[MAXF:.*]] : f32 // CHECK: } -> tensor - return %4 : !torch.vtensor<[?,?,?,?,?],f32> } From 4147b280cef981e53dfaa171572cf783c0fe37c2 Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Fri, 23 Feb 2024 11:57:20 -0800 Subject: [PATCH 220/283] [torch-mlir][sparse] add block sparsity to mlir lowering (#2942) Also note that we are in the process of proposing SparseTensorMetadata to PyTorch FX graph export (see https://github.com/pytorch/pytorch/pull/117907). This will hopefully eventually replace the current data structures in torch-mlir. --- python/torch_mlir/extras/fx_importer.py | 35 ++++++++++---- test/python/fx_importer/sparse_test.py | 63 +++++++++++++++++-------- 2 files changed, 69 insertions(+), 29 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 89a3caa16843..aea76c621d46 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -216,14 +216,20 @@ @dataclass(frozen=True) class SparsityMeta: - """Class for keeping track of sparsity meta data.""" + """ + Class for keeping track of sparsity meta data. + + NOTE: this will be fully replaced by + torch.fx.passes.shape_prop.SparseTensorMetadata + """ layout: torch.layout batch_dim: int sparse_dim: int dense_dim: int - pos_width: int - crd_width: int + blocksize: Optional[tuple[int, int]] + pos_dtype: torch.dtype + crd_dtype: torch.dtype def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str: @@ -240,21 +246,31 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str: ) dim = batch_dim + sparse_dim + dense_dim assert dim == len(shape) + blocksize = sparsity.blocksize dims = ",".join(f"d{d}" for d in range(0, dim)) if sparsity.layout is torch.sparse_coo: - assert sparse_dim == 2 # TODO: deeper sparse dims + assert sparse_dim == 2 and blocksize is None # TODO: deeper sparse dims lvls = f"d{batch_dim}:compressed(nonunique),d{batch_dim+1}:singleton" elif sparsity.layout is torch.sparse_csr: - assert sparse_dim == 2 + assert sparse_dim == 2 and blocksize is None lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed" elif sparsity.layout is torch.sparse_csc: - assert sparse_dim == 2 + assert sparse_dim == 2 and blocksize is None lvls = f"d{batch_dim+1}:dense,d{batch_dim}:compressed" else: - # TODO: block format (derive block size!) - raise RuntimeError(f"Unsupported sparse layout {sparse_layout}") + assert sparse_dim == 2 and blocksize is not None + if sparsity.layout is torch.sparse_bsr: + i, j = batch_dim, batch_dim + 1 + else: + assert sparsity.layout is torch.sparse_bsc + j, i = batch_dim, batch_dim + 1 + m, n = blocksize + lvls = ( + f"d{i} floordiv {m}:dense,d{j} floordiv {n}:compressed," + f"d{i} mod {m}:dense,d{j} mod {n}:dense" + ) if batch_dim > 0: batch = ",".join(f"d{d}:dense" for d in range(0, batch_dim)) @@ -264,7 +280,8 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str: dense = ",".join(f"d{d}:dense" for d in range(batch_dim + sparse_dim, dim)) lvls = f"{lvls},{dense}" - posw, crdw = sparsity.pos_width, sparsity.crd_width + posw = torch.iinfo(sparsity.pos_dtype).bits + crdw = torch.iinfo(sparsity.crd_dtype).bits return f"#sparse_tensor.encoding<{{map=({dims})->({lvls}),posWidth={posw},crdWidth={crdw}}}>" diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index e936e40cb039..87eecb2977d5 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -31,50 +31,49 @@ ] -def sparse_overhead_width(d: torch.dtype) -> int: - """Returns bit-width for admissible overhead type.""" - if d is torch.int64: - return 64 - if d is torch.int32: - return 32 - if d is torch.int16: - return 16 - if d is torch.int8: - return 8 - raise RuntimeError(f"Unsupported overhead type {d}") - - def sparse_metadata(a: torch.Tensor) -> SparsityMeta: - """Returns a meta data tuple for the given sparse tensor.""" + """ + Returns a meta data tuple for the given sparse tensor. + + NOTE: this will be fully replaced by fx graph SparseTensorMetadata + """ sparse_dim = a.sparse_dim() dense_dim = a.dense_dim() batch_dim = a.ndim - dense_dim - sparse_dim + blocksize = None if a.layout is torch.sparse_coo: return SparsityMeta( a.layout, batch_dim, sparse_dim, dense_dim, - sparse_overhead_width(a.indices().dtype), - sparse_overhead_width(a.indices().dtype), + blocksize, + a.indices().dtype, + a.indices().dtype, ) elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: + if a.layout is torch.sparse_bsr: + blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3] return SparsityMeta( a.layout, batch_dim, sparse_dim, dense_dim, - sparse_overhead_width(a.crow_indices().dtype), - sparse_overhead_width(a.col_indices().dtype), + blocksize, + a.crow_indices().dtype, + a.col_indices().dtype, ) elif a.layout is torch.sparse_csc or a.layout is torch.sparse_bsc: + if a.layout is torch.sparse_bsc: + blocksize = a.values().shape[batch_dim + 1 : batch_dim + 3] return SparsityMeta( a.layout, batch_dim, sparse_dim, dense_dim, - sparse_overhead_width(a.ccol_indices().dtype), - sparse_overhead_width(a.row_indices().dtype), + blocksize, + a.ccol_indices().dtype, + a.row_indices().dtype, ) else: raise RuntimeError(f"Unsupported sparse layout for {a}") @@ -214,6 +213,30 @@ def forward(self, x): print("torch.mlir =", res2) +@run +# CHECK-LABEL: test_sparse_SpMV +# CHECK: #[[$BSR:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 floordiv 2 : dense, d1 floordiv 2 : compressed, d0 mod 2 : dense, d1 mod 2 : dense), posWidth = 64, crdWidth = 64 }> +# CHECK: func.func @main( +# CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[10,10],f32,#[[$BSR]]>, +# CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[10],f32>) -> !torch.vtensor<[10],f32> { +# CHECK: %[[R:.*]] = torch.aten.mv %[[A]], %[[B]] : !torch.vtensor<[10,10],f32,#[[$BSR]]>, !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> +# CHECK: return %[[R]] : !torch.vtensor<[10],f32> +# CHECK: } +def test_sparse_SpMV(): + class SpMVNet(torch.nn.Module): + def __init__(self): + super(SpMVNet, self).__init__() + + def forward(self, x, v): + return torch.mv(x, v) + + dense_vector = torch.ones(10) + dense_input = torch.ones(10, 10) + sparse_input = dense_input.to_sparse_bsr(blocksize=(2, 2)) + m = export_and_import(SpMVNet(), sparse_input, dense_vector) + print(m) + + @run # CHECK-LABEL: test_sparse_SpMM # CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton), posWidth = 64, crdWidth = 64 }> From ec2b80b433c9a1b56352f9851d5258218f8740ab Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 23 Feb 2024 13:13:54 -0800 Subject: [PATCH 221/283] [ci] Fix mpmath 1.4.0 error by forcing 1.3.0 (#2946) `mpmath 1.4.0` changes some import locations breaking `torch`. Changing to `1.3.0` to avoid breaking on `python 3.11` --- test-requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/test-requirements.txt b/test-requirements.txt index c8e8e2bc6e5a..b21e8dfcd021 100644 --- a/test-requirements.txt +++ b/test-requirements.txt @@ -2,3 +2,4 @@ pillow dill multiprocess onnx==1.15.0 +mpmath==1.3.0 From 89e02c195b910621246c55003fca86558162c6da Mon Sep 17 00:00:00 2001 From: Stella Laurenzo Date: Fri, 23 Feb 2024 15:52:27 -0800 Subject: [PATCH 222/283] Make a typing dependency that is not in older PyTorch backwards compatible. (#2948) This was found in a downstream that is pegged to an older PyTorch version. --- python/torch_mlir/extras/fx_importer.py | 25 +++++++++++++++++++++++-- 1 file changed, 23 insertions(+), 2 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index aea76c621d46..91f3c27ee263 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -16,7 +16,18 @@ import re from dataclasses import dataclass from types import BuiltinMethodType, BuiltinFunctionType -from typing import Any, Callable, Dict, List, Optional, Sequence, Set, Tuple, Union +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Set, + Tuple, + TYPE_CHECKING, + Union, +) import weakref import numpy as np @@ -45,6 +56,16 @@ Node, ) +try: + from torch.export.graph_signature import InputSpec as TypingInputSpec +except ModuleNotFoundError: + # PyTorch prior to 2.3 is missing certain things we use in typing + # signatures. Just make them be Any. + if not TYPE_CHECKING: + TypingInputSpec = Any + else: + raise + try: import ml_dtypes except ModuleNotFoundError: @@ -299,7 +320,7 @@ class InputInfo: """Provides additional metadata when resolving inputs.""" program: torch.export.ExportedProgram - input_spec: torch.export.graph_signature.InputSpec + input_spec: TypingInputSpec node: Node ir_type: IrType mutable_producer_node_name: Optional[str] = None From c5a1da1910f8e1a5dac748eb2806833bd4f1b0c2 Mon Sep 17 00:00:00 2001 From: ptrifunovic98 <156185835+ptrifunovic98@users.noreply.github.com> Date: Mon, 26 Feb 2024 17:46:56 +0100 Subject: [PATCH 223/283] Implement lowering of torch.aten.norm.Scalar (#2899) Closes [nod-ai/SHARK-Turbine#365](https://github.com/nod-ai/SHARK-Turbine/issues/365) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 +++++++++ lib/Conversion/TorchToLinalg/Reduction.cpp | 53 ++++++++++++++++--- lib/Dialect/Torch/IR/TorchOps.cpp | 36 +++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 32 +++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 1 + .../build_tools/abstract_interp_lib_gen.py | 18 +++++++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/reduction.py | 19 +++++++ 8 files changed, 177 insertions(+), 8 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index cc8be7c6910b..dc1203de9471 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -6325,6 +6325,31 @@ def Torch_AtenLayerNormOp : Torch_Op<"aten.layer_norm", [ }]; } +def Torch_AtenNormScalarOp : Torch_Op<"aten.norm.Scalar", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchScalarType:$p + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenNormScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenNormScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasVerifier = 1; +} + def Torch_AtenNormScalarOptDimOp : Torch_Op<"aten.norm.ScalarOpt_dim", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index a21615ad84c4..e050764993e6 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -275,7 +275,8 @@ static Value createInitElementForReduceOp(OpBuilder &b, Location loc, elementType.getIntOrFloatBitWidth()))); } - if (isa(op) || isa(op)) + if (isa(op) || isa(op) || + isa(op)) return b.create(loc, b.getZeroAttr(elementType)); if (isa(op)) { @@ -341,6 +342,26 @@ static Value createLinalgPayloadForReduceOp(OpBuilder &b, Location loc, if (intType.isSigned()) return b.create(loc, self, result); } + } else if (isa(op)) { + // This creates payload for only the first of the two linalg.generic ops. + // TODO: Short-circuit operations if `p` is zero or one. + Value elem = payloadArgs[0]; + Value result = payloadArgs[1]; + + // TODO: Fix this part to support complex elements. + if (elem.getType().isa()) { + op->emitError("lowering of complex input type for torch.aten.norm.Scalar " + "is currently unimplemented"); + return nullptr; + } + + Value self = convertScalarToDtype(b, loc, elem, resultElementType); + + auto abs = b.create(loc, self); + AtenNormScalarOp::Adaptor adaptor(operands); + Value p = convertScalarToDtype(b, loc, adaptor.getP(), resultElementType); + auto pow = b.create(loc, abs, p); + return b.create(loc, pow, result); } else if (isa(op)) { // This creates payload for only the first of the two linalg.generic ops. // TODO: Short-circuit operations if `ord` is zero or one. @@ -433,7 +454,7 @@ class ConvertReductionOp : public ConversionPattern { ConversionPatternRewriter &rewriter) const { auto opInfo = torch_to_linalg::ReductionOpInfo{false, Value{}, {}}; - if (isa(op)) { + if (isa(op)) { opInfo.tensorOperand = operands[0]; auto inputType = opInfo.tensorOperand.getType().cast(); @@ -484,10 +505,12 @@ class ConvertReductionOp : public ConversionPattern { return err ? Value{} : powOp; } - FailureOr createSecondReductionForVectorNormOp( - Location loc, Type elemType, AtenLinalgVectorNormOp op, Value ordOp, - Value firstReduction, const torch_to_linalg::ReductionOpInfo &opInfo, - ConversionPatternRewriter &rewriter) const { + template + FailureOr + createSecondReductionForNormOp(Location loc, Type elemType, TOp op, + Value ordOp, Value firstReduction, + const torch_to_linalg::ReductionOpInfo &opInfo, + ConversionPatternRewriter &rewriter) const { // Cast `ord` to float so that we can readily pass it math.powf. Value ordValue = convertScalarToDtype(rewriter, loc, ordOp, elemType); @@ -544,13 +567,15 @@ class ConvertReductionOp : public ConversionPattern { LogicalResult validateReductionElementType(Operation *op, Type elemType, ConversionPatternRewriter &rewriter) const { - if ((isa(op) || isa(op)) && + if ((isa(op) || isa(op) || + isa(op)) && !elemType.isa()) return rewriter.notifyMatchFailure( op, "only float types are valid for vector norm ops"); if (isa(op) && elemType.isa() && elemType.getIntOrFloatBitWidth() == 8) return rewriter.notifyMatchFailure(op, "uint8 is not supported"); + // No checks for all other reduction operations return success(); } @@ -587,11 +612,22 @@ class ConvertReductionOp : public ConversionPattern { return rewriter.notifyMatchFailure( op, "failed to create linalg.generic operation for reduction"); + // If this is aten.norm.Scalar op, then we need to generate another + // linalg.generic op that references the first linalg.generic op. + if (isa(op)) { + AtenNormScalarOp::Adaptor adaptor(operands); + FailureOr secondReduceOp = createSecondReductionForNormOp( + loc, elemType, op, adaptor.getP(), reduceOp, *opInfo, rewriter); + if (failed(secondReduceOp)) + return secondReduceOp; + reduceOp = *secondReduceOp; + } + // If this is aten.linalg_vector_norm op, then we need to generate another // linalg.generic op that references the first linalg.generic op. if (auto normOp = dyn_cast(op)) { AtenLinalgVectorNormOp::Adaptor adaptor(operands); - FailureOr secondReduceOp = createSecondReductionForVectorNormOp( + FailureOr secondReduceOp = createSecondReductionForNormOp( loc, elemType, normOp, adaptor.getOrd(), reduceOp, *opInfo, rewriter); if (failed(secondReduceOp)) return secondReduceOp; @@ -627,6 +663,7 @@ void mlir::torch::torch_to_linalg::populateReductionPatternsAndLegality( target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index da6f71015942..ef3098eb1c12 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -3767,6 +3767,42 @@ LogicalResult ShapeCalculateYieldShapesOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// AtenNormScalarOp +//===----------------------------------------------------------------------===// + +LogicalResult AtenNormScalarOp::verify() { + + // Verificaion of input type for torch.aten.norm.Scalar. + // Per PyTorch docs, only float and complex types are valid for norm + // operation. + + auto inTensor = getSelf().getType().cast(); + + // If no dtype is specified, it will default to a float one. + if (!inTensor.hasDtype()) { + return success(); + } + + auto inTensorDtype = inTensor.getDtype(); + + // Check if dtype is one of those supported by norm operation. + // ComplexType will match any torch complex types, but each float must be + // checked individually. + if (!inTensorDtype.isa()) { + return emitOpError( + "expected a float or complex type for input tensor, but got ") + << inTensorDtype; + } + + return success(); +} + +//===----------------------------------------------------------------------===// +// AtenPermuteOp +//===----------------------------------------------------------------------===// + LogicalResult AtenPermuteOp::verify() { // Verification of the permute op for input & output dimensions with diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index bfc2fc6a1d0c..a8327b0e0da6 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9339,6 +9339,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %arg2, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %2 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.norm.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" +" %false = torch.constant.bool false\n" +" %none = torch.constant.none\n" +" %0 = torch.derefine %none : !torch.none to !torch.optional>\n" +" %1 = torch.derefine %none : !torch.none to !torch.any\n" +" %2 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %0, %false, %1) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" return %2 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.norm.ScalarOpt_dim\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.list, %arg3: !torch.bool) -> !torch.list {\n" " %int0 = torch.constant.int 0\n" " %0 = torch.derefine %arg2 : !torch.list to !torch.optional>\n" @@ -12038,6 +12046,30 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %int5 = torch.constant.int 5\n" +" %int8 = torch.constant.int 8\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.eq.int %0#1, %int8 : !torch.int, !torch.int -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int5 : !torch.int\n" +" } else {\n" +" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.tensor.float\"(%arg0: !torch.float, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.bool) -> !torch.int {\n" " %int6 = torch.constant.int 6\n" " %none = torch.constant.none\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e749b5834cc6..70f26fe421e0 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1667,6 +1667,7 @@ "NllLossModule_ignore_index_out_of_bounds_basic", "NllLossModule_mean_basic", "NllLossModule_sum_basic", + "NormScalarModule_basic", "NormScalarOptDimKeepDimModule_basic", "NormScalarOptDimModule_basic", "NormalFunctionalModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 403d124ad927..99f4f2200d35 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1722,6 +1722,9 @@ def aten〇linalg_vector_norm〡shape(self: List[int], ord: float = 2, dim: Opti def aten〇frobenius_norm〇dim〡shape(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0) +def aten〇norm〇Scalar〡shape(self: List[int], p: float = 2) -> List[int]: + return upstream_shape_functions.sum_mean_dim(self, None, False, None) + def aten〇norm〇ScalarOpt_dim〡shape(self: List[int], p: Optional[float], dim: List[int], keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0) @@ -3924,6 +3927,21 @@ def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Uni return dtype return aten〇std〡dtype(self_rank_dtype) +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64})) +def aten〇norm〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], p: Union[int, float, complex] = 2) -> int: + self_rank, self_dtype = self_rank_dtype + assert not is_integer_dtype(self_dtype) + # The following check is added because aten〇std〡dtype + # does not handle complex32 transformation to float, + # so it is done manually (torch.half == torch.float16). + # Should possibly be added to aten〇std〡dtype. + if self_dtype == torch.complex32: + return torch.half + return aten〇std〡dtype(self_rank_dtype) + @check_dtype_function([Invocation(0.0), Invocation(0.0, dtype=torch.int32), Invocation(0.0, dtype=torch.float16), diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 51c196421b78..cc41a99be228 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -449,6 +449,7 @@ def emit_with_mutating_variants(key, **kwargs): emit( "aten::layer_norm : (Tensor, int[], Tensor?, Tensor?, float, bool) -> (Tensor)" ) + emit("aten::norm.Scalar : (Tensor, Scalar) -> (Tensor)", has_verifier=True) emit( "aten::norm.ScalarOpt_dim : (Tensor, Scalar?, int[], bool) -> (Tensor)" ) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index 2c61524bd797..d0d6c2ea2dfa 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -1100,6 +1100,25 @@ def ReduceL3NormKeepDimModule_basic(module, tu: TestUtils): # ============================================================================== +class NormScalarModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + self.p = 3.0 + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.norm(a, self.p) + +@register_test_case(module_factory=lambda: NormScalarModule()) +def NormScalarModule_basic(module, tu: TestUtils): + module.forward(tu.rand(3, 4, 5)) + +# ============================================================================== + class NormScalarOptDimModule(torch.nn.Module): def __init__(self) -> None: super().__init__() From 3cbe6c98ec9a67964ecb5947f7664e34e9ba4b5b Mon Sep 17 00:00:00 2001 From: Sambhav Jain Date: Mon, 26 Feb 2024 10:08:14 -0800 Subject: [PATCH 224/283] Expose `func_name` to the main fx import API (#2949) As titled. --- python/torch_mlir/extras/fx_importer.py | 4 ++-- python/torch_mlir/fx.py | 5 +++-- test/python/fx_importer/basic_test.py | 21 +++++++++++++++++++++ 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 91f3c27ee263..e6d0f03deda4 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -623,7 +623,7 @@ def import_program( node_importer.return_node_values(loc, user_outputs) self.symbol_table.insert(func_op) - def import_frozen_program(self, prog: torch.export.ExportedProgram): + def import_frozen_program(self, prog: torch.export.ExportedProgram, func_name: str = "main"): """Imports a consolidated torch.export.ExportedProgram instance. If using the new torch.export path (vs a lower level precursor), then this is @@ -702,7 +702,7 @@ def import_frozen_program(self, prog: torch.export.ExportedProgram): node.replace_all_uses_with(replacement) g.erase_node(node) - self.import_stateless_graph(g) + self.import_stateless_graph(g, func_name) def import_graph_module(self, gm: GraphModule): """Low-level import of a GraphModule assuming that it has been functionalized. diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 1f5aa8f74add..76cd91f82e0a 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -23,6 +23,7 @@ def export_and_import( constraints: Optional[torch.export.Constraint] = None, experimental_support_mutation: bool = False, hooks: Optional[FxImporterHooks] = None, + func_name: str = "main", **kwargs, ): context = ir.Context() @@ -36,8 +37,8 @@ def export_and_import( if experimental_support_mutation: if torch.__version__ < "2.3.0.dev20240207": warnings.warn("Mutable program import only supported on PyTorch 2.3+") - fx_importer.import_program(prog) + fx_importer.import_program(prog, func_name=func_name) else: - fx_importer.import_frozen_program(prog) + fx_importer.import_frozen_program(prog, func_name=func_name) return fx_importer.module_op diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index 36c554862506..fc5b2030b648 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -56,3 +56,24 @@ def forward(self, x): m = fx.export_and_import(Basic(), torch.randn(3, 4)) print(m) + + +@run +# CHECK-LABEL: test_import_frozen_exported_program_with_func_name +# CHECK: func.func @test_net(%[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> +def test_import_frozen_exported_program_with_func_name(): + @torch._dynamo.assume_constant_result + def get_a(): + return torch.randn(1, 4) + + class Basic(nn.Module): + def __init__(self): + super().__init__() + self.b = torch.randn(3, 1) + self.p = nn.Parameter(torch.randn(1, 1)) + + def forward(self, x): + return torch.tanh(x) * get_a() * self.b * self.p + + m = fx.export_and_import(Basic(), torch.randn(3, 4), func_name="test_net") + print(m) From d81747eadbbdc0f97b64dd2964aedb6497de4435 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 27 Feb 2024 11:02:05 +0530 Subject: [PATCH 225/283] [MLIR][TORCH] Extend support for OnnxToLinalg lowering for Dropout and Div op (#2938) Fixes https://github.com/nod-ai/SHARK-Turbine/issues/451, https://github.com/nod-ai/SHARK-Turbine/issues/452 --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 2 + .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 38 ++++++++++++--- .../TorchToLinalg/TensorScalarInterop.cpp | 20 +++++--- .../TorchToLinalg/Uncategorized.cpp | 14 ++++-- lib/Dialect/Torch/IR/TorchOps.cpp | 32 +++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 9 +++- .../build_tools/torch_ods_gen.py | 4 +- .../torch_mlir_e2e_test/test_suite/basic.py | 44 ++++++++++++++++++ .../test_suite/elementwise.py | 46 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 9 ++++ test/Dialect/Torch/canonicalize.mlir | 46 +++++++++++++++++++ 11 files changed, 243 insertions(+), 21 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index dc1203de9471..57b15ed18f4e 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -11231,6 +11231,7 @@ def Torch_AtenIntImplicitOp : Torch_Op<"aten.IntImplicit", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasCanonicalizer = 1; } def Torch_AtenFloatImplicitOp : Torch_Op<"aten.FloatImplicit", [ @@ -11254,6 +11255,7 @@ def Torch_AtenFloatImplicitOp : Torch_Op<"aten.FloatImplicit", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasCanonicalizer = 1; } def Torch_AtenTensorFloatOp : Torch_Op<"aten.tensor.float", [ diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 99a3985a2993..1c356db890db 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1339,12 +1339,38 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( Value ratio, trainingMode; if (numOperands == 3) { ratio = rewriter.create(loc, operands[1]); - Value trainingModeScalar = - rewriter.create(loc, operands[2]); - Value cstOne = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); - trainingMode = rewriter.create( - loc, trainingModeScalar, cstOne); + Value trainVal = operands[2]; + auto trainTensorType = + trainVal.getType().dyn_cast(); + if (!trainTensorType) + return rewriter.notifyMatchFailure(binder.op, + "train tensor must have a type"); + + Type inputDtype = trainTensorType.getOptionalDtype(); + if (!inputDtype || !inputDtype.isInteger(1)) + return rewriter.notifyMatchFailure( + binder.op, + "train tensor must have an integer dtype of width 1"); + + std::optional inputRank = Torch::getTensorRank(trainVal); + if (!inputRank || *inputRank != 0) + return rewriter.notifyMatchFailure(binder.op, + "train tensor must have rank 0"); + + if (auto valueTensorLiteralOp = + trainVal.getDefiningOp()) { + auto val = valueTensorLiteralOp.getValue() + .cast() + .getSplatValue(); + trainingMode = rewriter.create(loc, val); + } else { + Value trainingModeScalar = + rewriter.create(loc, operands[2]); + Value cstOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + trainingMode = rewriter.create( + loc, trainingModeScalar, cstOne); + } } else if (numOperands == 2) { ratio = rewriter.create(loc, operands[1]); trainingMode = rewriter.create(loc, false); diff --git a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp index a1e8e5fb72d9..58e6daa9bca8 100644 --- a/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp +++ b/lib/Conversion/TorchToLinalg/TensorScalarInterop.cpp @@ -191,12 +191,14 @@ class ConvertPrimNumToTensorScalarOp } // namespace namespace { -class ConvertAtenScalarImplicitOp - : public OpConversionPattern { +// Converts a tensor with one element to a scalar value. +template +class ConvertAtenImplicitLikeOp : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(AtenScalarImplicitOp op, OpAdaptor adaptor, + matchAndRewrite(OpTy op, + typename OpConversionPattern::OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { rewriter.replaceOpWithNewOp(op, adaptor.getA()); return success(); @@ -224,6 +226,12 @@ void mlir::torch::torch_to_linalg:: patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); - patterns.add(typeConverter, context); - target.addIllegalOp(); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + patterns.add>(typeConverter, + context); + target.addIllegalOp(); } diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index ed6883000cf9..87163fc95c4a 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -725,13 +725,17 @@ static Value createLinalgPayloadCalculationForElementwiseOp( Type dtype = converter->convertType(div.getType()) .cast() .getElementType(); - if (!dtype.isa()) { - div.emitError("unimplemented: non-floating point dtype"); - return nullptr; - } Value lhs = convertScalarToDtype(b, loc, payloadArgs[0], dtype); Value rhs = convertScalarToDtype(b, loc, payloadArgs[1], dtype); - return b.create(loc, lhs, rhs); + if (dtype.isa()) + return b.create(loc, lhs, rhs); + else if (dtype.isa()) { + if (dtype.isUnsignedInteger()) + return b.create(loc, lhs, rhs); + return b.create(loc, lhs, rhs); + } + div.emitError("unimplemented: non-floating point and non-integer dtype"); + return nullptr; } if (auto divTensorMode = dyn_cast(op)) { AtenDivTensorModeOp::Adaptor adaptor(operands); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index ef3098eb1c12..2f0884b1344e 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1568,6 +1568,38 @@ void AtenScalarImplicitOp::getCanonicalizationPatterns( }); } +//===----------------------------------------------------------------------===// +// AtenFloatImplicitOp +//===----------------------------------------------------------------------===// +void AtenFloatImplicitOp::getCanonicalizationPatterns( + RewritePatternSet &patterns, MLIRContext *context) { + patterns.add(+[](AtenFloatImplicitOp op, PatternRewriter &rewriter) { + Location loc = op.getLoc(); + Value a = op.getA(); + Value scalarValue = getScalarFloatValue(a, loc, rewriter); + if (!scalarValue) + return failure(); + rewriter.replaceOp(op, scalarValue); + return success(); + }); +} + +//===----------------------------------------------------------------------===// +// AtenIntImplicitOp +//===----------------------------------------------------------------------===// +void AtenIntImplicitOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenIntImplicitOp op, PatternRewriter &rewriter) { + Location loc = op.getLoc(); + Value a = op.getA(); + Value scalarValue = getScalarIntValue(a, loc, rewriter); + if (!scalarValue) + return failure(); + rewriter.replaceOp(op, scalarValue); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenSizeOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 70f26fe421e0..82c1a0759e2e 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -335,6 +335,9 @@ # Dynamo not supporting conv_tbc "ConvTbcModule_basic", + + "FloatImplicitModule_basic", + "IntImplicitModule_basic", } TORCHDYNAMO_CRASHING_SET = { @@ -989,6 +992,8 @@ "ElementwiseCloneContiguousModule_basic", "ElementwiseCloneModule_basic", "ElementwiseDivScalarModule_basic", + "ElementwiseDivTensorIntegerModule_basic", + "ElementwiseDivTensorUnsignedIntegerModule_basic", "ElementwiseEluModule_basic", "ElementwiseEluNonDefaultModule_basic", "ElementwiseEqBoolScalarModule_basic", @@ -2146,8 +2151,6 @@ "ElementwiseSigmoidIntModule_basic", # Failure - unknown - "ChunkListUnpackUneven_Module_basic", - "ChunkListUnpack_Module_basic", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "CopyWithDifferentDTypesAndSizesModule_basic", "CopyWithDifferentDTypesModule_basic", @@ -2168,6 +2171,8 @@ "ReduceMinAlongDimUnsignedInt_basic", "TensorsStackNegativeDimModule_basic", "TensorsStackPromoteDTypeModule_basic", + "FloatImplicitModule_basic", + "IntImplicitModule_basic", } ONNX_CRASHING_SET = { } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index cc41a99be228..c81f543b5dc9 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -669,8 +669,8 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::gather : (Tensor, int, Tensor, bool) -> (Tensor)") emit_with_mutating_variants("aten::scatter_add : (Tensor, int, Tensor, Tensor) -> (Tensor)") emit_with_mutating_variants("aten::scatter_reduce.two : (Tensor, int, Tensor, Tensor, str, bool) -> (Tensor)") - emit("aten::IntImplicit : (Tensor) -> (int)") - emit("aten::FloatImplicit : (Tensor) -> (float)") + emit("aten::IntImplicit : (Tensor) -> (int)", has_canonicalizer=True) + emit("aten::FloatImplicit : (Tensor) -> (float)", has_canonicalizer=True) emit("aten::tensor.float : (float, int?, Device?, bool) -> (Tensor)") emit("aten::Int.Tensor : (Tensor) -> (int)", has_folder=True) emit("aten::Float.Tensor : (Tensor) -> (float)", has_folder=True) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index 7e707893911a..c5ef92d41637 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -3719,6 +3719,50 @@ def ScalarImplicitIntModule_basic(module, tu: TestUtils): module.forward(tu.randint(low=-100, high=100)) +# ============================================================================== + + +class FloatImplicitModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.float64, True), + ]) + def forward(self, x): + return float(torch.ops.aten.FloatImplicit(x)) + + +@register_test_case(module_factory=lambda: FloatImplicitModule()) +def FloatImplicitModule_basic(module, tu: TestUtils): + module.forward(tu.rand().double()) + + +# ============================================================================== + + +class IntImplicitModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([], torch.int64, True), + ]) + def forward(self, x): + return float(torch.ops.aten.IntImplicit(x)) + + +@register_test_case(module_factory=lambda: IntImplicitModule()) +def IntImplicitModule_basic(module, tu: TestUtils): + module.forward(tu.randint()) + + # ============================================================================== class PowIntFloat(torch.nn.Module): diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index ad4abd9f1752..611effdb2338 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -2595,6 +2595,52 @@ def ElementwiseDivTensorFloatModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseDivTensorIntegerModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.int64, True), + ([-1, -1], torch.int32, True), + ]) + def forward(self, a, b): + return torch.div(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseDivTensorIntegerModule()) +def ElementwiseDivTensorIntegerModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=-10, high=10), tu.randint(3, 4, low=-10, high=10).type(torch.int32)) + + +# ============================================================================== + + +class ElementwiseDivTensorUnsignedIntegerModule(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1], torch.uint8, True), + ([-1, -1], torch.uint8, True), + ]) + def forward(self, a, b): + return torch.div(a, b) + + +@register_test_case(module_factory=lambda: ElementwiseDivTensorUnsignedIntegerModule()) +def ElementwiseDivTensorUnsignedIntegerModule_basic(module, tu: TestUtils): + module.forward(tu.randint(3, 4, low=0, high=10).to(torch.uint8), tu.randint(3, 4, low=0, high=10).type(torch.uint8)) + + +# ============================================================================== + + class ElementwiseDivRoundingModeTruncModule(torch.nn.Module): def __init__(self): diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 525583b7660e..7dc262228f1a 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -706,6 +706,15 @@ func.func @test_div(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3 // ----- +// CHECK-LABEL: @test_div_int32 +func.func @test_div_int32(%arg0: !torch.vtensor<[3,4,5],si32>, %arg1: !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32> -> !torch.vtensor<[3,4,5],si32> + %0 = torch.operator "onnx.Div"(%arg0, %arg1) : (!torch.vtensor<[3,4,5],si32>, !torch.vtensor<[3,4,5],si32>) -> !torch.vtensor<[3,4,5],si32> + return %0 : !torch.vtensor<[3,4,5],si32> +} + +// ----- + // CHECK-LABEL: @test_div_uint8 func.func @test_div_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.div.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],ui8>, !torch.vtensor<[3,4,5],ui8> -> !torch.vtensor<[3,4,5],ui8> diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 4df52cfb174b..85b95eb1cdba 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2145,6 +2145,52 @@ func.func @torch.aten.ScalarImplicit$canonicalize_literal_0d() -> !torch.number return %1 : !torch.number } +// ----- + +// CHECK-LABEL: func.func @torch.aten.FloatImplicit$canonicalize_numtotensor_0d() -> !torch.float { +// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 +// CHECK: return %[[FLOAT1]] : !torch.float +func.func @torch.aten.FloatImplicit$canonicalize_numtotensor_0d() -> !torch.float { + %float1 = torch.constant.float 1.0 + %0 = torch.prim.NumToTensor.Scalar %float1 : !torch.float -> !torch.vtensor<[],f64> + %1 = torch.aten.FloatImplicit %0 : !torch.vtensor<[],f64> -> !torch.float + return %1 : !torch.float +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.FloatImplicit$canonicalize_literal_0d() -> !torch.float { +// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00 +// CHECK: return %[[FLOAT1]] : !torch.float +func.func @torch.aten.FloatImplicit$canonicalize_literal_0d() -> !torch.float { + %0 = torch.vtensor.literal(dense<1.0> : tensor) : !torch.vtensor<[],f64> + %1 = torch.aten.FloatImplicit %0 : !torch.vtensor<[],f64> -> !torch.float + return %1 : !torch.float +} + +// ----- + +// CHECK-LABEL: func.func @torch.aten.IntImplicit$canonicalize_numtotensor_0d() -> !torch.int { +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: return %[[INT1]] : !torch.int +func.func @torch.aten.IntImplicit$canonicalize_numtotensor_0d() -> !torch.int { + %int1 = torch.constant.int 1 + %0 = torch.prim.NumToTensor.Scalar %int1 : !torch.int -> !torch.vtensor<[],si64> + %1 = torch.aten.IntImplicit %0 : !torch.vtensor<[],si64> -> !torch.int + return %1 : !torch.int +} + +// CHECK-LABEL: func.func @torch.aten.IntImplicit$canonicalize_literal_0d() -> !torch.int { +// CHECK: %[[INT1:.*]] = torch.constant.int 1 +// CHECK: return %[[INT1]] : !torch.int +func.func @torch.aten.IntImplicit$canonicalize_literal_0d() -> !torch.int { + %0 = torch.vtensor.literal(dense<1> : tensor) : !torch.vtensor<[],si64> + %1 = torch.aten.IntImplicit %0 : !torch.vtensor<[],si64> -> !torch.int + return %1 : !torch.int +} + +// ----- + // CHECK-LABEL: func.func @torch.prims.view_of$fold( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> { // CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[3,4,2],f32> From d628b5fd060eaff4c9ae858fa5fe79356b4018fa Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Tue, 27 Feb 2024 19:26:01 +0530 Subject: [PATCH 226/283] [MLIR][TORCH] Add support for tanh approximation for Gelu op (#2941) Fixes https://github.com/nod-ai/SHARK-Turbine/issues/461 Signed-Off By: Vivek Khandelwal --- .../TorchToLinalg/Uncategorized.cpp | 39 +++++++++++++++++-- projects/pt1/e2e_testing/xfail_sets.py | 1 + .../test_suite/elementwise.py | 23 +++++++++++ 3 files changed, 59 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 87163fc95c4a..657ea460f76e 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -511,11 +511,42 @@ static Value createLinalgPayloadCalculationForElementwiseOp( } // TODO: Take approximation into account. std::string approximate; - if (!matchPattern(gelu.getApproximate(), m_TorchConstantStr(approximate)) || - approximate != "none") + if (!matchPattern(gelu.getApproximate(), m_TorchConstantStr(approximate))) { + gelu.emitError( + "unimplemented: expected approximate to be a constant str"); return nullptr; - Value cdf = buildUnitNormalCdf(b, loc, payloadArgs[0]); - return b.create(loc, payloadArgs[0], cdf); + } + if (approximate == "none") { + Value multiplier = buildUnitNormalCdf(b, loc, payloadArgs[0]); + return b.create(loc, payloadArgs[0], multiplier); + } + if (approximate == "tanh") { + // GELU(x)=0.5∗x∗(1+Tanh((2/Ï€)^1/2 * (x+0.044715∗x^3))) + // Ref: https://pytorch.org/docs/stable/generated/torch.nn.GELU.html + Value cstThree = b.create( + loc, IntegerAttr::get(IntegerType::get(op->getContext(), 64), 3)); + Value xCube = b.create(loc, payloadArgs[0], cstThree); + Type elementType = payloadArgs[0].getType(); + Value cstAlpha = b.create( + loc, FloatAttr::get(elementType, 0.044715)); + Value xCubeMulAlpha = b.create(loc, xCube, cstAlpha); + Value xPlusXCubeMulAlpha = + b.create(loc, payloadArgs[0], xCubeMulAlpha); + Value cstBeta = b.create( + loc, FloatAttr::get(elementType, 0.7977240352174656)); + Value betaMulX = + b.create(loc, cstBeta, xPlusXCubeMulAlpha); + Value tanh = b.create(loc, betaMulX); + Value cstOne = + b.create(loc, FloatAttr::get(elementType, 1.0)); + Value onePlusTanh = b.create(loc, cstOne, tanh); + Value cstHalf = + b.create(loc, FloatAttr::get(elementType, 0.5)); + Value multiplier = b.create(loc, cstHalf, onePlusTanh); + return b.create(loc, payloadArgs[0], multiplier); + } + gelu.emitError("unimplemented: approximate value should be none or tanh"); + return nullptr; } if (auto geluBackward = dyn_cast(op)) { if (!geluBackward.getType() diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 82c1a0759e2e..195a5e42f249 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -518,6 +518,7 @@ "ElementwiseFloorIntModule_basic", "ElementwiseFloorModule_basic", "ElementwiseGeluModule_basic", + "ElementwiseGeluApproximateTanhModule_basic", "ElementwiseLeakyReluStaticModule_basic", "ElementwiseLogModule_basic", "ElementwiseNanToNumModule_Basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 611effdb2338..24bbe29194a2 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -853,6 +853,29 @@ def ElementwiseGeluModule_basic(module, tu: TestUtils): # ============================================================================== +class ElementwiseGeluApproximateTanhModule(torch.nn.Module): + + def __init__(self): + super().__init__() + self.gelu = torch.nn.GELU(approximate="tanh") + + @export + @annotate_args([ + None, + ([-1, -1], torch.float32, True), + ]) + def forward(self, x): + return self.gelu(x) + + +@register_test_case(module_factory=lambda: ElementwiseGeluApproximateTanhModule()) +def ElementwiseGeluApproximateTanhModule_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 3, low=-0.5, high=0.5)) + + +# ============================================================================== + + class ElementwiseSeluModule(torch.nn.Module): def __init__(self): From e30a083affb65c301066eda3df7112c06f4291da Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 27 Feb 2024 11:46:57 -0800 Subject: [PATCH 227/283] [torch] Rework lowering to tm_tensor.scatter to stop serialization (#2940) We collapsed and broadcasted scatter indices to a single element version. We should instead upport `tm_tensor.scatter`s support for multiple indices and the implicitly broadcasted behavior. This avoids the serialization and materializing a needlessly large indices tensor. --- .../Dialect/TMTensor/IR/TMTensorOps.td | 1 + .../TorchToTMTensor/TorchToTMTensor.cpp | 668 ++++++++---------- lib/Dialect/TMTensor/IR/TMTensorOps.cpp | 73 +- .../TorchConversion/Transforms/Passes.cpp | 1 + .../torch_mlir_e2e_test/test_suite/scatter.py | 1 + test/Dialect/TMTensor/bufferize.mlir | 8 +- test/Dialect/TMTensor/convert_to_loops.mlir | 17 +- test/Dialect/TMTensor/invalid.mlir | 50 +- 8 files changed, 379 insertions(+), 440 deletions(-) diff --git a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td index 50dc0c1a0403..12a74faa44d3 100644 --- a/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td +++ b/include/torch-mlir-dialects/Dialect/TMTensor/IR/TMTensorOps.td @@ -137,6 +137,7 @@ def TMTensor_ScatterOp : TMTensor_Op<"scatter", let arguments = (ins Variadic:$inputs, Variadic:$outputs, + DenseI64ArrayAttr:$dimension_map, DefaultValuedAttr:$unique_indices ); let results = (outs Variadic:$results); diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index 4aa82420c38e..ac6d731bf73d 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -200,15 +200,30 @@ convertTorchScatterIndexAndSrcToTMScatterIndexAndSrc(PatternRewriter &rewriter, scatterInputsVector[indexType.getRank()]); } +static llvm::SmallVector createDefaultDimMap(Value indices) { + llvm::SmallVector dmap; + if (auto iTy = dyn_cast(indices.getType())) + dmap.resize(iTy.getSizes()[1]); + + if (auto iTy = dyn_cast(indices.getType())) + dmap.resize(iTy.getDimSize(1)); + + for (int i = 0, s = dmap.size(); i < s; ++i) + dmap[i] = i; + + return dmap; +} + static Value createTMTensorScatterOp( OpBuilder &b, Location loc, Value updates, Value indices, Value original, - bool uniqueIndices, + llvm::ArrayRef dimensionsMap, bool uniqueIndices, function_ref bodyBuild) { + auto dimensionsMapAttr = b.getDenseI64ArrayAttr(dimensionsMap); auto originalTensorType = original.getType().cast(); Type originalElementType = originalTensorType.getElementType(); auto scatterOp = b.create( loc, originalTensorType, ValueRange{updates, indices}, - ValueRange{original}, uniqueIndices); + ValueRange{original}, dimensionsMapAttr, uniqueIndices); Region &scatterOpRegion = scatterOp.getRegion(); auto &scatterOpBlock = scatterOpRegion.emplaceBlock(); @@ -334,7 +349,7 @@ class ConvertAtenScatterSrcOp : public OpConversionPattern { src, dim); Value scatterOp = createTMTensorScatterOp( rewriter, loc, updates, indices, self, - /*uniqueIndices=*/false, + /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value updatesElement, Value inputElement) { b.create(loc, updatesElement); @@ -455,7 +470,7 @@ class ConvertAtenBincountOp : public OpConversionPattern { Value scatterOp = createTMTensorScatterOp( rewriter, loc, updatesTensor, indices, bincountTensor, - /*uniqueIndices=*/false, + /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value _, Value bincountElem) { Value add = b.create(loc, bincountElem, constantOne); b.create(loc, add); @@ -466,235 +481,200 @@ class ConvertAtenBincountOp : public OpConversionPattern { }; } // namespace -// """Create a map from each dimension of the input tensor to the -// subspace that dimension corresponds to in the result shape one gets -// from indexing the tensor with the optional index tensors. -// -// Note: Index tensors are first broadcasted to a common shape before -// creating the mapping. So the index of every index tensor will map to -// the same dimensions in the result shape. -// -// For example: -// indices = [None, None, torch.randint(4, (6, 1)), torch.randint(5, (7,))] -// indexBroadcastShapeValue = [6, 7] -// map = {0: [0], 1: [1], 2: [2, 3], 3: [2, 3]} -static SmallVector> -getInputShapeToOutputShapeMap(SmallVector optionalIndices, - SmallVector indexBroadcastShapeValue) { - SmallVector indices; - for (Value index : optionalIndices) { - if (!index.getType().isa()) - indices.push_back(index); +namespace { + +Value combinePutIndices(Location loc, llvm::ArrayRef indicesRef, + OpBuilder b) { + llvm::SmallVector indices(indicesRef); + // Declare commonly used constants up front: + Value torchCstZero = + b.create(loc, b.getI64IntegerAttr(0)); + Value torchCstOne = + b.create(loc, b.getI64IntegerAttr(1)); + Value torchCstNegOne = + b.create(loc, b.getI64IntegerAttr(-1)); + + // Determine the broadcast sizes and materialize missing implicit end + // dimensions: + int64_t indicesRank = 0; + for (auto index : indices) { + auto indexTy = cast(index.getType()); + int64_t rank = indexTy.getSizes().size(); + indicesRank = std::max(rank, indicesRank); } - unsigned broadcastRank = indexBroadcastShapeValue.size(); - unsigned numIndexTensors = indices.size(); - int64_t indexOfFirstIndexTensor = -1; - SmallVector> result; - - for (unsigned i = 0; i < optionalIndices.size(); i++) { - if (optionalIndices[i].getType().isa()) { - unsigned val = i; - if (indexOfFirstIndexTensor >= 0) - val += broadcastRank - numIndexTensors; - result.push_back({val}); - } else { - if (indexOfFirstIndexTensor < 0) - indexOfFirstIndexTensor = i; - SmallVector outputIndices; - for (unsigned j = indexOfFirstIndexTensor; - j < (indexOfFirstIndexTensor + broadcastRank); j++) - outputIndices.push_back(j); - result.push_back(outputIndices); + auto maxDim = [](int64_t dim0, int64_t dim1) { + if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize) + return Torch::kUnknownSize; + return std::max(dim0, dim1); + }; + + llvm::SmallVector broadcastSizes(indicesRank, torchCstOne); + llvm::SmallVector broadcastShape(indicesRank, 0); + for (auto index : indices) { + auto indexTy = cast(index.getType()); + auto shape = indexTy.getSizes(); + int32_t rank = shape.size(); + + for (int32_t j = 0; j < rank; ++j) { + Value dim = b.create(loc, b.getI64IntegerAttr(j)); + auto sizeOp = b.create(loc, index, dim); + auto size = shape[j]; + + int32_t idx = broadcastShape.size() - rank + j; + broadcastSizes[idx] = + b.create(loc, sizeOp, broadcastSizes[idx]); + broadcastShape[idx] = maxDim(size, broadcastShape[idx]); } } - return result; -} -static std::tuple, SmallVector> -getIndicesFinalShape(ConversionPatternRewriter &rewriter, Location loc, - Value input, SmallVector optionalIndices, - SmallVector inputShapeInt, - SmallVector inputShapeValue, - SmallVector indexBroadcastShapeInt, - SmallVector indexBroadcastShapeValue) { - SmallVector result; - SmallVector resultInt; - bool handledIndexTensorSpace = false; - - for (unsigned i = 0; i < inputShapeValue.size(); i++) { - if (optionalIndices[i].getType().isa()) { - result.push_back(inputShapeValue[i]); - resultInt.push_back(inputShapeInt[i]); - } else { - if (!handledIndexTensorSpace) { - handledIndexTensorSpace = true; - for (unsigned j = 0; j < indexBroadcastShapeValue.size(); j++) { - result.push_back(indexBroadcastShapeValue[j]); - resultInt.push_back(indexBroadcastShapeInt[j]); - } - } - } + auto mulDim = [](int64_t dim0, int64_t dim1) { + if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize) + return Torch::kUnknownSize; + return dim0 * dim1; + }; + + int64_t scatterBatchCount = 1; + for (auto dim : broadcastShape) { + scatterBatchCount = mulDim(scatterBatchCount, dim); } - return std::make_tuple(result, resultInt); -} -static FailureOr -getScatterIndices(Aten_IndexPutImplOp op, ConversionPatternRewriter &rewriter, - Type indicesDtype, SmallVector optionalIndices, - SmallVector indexBroadcastShapeInt, - SmallVector indexBroadcastShapeValue) { - Location loc = op.getLoc(); - MLIRContext *context = op->getContext(); - Value input = op.getSelf(); - - SmallVector> shapeMap = - getInputShapeToOutputShapeMap(optionalIndices, indexBroadcastShapeValue); - - SmallVector inputShapeInt{ - input.getType().cast().getSizes()}; - int64_t inputRank = inputShapeInt.size(); - SmallVector inputShapeValue; - for (unsigned i = 0; i < inputShapeInt.size(); i++) { - Value dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - inputShapeValue.push_back( - rewriter.createOrFold(loc, input, dim)); + // Broadcast together and flatten to batch values: + Value broadcastSizeList = b.create( + loc, Torch::ListType::get(b.getType()), broadcastSizes); + for (Value &index : indices) { + auto indexTy = cast(index.getType()); + auto expandTy = b.getType( + broadcastShape, indexTy.getOptionalDtype()); + index = b.create(loc, expandTy, index, + broadcastSizeList); + + auto flattenTy = b.getType( + scatterBatchCount, indexTy.getOptionalDtype()); + index = b.create( + loc, flattenTy, index, torchCstZero, torchCstNegOne); } - auto finalShapeResult = getIndicesFinalShape( - rewriter, loc, input, optionalIndices, inputShapeInt, inputShapeValue, - indexBroadcastShapeInt, indexBroadcastShapeValue); - SmallVector finalShapeValue = std::get<0>(finalShapeResult); - SmallVector finalShapeInt = std::get<1>(finalShapeResult); + // Unsqueeze so we have a 1 dim to concat along: + for (Value &tensor : indices) { + auto btt = cast(tensor.getType()); + if (!btt.hasSizes()) + return nullptr; - Value torchCstNone = rewriter.create(loc); - Value torchCstZero = - rewriter.create(loc, rewriter.getI64IntegerAttr(0)); - Value torchCstOne = - rewriter.create(loc, rewriter.getI64IntegerAttr(1)); - - Value indexBroadcastShapeTorchList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), - indexBroadcastShapeValue); - - // Calculating index count. - int64_t indexCount = 1; - if (llvm::all_of(finalShapeInt, - [](int64_t shape) { return shape != kUnknownSize; })) { - for (int64_t i : finalShapeInt) - indexCount *= i; - } else { - indexCount = kUnknownSize; + llvm::SmallVector shape(btt.getSizes()); + shape.push_back(1); + + auto unsqueezeTy = b.getType(shape, btt.getDtype()); + Value unsqueezed = + b.create(loc, unsqueezeTy, tensor, torchCstOne); + tensor = unsqueezed; } - Value indexCountValue = finalShapeValue[0]; - for (unsigned i = 1; i < finalShapeValue.size(); i++) - indexCountValue = - rewriter.create(loc, indexCountValue, finalShapeValue[i]); - - ValueTensorType flattenIndicesType = - ValueTensorType::get(context, llvm::ArrayRef(indexCount), indicesDtype); - Value flattenEndDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(finalShapeInt.size() - 1)); - - SmallVector broadcastedIndices; - for (unsigned i = 0; i < optionalIndices.size(); i++) { - Value broadcastedIndexTensor; - if (optionalIndices[i].getType().isa()) { - Value torchCstDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - Value inputDim = rewriter.create(loc, input, torchCstDim); - ValueTensorType tensorType = ValueTensorType::get( - context, llvm::ArrayRef(inputShapeInt[i]), indicesDtype); - broadcastedIndexTensor = rewriter.create( - loc, tensorType, /*start=*/torchCstZero, /*end=*/inputDim, - /*step=*/torchCstOne, - /*dtype=*/torchCstNone, - /*layout=*/torchCstNone, - /*device=*/torchCstNone, - /*pin_memory=*/torchCstNone); - } else { - ValueTensorType tensorType = ValueTensorType::get( - context, llvm::ArrayRef(indexBroadcastShapeInt), indicesDtype); - broadcastedIndexTensor = rewriter.create( - loc, tensorType, optionalIndices[i], indexBroadcastShapeTorchList); - } + BaseTensorType unsqueezedTensorType = + indices[0].getType().cast(); + Value indicesTorchList = b.create( + loc, Torch::ListType::get(unsqueezedTensorType), indices); + llvm::SmallVector concatShape{ + unsqueezedTensorType.getSizes()[0], static_cast(indices.size())}; + ValueTensorType concatIndicesType = b.getType( + llvm::ArrayRef(concatShape), unsqueezedTensorType.getDtype()); + return b.create(loc, concatIndicesType, indicesTorchList, + torchCstOne); +} - // spotlight_indices(final_shape, shape_map[i]): - // Turn all values in `final_shape` to `1` except for those with index in - // `indices`. - // for j in range(len(final_shape)): - // if j not in indices: - // final_shape[j] = 1 - // This is equivalent to unsqueezing the index tensor at the dimension `j` - // not in indices. - for (unsigned j = 0; j < finalShapeInt.size(); j++) { - if (llvm::find(shapeMap[i], j) == shapeMap[i].end()) { - Value unsqueezeDim = rewriter.create( - loc, rewriter.getI64IntegerAttr(j)); - auto unsqueezedInfo = - unsqueezeTensor(rewriter, op, broadcastedIndexTensor, - /*dim=*/unsqueezeDim); - if (failed(unsqueezedInfo)) { - return rewriter.notifyMatchFailure( - op, "cannot generate unsqueeze tensor op"); - } - broadcastedIndexTensor = *unsqueezedInfo; - } - } +// Helper that collapses the batch dimensions together and moves it to the front +// of the array. +static Value collapseAndMoveBatchDims(Location loc, Value values, int64_t batch, + int64_t count, OpBuilder b) { + if (batch == 0 && count == 1) + return values; + + auto valuesTy = cast(values.getType()); + auto inShape = valuesTy.getSizes(); - // Performing broadcast to final shape. - Value broadcastShapeTorchList = rewriter.create( - loc, Torch::ListType::get(Torch::IntType::get(context)), - finalShapeValue); - ValueTensorType broadcastTensorType = ValueTensorType::get( - context, llvm::ArrayRef(finalShapeInt), indicesDtype); - broadcastedIndexTensor = rewriter.create( - loc, broadcastTensorType, broadcastedIndexTensor, - broadcastShapeTorchList); - - // Flattening the tensor. - broadcastedIndexTensor = rewriter.create( - loc, flattenIndicesType, broadcastedIndexTensor, torchCstZero, - flattenEndDim); - - broadcastedIndices.push_back(broadcastedIndexTensor); + llvm::SmallVector outShape; + llvm::SmallVector outDims; + + // We need a length-1 dim at the start to transpose the batch to: + if (batch != 0) { + outDims.push_back(b.create(loc, 1)); + outShape.push_back(1); } - // Stacking broadcasted indices. - Value scatterIndices; - // The operation torch.stack([a, b], dim=0) is decomposed into: - // torch.cat([a.unsqueeze(dim=0), b.unsqueeze(dim=0)], dim=0) - // Unsqueeze all tensors before concatenating. - SmallVector unsqueezedIndexTensors; - for (Value tensor : broadcastedIndices) { - auto unsqueezedInfo = - unsqueezeTensor(rewriter, op, tensor, /*dim=*/torchCstZero); - if (failed(unsqueezedInfo)) { - return rewriter.notifyMatchFailure(op, - "cannot generate unsqueeze tensor op"); - } - unsqueezedIndexTensors.push_back(*unsqueezedInfo); + // Dimensions before the batch stay the same: + for (int i = 0; i <= batch; i++) { + auto k = b.create(loc, b.getI64IntegerAttr(i)); + auto dim = b.create(loc, values, k); + outDims.push_back(dim); + outShape.push_back(inShape[i]); } - BaseTensorType unsqueezedTensorType = - unsqueezedIndexTensors[0].getType().cast(); - Value concatIndicesTorchList = rewriter.create( - loc, Torch::ListType::get(unsqueezedTensorType), unsqueezedIndexTensors); - ValueTensorType concatIndicesType = ValueTensorType::get( - context, llvm::ArrayRef({inputRank, indexCount}), indicesDtype); - scatterIndices = rewriter.create( - loc, concatIndicesType, concatIndicesTorchList, torchCstZero); - - ValueTensorType transposedIndicesType = ValueTensorType::get( - context, llvm::ArrayRef({indexCount, inputRank}), indicesDtype); - scatterIndices = rewriter.create( - loc, transposedIndicesType, scatterIndices, torchCstZero, torchCstOne); - return scatterIndices; + auto mulI = [](int64_t dim0, int64_t dim1) { + if (dim0 == Torch::kUnknownSize || dim1 == Torch::kUnknownSize) + return Torch::kUnknownSize; + return dim0 * dim1; + }; + + // Determine the collapse size of the batch dimension: + for (int i = 1; i < count; i++) { + outShape.back() = mulI(outShape.back(), inShape[batch + i]); + + auto k = + b.create(loc, b.getI64IntegerAttr(batch + i)); + auto dim = b.create(loc, values, k); + outDims.back() = b.create(loc, dim, outDims.back()); + } + + // Add the dimensions after the batch dims: + for (int i = batch + count, s = inShape.size(); i < s; ++i) { + auto k = b.create(loc, b.getI64IntegerAttr(i)); + auto dim = b.create(loc, values, k); + outDims.push_back(dim); + outShape.push_back(inShape[i]); + } + + Value outDimsList = b.create( + loc, Torch::ListType::get(b.getType()), outDims); + + valuesTy = + b.getType(outShape, valuesTy.getOptionalDtype()); + values = b.create(loc, valuesTy, values, outDimsList); + + if (batch == 0) + return values; + + // Batch is already at the front, no need to transpose: + std::swap(outDims[0], outDims[batch + 1]); + std::swap(outShape[0], outShape[batch + 1]); + + Value dim0 = b.create(loc, b.getI64IntegerAttr(0)); + Value dimB = + b.create(loc, b.getI64IntegerAttr(batch + 1)); + + valuesTy = + b.getType(outShape, valuesTy.getOptionalDtype()); + values = + b.create(loc, valuesTy, values, dim0, dimB); + + outDims.clear(); + outShape.clear(); + auto transposeShape = valuesTy.getSizes(); + int64_t transposeRank = transposeShape.size(); + for (int i = 0; i < transposeRank; ++i) { + if (i == batch + 1) + continue; + Value k = b.create(loc, b.getI64IntegerAttr(i)); + outDims.push_back(b.create(loc, values, k)); + outShape.push_back(transposeShape[i]); + } + + valuesTy = + b.getType(outShape, valuesTy.getOptionalDtype()); + outDimsList = b.create( + loc, Torch::ListType::get(b.getType()), outDims); + return b.create(loc, valuesTy, values, outDimsList); } -namespace { class ConvertAten_IndexPutImplOp : public OpConversionPattern { public: @@ -706,11 +686,11 @@ class ConvertAten_IndexPutImplOp return failure(); Location loc = op.getLoc(); MLIRContext *context = op->getContext(); - Value input = adaptor.getSelf(); - Value values = adaptor.getValues(); - RankedTensorType inputType = input.getType().cast(); - RankedTensorType valuesType = values.getType().cast(); - int64_t inputRank = inputType.getRank(); + Value input = op.getSelf(); + Value values = op.getValues(); + auto inputType = cast(input.getType()); + auto valuesType = cast(values.getType()); + int64_t inputRank = inputType.getSizes().size(); auto valuesTensorType = op.getValues().getType().cast(); auto resultType = typeConverter->convertType(op->getResult(0).getType()) .cast(); @@ -737,190 +717,107 @@ class ConvertAten_IndexPutImplOp op, "Expected accumulate to be constant bool."); // The element type of the `input` and `values` should be same. - if (inputType.getElementType() != valuesType.getElementType()) + if (inputType.getDtype() != valuesType.getDtype()) return rewriter.notifyMatchFailure( op, "Input element type should be same as the values element type."); SmallVector optionalIndicesList; getListConstructElements(op.getIndices(), optionalIndicesList); + int64_t optionalIndicesCount = optionalIndicesList.size(); // The size of the list of the index tensors should not be greater than the // input rank. - if ((int64_t)optionalIndicesList.size() > inputRank) + if (optionalIndicesCount > inputRank) return rewriter.notifyMatchFailure( op, "Indices list size should not be greater than the input rank."); - Value torchCstNone = rewriter.create(loc); - unsigned sizeOptionalIndicesList = optionalIndicesList.size(); - SmallVector nonNoneIndexTensorDim; - unsigned numNonNoneIndices; - - if (sizeOptionalIndicesList == 0) + if (optionalIndicesCount == 0) return rewriter.notifyMatchFailure(op, "Indices list must not be empty."); - for (unsigned i = 0; i < optionalIndicesList.size(); i++) { - if (!optionalIndicesList[i].getType().isa()) { - nonNoneIndexTensorDim.push_back(i); - } - } - - numNonNoneIndices = nonNoneIndexTensorDim.size(); - if (numNonNoneIndices > 2) { - return rewriter.notifyMatchFailure( - op, "unimplemented: non none index tensors less than or equal to 2 " - "supported only"); - } else if (numNonNoneIndices == 2 && - nonNoneIndexTensorDim[0] != nonNoneIndexTensorDim[1] - 1) { - return rewriter.notifyMatchFailure( - op, "unimplemented: case of 2 non none index tensors is supported " - "only when both the tensors are along consecutive dimensions"); - } - - // Padding the indices list with none values. - if (sizeOptionalIndicesList < inputRank) { - for (unsigned i = 0; i < (inputRank - sizeOptionalIndicesList); i++) - optionalIndicesList.push_back(torchCstNone); + // Filter to available indices and get the indicesMap: + SmallVector indicesList; + SmallVector indicesMap; + int64_t numBatchDims = 0; + for (int i = 0, s = optionalIndicesList.size(); i < s; ++i) { + if (isa(optionalIndicesList[i].getType())) + continue; + indicesList.push_back(optionalIndicesList[i]); + indicesMap.push_back(i); + + auto indexTy = cast(indicesList.back().getType()); + numBatchDims = std::max(static_cast(indexTy.getSizes().size()), + numBatchDims); } - SmallVector indexBroadcastShapeInt{ - optionalIndicesList[nonNoneIndexTensorDim[0]] - .getType() - .cast() - .getSizes()}; - SmallVector indexBroadcastShapeValue; - if (numNonNoneIndices == 2) { - computeBroadcastShape(rewriter, loc, - optionalIndicesList[nonNoneIndexTensorDim[0]], - optionalIndicesList[nonNoneIndexTensorDim[1]], - indexBroadcastShapeInt, indexBroadcastShapeValue); - } else { - // It means there's only one index tensor and broadcast shape is same as - // that index tensor' shape. - for (unsigned i = 0; i < indexBroadcastShapeInt.size(); i++) { - Value dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - indexBroadcastShapeValue.push_back(rewriter.createOrFold( - loc, optionalIndicesList[nonNoneIndexTensorDim[0]], dim)); + // Value broadcasting semantics require batch dimensions to be up front if + // the indices are not sequential, otherwise they are sequentially at their + // location: + int64_t batchDim = 0; + for (int s = optionalIndicesList.size(); batchDim < s; ++batchDim) + if (!isa(optionalIndicesList[batchDim].getType())) + break; + + int64_t nextNone = batchDim; + for (int s = optionalIndicesList.size(); nextNone < s; ++nextNone) + if (isa(optionalIndicesList[nextNone].getType())) + break; + + for (int s = optionalIndicesList.size(); nextNone < s; ++nextNone) + if (!isa(optionalIndicesList[nextNone].getType())) + batchDim = 0; + + // Indices are extended, catted, and collapsed into a [batch, depth] tensor: + Value indices = combinePutIndices(loc, indicesList, rewriter); + + // Bove batch dimensions to the front and collapse into a single dim: + values = + collapseAndMoveBatchDims(loc, values, batchDim, numBatchDims, rewriter); + valuesType = cast(values.getType()); + + // Materialize out the length-1 dimensions: + Value zero = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + llvm::SmallVector valuesShape{valuesType.getSizes().front()}; + llvm::SmallVector valuesDims; + valuesDims.push_back( + rewriter.create(loc, values, zero)); + + int vDim = 1; + for (int i = 0, s = inputType.getSizes().size(); i < s; ++i) { + if (i < optionalIndicesCount && + !isa(optionalIndicesList[i].getType())) { + valuesDims.push_back(one); + valuesShape.push_back(1); + continue; } - } - Type indicesDtype = optionalIndicesList[nonNoneIndexTensorDim[0]] - .getType() - .cast() - .getDtype(); - - // This implementation is done to get the scatter indices: - - // def get_broadcast_shape(tensors): - // return list(torch.broadcast_tensors(*tensors)[0].shape) - - // def get_input_shape_to_output_shape_map(optional_index_tensors: - // list[Optional[torch.Tensor]]): - // index_tensors = list(filter(lambda x: x is not None, - // optional_index_tensors)) broadcast_rank = - // len(get_broadcast_shape(index_tensors)) num_of_index_tensors = - // len(index_tensors) index_of_first_index_tensor: Optional[int] = None - // result = {} - // for i, index in enumerate(optional_index_tensors): - // if index is None: - // val = i - // if index_of_first_index_tensor is not None: - // val += broadcast_rank - num_of_index_tensors - // result[i] = [val] - // else: - // if index_of_first_index_tensor is None: - // index_of_first_index_tensor = i - // output_indices = list(range(index_of_first_index_tensor, - // index_of_first_index_tensor + - // broadcast_rank)) - // result[i] = output_indices - // return result - - // def spotlight_indices(shape, indices: list[int]): - // """Turn all values in `shape` to `1` except for those with index in - // `indices`.""" shape = shape.copy() for i in range(len(shape)): - // if i not in indices: - // shape[i] = 1 - // return shape - - // def get_final_shape(input, optional_index_tensors: - // list[Optional[torch.Tensor]]): - // index_tensors = list(filter(lambda x: x is not None, - // optional_index_tensors)) index_tensors_broadcast_shape = - // get_broadcast_shape(index_tensors) result = [] - // handled_index_tensor_space = False - // for e, i in enumerate(input.shape): - // if optional_index_tensors[e] is None: - // result.append(i) - // else: - // if not handled_index_tensor_space: - // handled_index_tensor_space = True - // result += index_tensors_broadcast_shape - // return result - - // def get_scatter_indices(input, optional_index_tensors: - // list[Optional[torch.Tensor]]): - // assert len(input.size()) == len(optional_index_tensors), "Pad indices - // with None" shape_map = - // get_input_shape_to_output_shape_map(optional_index_tensors) - // index_tensors = list(filter(lambda x: x is not None, - // optional_index_tensors)) index_tensors_broadcast_shape = - // get_broadcast_shape(index_tensors) final_shape = - // get_final_shape(input, optional_index_tensors) - - // broadcasted_index_tensors = [] - // for e, optional_index_tensor in enumerate(optional_index_tensors): - // if optional_index_tensor is None: - // tensor_to_broadcast = torch.arange(0, input.size(e)) - // else: - // tensor_to_broadcast = - // optional_index_tensor.broadcast_to(index_tensors_broadcast_shape) - - // broadcasted_index_tensor = \ - // tensor_to_broadcast.reshape(spotlight_indices(final_shape, shape_map[e]))\ - // .broadcast_to(final_shape)\ - // .flatten() - // broadcasted_index_tensors.append(broadcasted_index_tensor) - - // return torch.stack(broadcasted_index_tensors, dim=0).t() - - auto scatterIndicesInfo = - getScatterIndices(op, rewriter, indicesDtype, optionalIndicesList, - indexBroadcastShapeInt, indexBroadcastShapeValue); - if (failed(scatterIndicesInfo)) { - return rewriter.notifyMatchFailure( - op, "cannot generate scatter indices for index put op"); + Value k = rewriter.create( + loc, rewriter.getI64IntegerAttr(vDim)); + valuesDims.push_back( + rewriter.create(loc, values, k)); + valuesShape.push_back(inputType.getSizes()[i]); + vDim++; } - Value indexTensor = *scatterIndicesInfo; - // Flattening the values tensor. - Value torchCstZero = rewriter.create( - loc, rewriter.getI64IntegerAttr(0)); - Value flattenedValuesTensorLastDim = rewriter.create( - loc, - rewriter.getI64IntegerAttr(valuesTensorType.getSizes().size() - 1)); - SmallVector valuesShapeInt{valuesTensorType.getSizes()}; - int64_t valuesCount = 1; - if (llvm::all_of(valuesShapeInt, - [](int64_t shape) { return shape != kUnknownSize; })) { - for (int64_t i : valuesShapeInt) - valuesCount *= i; - } else { - valuesCount = kUnknownSize; - } - auto flattenedValuesTensorType = ValueTensorType::get( - context, llvm::ArrayRef(valuesCount), valuesTensorType.getDtype()); - Value flattenedValuesTensor = rewriter.create( - loc, flattenedValuesTensorType, op.getValues(), torchCstZero, - flattenedValuesTensorLastDim); - values = typeConverter->materializeTargetConversion( - rewriter, loc, - typeConverter->convertType(flattenedValuesTensor.getType()), - flattenedValuesTensor); + Value valuesDimsList = rewriter.create( + loc, Torch::ListType::get(rewriter.getType()), + valuesDims); + + valuesType = rewriter.getType( + valuesShape, valuesType.getOptionalDtype()); + values = + rewriter.create(loc, valuesType, values, valuesDimsList); // `TMTensor::ScatterOp` expects indices of element type i32. - Value indices = convertTensorToDtype( - rewriter, loc, indexTensor, + indices = convertTensorToDtype( + rewriter, loc, indices, mlir::IntegerType::get(context, 32, mlir::IntegerType::Signed)); + + input = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(input.getType()), input); + values = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(values.getType()), values); indices = typeConverter->materializeTargetConversion( rewriter, loc, typeConverter->convertType(indices.getType()), indices); @@ -931,7 +828,8 @@ class ConvertAten_IndexPutImplOp // 3.) `input` is mapped to `original` in scatter op. bool invalidInputTypeFound = false; Value scatterOp = createTMTensorScatterOp( - rewriter, loc, values, indices, input, /*uniqueIndices=*/false, + rewriter, loc, values, indices, input, indicesMap, + /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value valuesElement, Value inputElement) { Value yieldValue = valuesElement; @@ -1150,6 +1048,7 @@ class ConvertAtenMaxPool2dWithIndicesBackwardOp Value scatterOp = createTMTensorScatterOp( rewriter, loc, /*updates=*/gradOutputFlattened, /*indices=*/indicesCollapsed, /*original=*/outputTensor, + /*dimensionsMap=*/createDefaultDimMap(indicesCollapsed), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value valuesElement, Value inputElement) { @@ -1292,6 +1191,7 @@ class ConvertAtenScatterReduceTwoOp srcType.getElementType(), /*init_element=*/normalizationValue); self = createTMTensorScatterOp( rewriter, loc, normalizations, indices, self, + /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value update, Value current) { b.create(loc, update); @@ -1299,6 +1199,7 @@ class ConvertAtenScatterReduceTwoOp if (reduceEnum == torch_upstream::ReductionType::MEAN) { counts = createTMTensorScatterOp( rewriter, loc, normalizations, indices, counts, + /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value update, Value current) { b.create(loc, update); @@ -1309,7 +1210,7 @@ class ConvertAtenScatterReduceTwoOp // Create final operation Value scatterOp = createTMTensorScatterOp( rewriter, loc, updates, indices, self, - /*uniqueIndices=*/false, + /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value update, Value current) { Value result; if (reduceEnum == torch_upstream::ReductionType::SUM || @@ -1353,6 +1254,7 @@ class ConvertAtenScatterReduceTwoOp if (reduceEnum == torch_upstream::ReductionType::MEAN) { counts = createTMTensorScatterOp( rewriter, loc, updates, indices, counts, + /*dimensionsMap=*/createDefaultDimMap(indices), /*uniqueIndices=*/false, [&](OpBuilder &b, Location loc, Value update, Value current) { Value result; diff --git a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp index 0b827893cae3..7b8a17682a9e 100644 --- a/lib/Dialect/TMTensor/IR/TMTensorOps.cpp +++ b/lib/Dialect/TMTensor/IR/TMTensorOps.cpp @@ -509,12 +509,32 @@ LogicalResult ScanOp::fold(FoldAdaptor adaptor, //===----------------------------------------------------------------------===// // ScatterOp //===----------------------------------------------------------------------===// +static Type getComplexElementTypeOrSelf(Type ty) { + if (auto complex = dyn_cast_or_null(ty)) + return complex.getElementType(); + return ty; +} + +static bool isInvalid(ArrayRef dimsPos, int64_t rank) { + // early exit. + if (static_cast(dimsPos.size()) > rank) + return true; + DenseSet uniqued; + for (int64_t dim : dimsPos) + uniqued.insert(dim); + if (static_cast(dimsPos.size()) != uniqued.size()) + return true; + return llvm::any_of( + dimsPos, [rank](int64_t dimPos) { return dimPos < 0 || dimPos >= rank; }); +} + LogicalResult ScatterOp::verify() { + Operation *op = getOperation(); if (getInputs().size() != 2) { - return emitOpError("expected two input operands"); + return op->emitOpError("expected two input operands"); } if (getOutputs().size() != 1) { - return emitOpError("expected one output operand"); + return op->emitOpError("expected one output operand"); } auto checkDimensionsMatch = [&](ShapedType t1, ShapedType t2, unsigned dim) { return t1.getShape()[dim] == t2.getShape()[dim]; @@ -526,10 +546,19 @@ LogicalResult ScatterOp::verify() { return emitOpError("expected indices to be of rank 2 of i32 element type"); } auto indexDepth = getIndexDepth(); - if (indexDepth == ShapedType::kDynamic) { + if (ShapedType::isDynamic(indexDepth)) { return emitOpError("expected index depth is static"); } + ArrayRef dimMap = getDimensionMap(); + if (static_cast(dimMap.size()) != indexDepth) { + return op->emitOpError("invalid number of dimension map entries "); + } + + auto originalType = getOriginalType(); + if (isInvalid(dimMap, originalType.getRank())) + return op->emitOpError("dimension map is invalid"); + // The first dimension of the indices should match the first dimension of the // output. They indicate to the number of updates. auto updateType = getUpdateType(); @@ -540,7 +569,6 @@ LogicalResult ScatterOp::verify() { return emitOpError( "mismatch in shape of indices and update value at dim#0"); } - auto originalType = getOriginalType(); if (updateType.getRank() - 1 > originalType.getRank()) { return emitOpError( "update value rank exceeds the rank of the original value"); @@ -553,7 +581,7 @@ LogicalResult ScatterOp::verify() { "index depth and update value does not cover rank of original value"); } - // Validate the non-indexed update dims covier the full slice size of the + // Validate the non-indexed update dims cover the full slice size of the // original tensor. int64_t fullSliceDims = originalType.getRank() - indexDepth; for (auto it : @@ -562,10 +590,11 @@ LogicalResult ScatterOp::verify() { updateType.getRank()))) { int64_t originalDim = std::get<0>(it); int64_t updateDim = std::get<1>(it); - if (updateType.getDimSize(updateDim) != - originalType.getDimSize(originalDim)) { - return emitOpError("mismatch in shape of update value dim#") - << updateDim << " and original value at dim#" << originalDim; + if (!originalType.isDynamicDim(originalDim) && + updateType.getDimSize(updateDim) > + originalType.getDimSize(originalDim)) { + return op->emitOpError("shape of update value dim#") + << updateDim << " exceeds original value at dim#" << originalDim; } } @@ -576,23 +605,25 @@ LogicalResult ScatterOp::verify() { llvm::seq(1, updateType.getRank() - fullSliceDims))) { int64_t originalDim = std::get<0>(it); int64_t updateDim = std::get<1>(it); - if (updateType.getDimSize(updateDim) > - originalType.getDimSize(originalDim)) { - return emitOpError("indexed shape of update value dim#") + if (!originalType.isDynamicDim(originalDim) && + updateType.getDimSize(updateDim) > + originalType.getDimSize(originalDim)) { + return op->emitOpError("indexed shape of update value dim#") << updateDim << " exceeds original value at dim#" << originalDim << " " << updateType.getDimSize(updateDim) << " " << originalType.getDimSize(originalDim); } } - Region &thisRegion = getRegion(); - Block *body = &thisRegion.front(); + Region ®ion = this->getRegion(); + Block *body = ®ion.front(); if (body->getNumArguments() != 2) { - return emitOpError("expected region to have two arguments"); + return op->emitOpError("expected region to have two arguments"); } Type arg0Type = body->getArgument(0).getType(); Type arg1Type = body->getArgument(1).getType(); - if (!arg0Type.isIntOrFloat() || !arg1Type.isIntOrFloat()) { + if (!getComplexElementTypeOrSelf(arg0Type).isIntOrFloat() || + !getComplexElementTypeOrSelf(arg1Type).isIntOrFloat()) { return emitOpError( "expected region to have scalar argument of integer or float types"); } @@ -684,14 +715,16 @@ LogicalResult ScatterOp::generateScalarImplementation(OpBuilder &b, starts[it.index() + offset] = it.value(); } + ArrayRef dimMap = getDimensionMap(); for (auto i : llvm::seq(0, indexDepth)) { loadIndices.back() = b.create(loc, i); Value idx = b.create(loc, indices(), loadIndices); - Value cast = b.create(loc, b.getIndexType(), idx); + Value ret = b.create(loc, b.getIndexType(), idx); - if (starts[i]) - cast = b.create(loc, cast, starts[i]); - starts[i] = cast; + auto dim = dimMap[i]; + if (starts[dim]) + ret = b.create(loc, ret, starts[dim]); + starts[dim] = ret; } Value init = b.create(loc, original(), starts); diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 9ff447371a76..7ac95ab6c4e9 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -75,6 +75,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( // (e.g. dimensions which must be constant in a ranked programming model) // and those constants get somewhat obscured by TorchToArith. pm.addNestedPass(createConvertTorchToTMTensorPass()); + pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createConvertTorchToLinalgPass()); pm.addNestedPass(createConvertTorchToSCFPass()); pm.addNestedPass(createConvertTorchToArithPass()); diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py index 8a84961a5f55..89b8b10eb444 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/scatter.py @@ -1166,3 +1166,4 @@ def forward(self, input, index1, index2, value): module_factory=lambda: IndexPutImplIndexWithNoneModule()) def IndexPutImplIndexWithNoneModule_basic(module, tu: TestUtils): module.forward(tu.rand(2, 3, 4, 5), tu.randint(6, 1, high=4), tu.randint(7, high=5), tu.rand(2, 3, 6, 7)) + diff --git a/test/Dialect/TMTensor/bufferize.mlir b/test/Dialect/TMTensor/bufferize.mlir index 3e60814fabdd..f36a2f521ad1 100644 --- a/test/Dialect/TMTensor/bufferize.mlir +++ b/test/Dialect/TMTensor/bufferize.mlir @@ -64,7 +64,7 @@ func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: // CHECK: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32> // CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32> // CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32> -// CHECK: tm_tensor.scatter unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] +// CHECK: tm_tensor.scatter {dimension_map = array} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] // CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) { // CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32): // CHECK: tm_tensor.yield %[[UPDATE_SCALAR]] : i32 @@ -74,7 +74,7 @@ func.func @scan_1d_exclusive(%in: tensor<128xi32>, %out: tensor<128xi32>, %acc: func.func @scatter_update_scalar_1D( %original: tensor<8xi32>, %indices: tensor<3x1xi32>, %updates: tensor<3xi32>) -> tensor<8xi32> { - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map = array} unique_indices(true) ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>) outs(%original : tensor<8xi32>) { ^bb0(%update: i32, %orig: i32): // no predecessors @@ -92,7 +92,7 @@ func.func @scatter_update_scalar_1D( // CHECK: %[[ORIG_MEMREF:.*]] = bufferization.to_memref %[[ORIG_TENSOR]] : memref<8xi32> // CHECK: %[[ORIG_MEMREF_NEW:.*]] = memref.alloc() : memref<8xi32> // CHECK: memref.copy %[[ORIG_MEMREF]], %[[ORIG_MEMREF_NEW]] : memref<8xi32> to memref<8xi32> -// CHECK: tm_tensor.scatter unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] +// CHECK: tm_tensor.scatter {dimension_map = array} unique_indices(true) ins(%[[UPDATES_MEMREF]], %[[INDICES_MEMREF]] // CHECK-SAME: : memref<3xi32>, memref<3x1xi32>) outs(%[[ORIG_MEMREF_NEW]] : memref<8xi32>) { // CHECK: ^bb0(%[[UPDATE_SCALAR:.*]]: i32, %[[ORIG_SCALAR:.*]]: i32): // CHECK: %[[CST1:.*]] = arith.constant 1 : i32 @@ -104,7 +104,7 @@ func.func @scatter_update_scalar_1D( func.func @scatter_add_scalar_1D( %original: tensor<8xi32>, %indices: tensor<3x1xi32>, %updates: tensor<3xi32>) -> tensor<8xi32> { - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map = array} unique_indices(true) ins(%updates, %indices : tensor<3xi32>, tensor<3x1xi32>) outs(%original : tensor<8xi32>) { ^bb0(%update: i32, %orig: i32): // no predecessors diff --git a/test/Dialect/TMTensor/convert_to_loops.mlir b/test/Dialect/TMTensor/convert_to_loops.mlir index e9c160f99e94..7901cf505f2a 100644 --- a/test/Dialect/TMTensor/convert_to_loops.mlir +++ b/test/Dialect/TMTensor/convert_to_loops.mlir @@ -105,7 +105,7 @@ func.func @scan_2d(%0: memref<16x32xi32>, %1: memref<16x32xi32>) { func.func @scatter_update_scalar_1D( %original: memref<8xi32>, %indices: memref<3x1xi32>, %updates: memref<3xi32>) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>) outs(%original : memref<8xi32>) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -131,7 +131,7 @@ func.func @scatter_update_scalar_1D( func.func @scatter_add_scalar_2D( %original: memref<4x3xi32>, %indices: memref<3x2xi32>, %updates: memref<3xi32>) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref<3xi32>, memref<3x2xi32>) outs(%original : memref<4x3xi32>) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -162,7 +162,7 @@ func.func @scatter_add_scalar_2D( func.func @scatter_update_slice_2D( %original: memref<4x3xi32>, %indices: memref<2x1xi32>, %updates: memref<2x3xi32>) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>) outs(%original : memref<4x3xi32>) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -192,7 +192,7 @@ func.func @scatter_update_slice_2D( func.func @scatter_add_scalar_1D( %original: memref<8xi32>, %indices: memref<3x1xi32>, %updates: memref<3xi32>) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref<3xi32>, memref<3x1xi32>) outs(%original : memref<8xi32>) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -221,7 +221,7 @@ func.func @scatter_add_scalar_1D( func.func @scatter_add_slice_2D( %original: memref<4x3xi32>, %indices: memref<2x1xi32>, %updates: memref<2x3xi32>) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref<2x3xi32>, memref<2x1xi32>) outs(%original : memref<4x3xi32>) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -251,7 +251,7 @@ func.func @scatter_add_slice_2D( func.func @scatter_update_scalar_dynamic_1D( %original: memref, %indices: memref, %updates: memref) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref, memref) outs(%original : memref) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -277,7 +277,7 @@ func.func @scatter_update_scalar_dynamic_1D( func.func @scatter_add_scalar_dynamic_2D( %original: memref, %indices: memref, %updates: memref) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref, memref) outs(%original : memref) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -308,7 +308,7 @@ func.func @scatter_add_scalar_dynamic_2D( func.func @scatter_update_slice_dynamic_2D( %original: memref, %indices: memref, %updates: memref) { - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%updates, %indices : memref, memref) outs(%original : memref) { ^bb0(%arg0: i32, %arg1: i32): // no predecessors @@ -335,6 +335,7 @@ func.func @scatter_update_slice_dynamic_2D( func.func @scatter_partial_slices(%arg0: memref<2x64x12xf32>, %arg1: memref<2x3xi32>, %arg2: memref<2x1x12xf32>) { tm_tensor.scatter + {dimension_map= array} unique_indices(true) ins(%arg2, %arg1 : memref<2x1x12xf32>, memref<2x3xi32>) outs(%arg0 : memref<2x64x12xf32>) { diff --git a/test/Dialect/TMTensor/invalid.mlir b/test/Dialect/TMTensor/invalid.mlir index bfcd1adb8152..6653d944a059 100644 --- a/test/Dialect/TMTensor/invalid.mlir +++ b/test/Dialect/TMTensor/invalid.mlir @@ -4,7 +4,7 @@ func.func @scatter_mixed_tensor_memref( %update : memref, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : memref, tensor) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -20,7 +20,7 @@ func.func @scatter_mixed_tensor_memref( %update : tensor, %indices : memref, %original : tensor) -> tensor { // expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, memref) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -36,7 +36,7 @@ func.func @scatter_extra_outputs( %update : tensor, %indices : tensor, %original : tensor) -> (tensor, tensor) { // expected-error @+1 {{expected number of outputs to be same as the number of results}} - %0, %1 = tm_tensor.scatter unique_indices(true) + %0, %1 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -52,7 +52,7 @@ func.func @scatter_mixed_tensor_memref( %update : tensor, %indices : tensor, %original : memref) -> tensor { // expected-error @+1 {{expected inputs and outputs to be RankedTensorType or scalar}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : memref) { ^bb0(%arg1: f32, %arg2: f32): @@ -68,7 +68,7 @@ func.func @scatter_output_type_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor<4x?xf32> { // expected-error @+1 {{expected type of `outs` operand #0 'tensor' to be same as result type 'tensor<4x?xf32>'}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -84,7 +84,7 @@ func.func @scatter_mixed_tensor_memref( %update : memref, %indices : tensor, %original : memref) { // expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}} - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : memref, tensor) outs(%original : memref) { ^bb0(%arg1: f32, %arg2: f32): @@ -100,7 +100,7 @@ func.func @scatter_mixed_tensor_memref( %update : memref, %indices : memref, %original : tensor) { // expected-error @+1 {{expected inputs and outputs to be MemRefType or scalar}} - tm_tensor.scatter unique_indices(true) + tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : memref, memref) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -116,7 +116,7 @@ func.func @scatter_dim_mismatch( %update : tensor, %indices : tensor<48x1xi32>, %original : tensor) -> tensor { // expected-error @+1 {{mismatch in shape of indices and update value at dim#0}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor<48x1xi32>) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -132,7 +132,7 @@ func.func @scatter_dim_mismatch( %update : tensor<64x?xf32>, %indices : tensor<48x1xi32>, %original : tensor) -> tensor { // expected-error @+1 {{mismatch in shape of indices and update value at dim#0}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor<64x?xf32>, tensor<48x1xi32>) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -148,7 +148,7 @@ func.func @scatter_dim_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{op update value rank exceeds the rank of the original value}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): @@ -162,16 +162,16 @@ func.func @scatter_dim_mismatch( func.func @scatter_dim_mismatch( %update : tensor, %indices : tensor, - %original : tensor) -> tensor { - // expected-error @+1 {{mismatch in shape of update value dim#1 and original value at dim#1}} - %0 = tm_tensor.scatter unique_indices(true) + %original : tensor) -> tensor { + // expected-error @+1 {{shape of update value dim#1 exceeds original value at dim#1}} + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) - outs(%original : tensor) { + outs(%original : tensor) { ^bb0(%arg1: f32, %arg2: f32): %1 = arith.addf %arg1, %arg2 : f32 tm_tensor.yield %1 : f32 - } -> tensor - return %0 : tensor + } -> tensor + return %0 : tensor } // ----- @@ -180,7 +180,7 @@ func.func @scatter_region_type_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{expected region to have scalar argument of integer or float types}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: index, %arg2: index): @@ -197,7 +197,7 @@ func.func @scatter_region_type_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{mismatch in argument 0 of region 'i64' and element type of update value 'i32'}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i64, %arg2: i32): @@ -214,7 +214,7 @@ func.func @scatter_region_type_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{mismatch in argument 1 of region 'i64' and element type of original value 'i32'}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i32, %arg2: i64): @@ -231,7 +231,7 @@ func.func @scatter_region_type_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{mismatch in region argument types 'i32' and 'i64'}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i32, %arg2: i64): @@ -248,7 +248,7 @@ func.func @scatter_region_type_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{expected region to have two arguments}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i64, %arg2: i64, %arg3 : i64): @@ -264,7 +264,7 @@ func.func @scatter_region_type_mismatch( func.func @scatter_yield_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -281,7 +281,7 @@ func.func @scatter_yield_mismatch( func.func @scatter_yield_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -299,7 +299,7 @@ func.func @scatter_index_depth_dynamic( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{expected index depth is static}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i64, %arg2: i64): @@ -316,7 +316,7 @@ func.func @scatter_original_rank_mismatch( %update : tensor, %indices : tensor, %original : tensor) -> tensor { // expected-error @+1 {{op index depth and update value does not cover rank of original value}} - %0 = tm_tensor.scatter unique_indices(true) + %0 = tm_tensor.scatter {dimension_map= array} unique_indices(true) ins(%update, %indices : tensor, tensor) outs(%original : tensor) { ^bb0(%arg1: i64, %arg2: i64): From 30212547a9d750d6406383f6c706e6eae3395e37 Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Tue, 27 Feb 2024 11:49:32 -0800 Subject: [PATCH 228/283] [torch-mlir][sparse] add JIT test for block sparse SpMV (#2955) This required adding a "decompose" pass to the torch lowering, since torch.mv was not directly handled by lowering to linalg --- test/python/fx_importer/sparse_test.py | 36 +++++++++++++++++++------- 1 file changed, 26 insertions(+), 10 deletions(-) diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 87eecb2977d5..138942b07092 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -139,7 +139,11 @@ def sparse_jit(f, *args, **kwargs): module = export_and_import(f, *args, *kwargs) run_pipeline_with_repro_report( module, - "builtin.module(torch-backend-to-linalg-on-tensors-backend-pipeline)", + ( + "builtin.module(" + "func.func(torch-decompose-complex-ops)," + "torch-backend-to-linalg-on-tensors-backend-pipeline)" + ), "Lowering TorchFX IR -> Linalg IR", enable_ir_printing=False, ) @@ -200,13 +204,13 @@ def __init__(self): def forward(self, x): return x.sum() + net = SumNet() dense_input = torch.ones(64, 64) sparse_input = dense_input.to_sparse_csr() - m = export_and_import(SumNet(), sparse_input) + m = export_and_import(net, sparse_input) print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. - net = SumNet() res1 = net(sparse_input) res2 = sparse_jit(net, sparse_input) print("torch.sparse =", res1) @@ -222,6 +226,10 @@ def forward(self, x): # CHECK: %[[R:.*]] = torch.aten.mv %[[A]], %[[B]] : !torch.vtensor<[10,10],f32,#[[$BSR]]>, !torch.vtensor<[10],f32> -> !torch.vtensor<[10],f32> # CHECK: return %[[R]] : !torch.vtensor<[10],f32> # CHECK: } +# +# CHECK: torch.sparse = tensor([55., 55., 55., 55., 55., 55., 55., 55., 55., 55.]) +# CHECK: torch.mlir = [55. 55. 55. 55. 55. 55. 55. 55. 55. 55.] +# def test_sparse_SpMV(): class SpMVNet(torch.nn.Module): def __init__(self): @@ -230,12 +238,19 @@ def __init__(self): def forward(self, x, v): return torch.mv(x, v) - dense_vector = torch.ones(10) + net = SpMVNet() + dense_vector = torch.arange(1, 11, dtype=torch.float32) dense_input = torch.ones(10, 10) sparse_input = dense_input.to_sparse_bsr(blocksize=(2, 2)) - m = export_and_import(SpMVNet(), sparse_input, dense_vector) + m = export_and_import(net, sparse_input, dense_vector) print(m) + # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. + res1 = net(sparse_input, dense_vector) + res2 = sparse_jit(net, sparse_input, dense_vector) + print("torch.sparse =", res1) + print("torch.mlir =", res2) + @run # CHECK-LABEL: test_sparse_SpMM @@ -264,15 +279,15 @@ def __init__(self): def forward(self, x, y): return torch.matmul(x, y) + net = MatMulNet() dense_input = torch.ones(8, 8) sparse_input = dense_input.to_sparse_coo() - m = export_and_import(MatMulNet(), sparse_input, dense_input) + m = export_and_import(net, sparse_input, dense_input) print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. # TODO: run with COO, right now only CSR works sparse_input = dense_input.to_sparse_csr() - net = MatMulNet() res1 = net(sparse_input, dense_input) res2 = sparse_jit(net, sparse_input, dense_input) print("torch.sparse") @@ -311,6 +326,7 @@ def forward(self, x, y): # ... # CHECK: [-61. -62.] # CHECK: [-63. -64.]{{\]\]}} +# def test_sparse_eltwise(): class EltNet(torch.nn.Module): def __init__(self): @@ -319,18 +335,19 @@ def __init__(self): def forward(self, x): return -x + net = EltNet() dense_input = torch.reshape( torch.arange(1, 65, dtype=torch.float32), shape=(8, 4, 2) ) # This yields a **batched** CSR. sparse_input = dense_input.to_sparse_csr(dense_dim=0) - m = export_and_import(EltNet(), sparse_input) + m = export_and_import(net, sparse_input) print(m) # This yields a plain CSR with dense **sub**tensor sparse_input = dense_input.to_sparse_csr(dense_dim=1) - m = export_and_import(EltNet(), sparse_input) + m = export_and_import(net, sparse_input) print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. @@ -339,7 +356,6 @@ def forward(self, x): # (1) since we do not propagate sparsity into elt-wise, MLIR returns dense result # (2) for dense_dim=0, this will need a dense(batched) property sparse_input = dense_input.to_sparse_csr(dense_dim=1) - net = EltNet() res1 = net(sparse_input) res2 = sparse_jit(net, sparse_input) print("torch.sparse") From d541779f3754471f95d8ece8daa63e959168e22f Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT <58800592+Abhishek-TyRnT@users.noreply.github.com> Date: Wed, 28 Feb 2024 03:10:55 +0530 Subject: [PATCH 229/283] Add support for torch arange float module (#2749) Added Support for float dtype in in torch.arange in TOSA Dialect This resolves the following issue :- https://github.com/llvm/torch-mlir/issues/2762 The following test cases are passing after this change 1. ArangeDtypeIntModule_basic 2. ArangeFloatModule_basic 3. ArangeNegativeStartFloatModule_basic 4. ArangeStartFloatModule_basic 5. ArangeStartNegativeStepFloatModule_basic 6. ArangeStartOutDtypeModule_basic 7. ArangeStartStepFloatModule_basic --------- Co-authored-by: James Newling --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 153 ++++++++++++++++++--- projects/pt1/e2e_testing/xfail_sets.py | 8 ++ 2 files changed, 139 insertions(+), 22 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index b49c9af8adce..ce0a1af2f834 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -8,24 +8,23 @@ //===----------------------------------------------------------------------===// #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" -#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" -#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" -#include "torch-mlir/Conversion/Utils/Utils.h" - #include "../PassDetail.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" -#include "mlir/Dialect/Traits.h" #include "mlir/IR/Matchers.h" #include "mlir/Support/LogicalResult.h" #include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeCommon.h" +#include "torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h" +#include "torch-mlir/Conversion/Utils/Utils.h" #include "torch-mlir/Dialect/Torch/IR/TorchDialect.h" #include "torch-mlir/Dialect/Torch/IR/TorchOps.h" #include "torch-mlir/Dialect/Torch/IR/TorchTypes.h" #include "torch-mlir/Dialect/Torch/Utils/Utils.h" #include "torch-mlir/Dialect/TorchConversion/IR/TorchConversionDialect.h" #include "torch-mlir/Dialect/TorchConversion/Transforms/BackendTypeConversion.h" +#include using namespace mlir; using namespace mlir::torch; @@ -4067,28 +4066,138 @@ LogicalResult ConvertAtenOp::matchAndRewrite( op, "unimplemented: pin_memory must be either None or false"); } - int64_t start, step, end; - if (!matchPattern(op.getStart(), m_TorchConstantInt(&start))) + // Stores a range value (a start, end, or step value) and whether or not it + // was initiated with a constant integer, an constant float or neither. + class ConstRangeValue { + public: + explicit ConstRangeValue(double v) + : vDouble(v), fromDouble(true), vInt(static_cast(v)), + fromInt(false) {} + + explicit ConstRangeValue(int64_t v) + : vDouble(static_cast(v)), fromDouble(false), vInt(v), + fromInt(true) {} + + // Constructor for the case where there is no constant value to use. + ConstRangeValue() + : vDouble(0), fromDouble(false), vInt(0), fromInt(false) {} + + static ConstRangeValue fromValue(Value v) { + int64_t intVal{0}; + double floatVal{0.0}; + if (matchPattern(v, m_TorchConstantFloat(&floatVal))) { + return ConstRangeValue(floatVal); + } else if (matchPattern(v, m_TorchConstantInt(&intVal))) { + return ConstRangeValue(intVal); + } + return ConstRangeValue(); + } + + bool hasConstInt() const { return fromInt; } + bool hasConstDouble() const { return fromDouble; } + bool hasConst() const { return fromInt || fromDouble; } + double getDouble() const { return vDouble; } + int64_t getInt() const { return vInt; } + + private: + double vDouble; + bool fromDouble; + int64_t vInt; + bool fromInt; + }; + + auto start = ConstRangeValue::fromValue(op.getStart()); + if (!start.hasConst()) { return rewriter.notifyMatchFailure( - op, "unimplemented: value `start` should be a torch constant int"); + op, "unimplemented: case where `start` is not a constant int or float"); + } - if (!matchPattern(op.getEnd(), m_TorchConstantInt(&end))) + auto end = ConstRangeValue::fromValue(op.getEnd()); + if (!end.hasConst()) { return rewriter.notifyMatchFailure( - op, "unimplemented: value `end` should be a torch constant int"); + op, + "unimplemented: case where value `end` is not a constant int or float"); + } - if (!matchPattern(op.getStep(), m_TorchConstantInt(&step))) + auto step = ConstRangeValue::fromValue(op.getStep()); + if (!step.hasConst()) { + return rewriter.notifyMatchFailure(op, + "unimplemented: case where value `step` " + "is not a constant int or float"); + } + + auto getRange = [](auto start, auto end, auto step) { + // Initialize a small vector of the same type as start: + using T = decltype(start); + SmallVector values; + + uint64_t counter{0}; + if (start == end) { + return values; + } + assert(step != T(0)); + values.reserve( + 1 + static_cast(std::abs((end - start) / std::abs(step)))); + if (step > 0) { + while (start + T(counter) * step < end) { + values.push_back(start + counter * step); + counter++; + } + } else { + while (start + T(counter) * step > end) { + values.push_back(start + counter * step); + counter++; + } + } + return values; + }; + + const auto isIntType = + resultType.getElementType().dyn_cast_or_null(); + + const auto isDoubleType = + resultType.getElementType().dyn_cast_or_null(); + + auto maybeResult = [&]() -> std::optional { + // Integer output type, and start / end / range are all integers. + if (isIntType && start.hasConstInt() && end.hasConstInt() && + step.hasConstInt()) { + auto values = getRange(start.getInt(), end.getInt(), step.getInt()); + return tosa::getConstTensor(rewriter, op, values, values.size()); + } + + // Get a double range. + auto values = + getRange(start.getDouble(), end.getDouble(), step.getDouble()); + if (isIntType) { + SmallVector values_i64; + values_i64.reserve(values.size()); + for (auto v : values) { + values_i64.push_back(static_cast(v)); + } + return tosa::getConstTensor(rewriter, op, values_i64, + values.size()); + } + + if (!isDoubleType) { + return {}; + } + + SmallVector values_f32; + values_f32.reserve(values.size()); + for (auto v : values) { + values_f32.push_back(static_cast(v)); + } + auto vs = tosa::getConstTensor(rewriter, op, values_f32, + values_f32.size()); + return vs; + }(); + + if (!maybeResult.has_value()) { return rewriter.notifyMatchFailure( - op, "unimplemented: value `step` should be a torch constant int"); - - // The result will always be a 1-d tensor. - // The size of the result is calculated as follows: - // ceil((end - start)/step) - int64_t resultShape = ceil((float)(end - start) / (float)step); - SmallVector values(resultShape, start); - for (unsigned i = 1; i < resultShape; i++) - values[i] += i * step; - Value result = - tosa::getConstTensor(rewriter, op, values, resultShape).value(); + op, "failed to generate constant tensor for arange"); + } + auto result = maybeResult.value(); rewriter.replaceOpWithNewOp(op, resultType, result); return success(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 195a5e42f249..74f7300c9274 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -892,6 +892,14 @@ "ArangeStartOutViewModule_basic", "ArangeStartStepIntModule_basic", "ArangeZeroElementOutputModule_basic", + "ArangeDtypeIntModule_basic", + "ArangeFalsePinMemoryModule_basic", + "ArangeFloatModule_basic", + "ArangeNegativeStartFloatModule_basic", + "ArangeStartFloatModule_basic", + "ArangeStartNegativeStepFloatModule_basic", + "ArangeStartOutDtypeModule_basic", + "ArangeStartStepFloatModule_basic", "ArgmaxModule_keepDim", "ArgmaxModule_with_dim", "AtenComplex64Module_basic", From 4a7a7d76f8870cad43a1803312efce7a8ae8643b Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 27 Feb 2024 22:48:07 -0800 Subject: [PATCH 230/283] [onnx] Fix ReduceMean lowering to torch (#2956) Torch lowering only supported the most recent version. Refactored the lowering so more easily handle default values and optional operands / attributes. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 228 ++++---- lib/Conversion/TorchToLinalg/Reduction.cpp | 97 ++-- .../Torch/Transforms/DecomposeComplexOps.cpp | 52 ++ .../TorchConversion/Transforms/Passes.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 493 ++++++++++-------- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 186 ++++--- 6 files changed, 608 insertions(+), 449 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index bc2cde573967..adf6d1cb639a 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1104,129 +1104,145 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( Value axes; int64_t keepDims; int64_t noop_with_empty_axes; - // Deal with case when no axes arg is passed - if (binder.op->getNumOperands() == 1) { - if (binder.tensorOperand(data) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, - "noop_with_empty_axes", 0)) - return failure(); - if (noop_with_empty_axes == 0) { - Value keepDimsConstInt = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); - Value keepDimsBool = rewriter.create( - binder.getLoc(), keepDimsConstInt); - int64_t numDims = dyn_cast(data.getType()) - .getSizes() - .size(); - SmallVector axesList; - for (int i = 0; i < numDims; i++) { - Value curr = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - axesList.push_back(curr); - } - Value axesValueList = rewriter.create( - binder.getLoc(), - Torch::ListType::get( - Torch::IntType::get(binder.op->getContext())), - axesList); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, axesValueList, keepDimsBool); - } else { - rewriter.replaceOp(binder.op, data); - } - return success(); - } - if (binder.tensorOperands(data, axes) || + if (binder.tensorOperandAtIndex(data, 0) || binder.tensorResultType(resultType) || binder.s64IntegerAttr(keepDims, "keepdims", 1) || binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", 0)) return failure(); - Torch::BaseTensorType axesType = - axes.getType().cast(); - SmallVector dimList; - SmallVector selectSizes; - selectSizes.push_back(1); - Type selectResultType = axesType.getWithSizesAndDtype( - llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); - auto sizes = - dyn_cast(axes.getType()).getSizes(); - // deal with case when axes is empty - if (sizes.size() == 1 && sizes[0] == 0) { - if (noop_with_empty_axes == 0) { - // create dims list with all dims [0, data.getSizes().size()) - Value keepDimsConstInt = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); - Value keepDimsBool = rewriter.create( - binder.getLoc(), keepDimsConstInt); - int64_t numDims = dyn_cast(data.getType()) - .getSizes() - .size(); - for (int i = 0; i < numDims; i++) { - Value curr = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - dimList.push_back(curr); + + auto dataTy = cast(data.getType()); + Torch::IntType torchIntTy = rewriter.getType(); + + // If any of the input dims are 0 we set to the upper limit: + if (llvm::any_of(dataTy.getSizes(), [](int64_t d) { return d == 0; }) && + (llvm::any_of(dataTy.getSizes(), + [](int64_t d) { return d == Torch::kUnknownSize; }) || + keepDims)) { + auto dty = dataTy.getDtype(); + Value scalar; + if (FloatType fpTy = dyn_cast(dty)) { + auto inf = APFloat::getInf(fpTy.getFloatSemantics()); + scalar = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), + inf.convertToDouble())); + } + + if (IntegerType intTy = dyn_cast(dty)) { + auto mx = + intTy.isSigned() + ? APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth()) + : APInt::getMaxValue(intTy.getIntOrFloatBitWidth()); + scalar = rewriter.create( + binder.getLoc(), torchIntTy, + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + mx.getSExtValue())); + } + + llvm::SmallVector fillDims; + for (int i = 0, s = resultType.getSizes().size(); i < s; ++i) { + auto staticDim = resultType.getSizes()[i]; + if (staticDim != Torch::kUnknownSize) { + fillDims.push_back(rewriter.create( + binder.getLoc(), torchIntTy, + rewriter.getI64IntegerAttr(staticDim))); + continue; } - Value dimValueList = rewriter.create( - binder.getLoc(), - Torch::ListType::get( - Torch::IntType::get(binder.op->getContext())), - dimList); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, dimValueList, keepDimsBool); - } else { - rewriter.replaceOp(binder.op, data); + + Value iv = rewriter.create( + binder.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(i)); + fillDims.push_back(rewriter.create( + binder.getLoc(), torchIntTy, data, iv)); } + + Value none = rewriter.create(binder.getLoc()); + Value fillDimsList = rewriter.create( + binder.getLoc(), Torch::ListType::get(torchIntTy), fillDims); + rewriter.replaceOpWithNewOp( + binder.op, resultType, fillDimsList, scalar, none, none, none, + none); return success(); } + + // Previous version of the operation had the axes as an attribute: + SmallVector axesList; + llvm::SmallVector axesAttr; + if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) { + for (int i = 0, s = axesAttr.size(); i < s; ++i) { + axesList.push_back(rewriter.create( + binder.getLoc(), torchIntTy, + rewriter.getI64IntegerAttr(axesAttr[i]))); + } + } + + // Extract the axes values from the axes operand: + if (!binder.tensorOperandAtIndex(axes, 1)) { + Torch::BaseTensorType axesType = + axes.getType().cast(); + SmallVector selectSizes{1}; + Type selectResultType = axesType.getWithSizesAndDtype( + selectSizes, axesType.getOptionalDtype()); + auto sizes = axesType.getSizes(); + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + + // Extract the value of each axes: + for (int i = 0; i < sizes[0]; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, axes, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + axesList.push_back(dim); + } + } + + // Handle the noop case: + if (axesList.empty() && noop_with_empty_axes) { + rewriter.replaceOp(binder.op, data); + return success(); + } + + // Deal with case when no axes arg is passed but not a noop: + if (axesList.empty()) { + int64_t numDims = dyn_cast(data.getType()) + .getSizes() + .size(); + for (int i = 0; i < numDims; i++) { + Value curr = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + axesList.push_back(curr); + } + } + + // Handle negative axis: + Value rankVal = rewriter.create(binder.getLoc(), + torchIntTy, data); Value zero = rewriter.create( binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - int64_t adjustmentInt = - cast(data.getType()).getSizes().size(); - Value adjustment = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - adjustmentInt)); - // convert axes (tensor) into torch int list while dealing with neg axis - for (int i = 0; i < sizes[0]; i++) { - // Go through the axes list and get each dim in the list - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, axes, zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); - // deal with neg axis: if (axis < 0) axis += rank + rewriter.getI64IntegerAttr(0)); + for (Value &axes : axesList) { Value isNegative = - rewriter.create(binder.getLoc(), dim, zero); + rewriter.create(binder.getLoc(), axes, zero); isNegative = rewriter.create(binder.getLoc(), isNegative); Value finalOffset = rewriter.create( - binder.getLoc(), isNegative, adjustment); - Value finalDim = rewriter.create( - binder.getLoc(), dim, finalOffset); - dimList.push_back(finalDim); + binder.getLoc(), isNegative, rankVal); + axes = rewriter.create(binder.getLoc(), axes, + finalOffset); } + Value dimValueList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - dimList); - Value keepDimBool; - if (keepDims == 1) { - keepDimBool = - rewriter.create(binder.getLoc(), true); - } else { - keepDimBool = - rewriter.create(binder.getLoc(), false); - } + binder.getLoc(), Torch::ListType::get(torchIntTy), axesList); + Value keepDimBool = + rewriter.create(binder.getLoc(), keepDims); rewriter.replaceOpWithNewOp( binder.op, resultType, data, dimValueList, keepDimBool); return success(); diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index e050764993e6..92f50523c764 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -60,18 +60,15 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { Location loc = op.getLoc(); Value input = adaptor.getSelf(); - RankedTensorType valResultType = - getTypeConverter() - ->convertType(op.getResult(0).getType()) - .template cast(); - - RankedTensorType idxResultType = - this->getTypeConverter() - ->convertType(op.getResult(1).getType()) - .template cast(); + auto typec = this->getTypeConverter(); + auto valResultType = + cast(typec->convertType(op.getResult(0).getType())); + auto idxResultType = + cast(typec->convertType(op.getResult(1).getType())); RankedTensorType inputType = input.getType().template cast(); - Type idxElementType = idxResultType.getElementType(); + Type idxElementType = + getElementTypeOrSelf(typec->convertType(idxResultType)); if (!idxElementType.isa()) return rewriter.notifyMatchFailure( op, opName + " to linalg.* requires integer-like result type"); @@ -109,14 +106,12 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { } // Constant op to account for the reduction along dim. - auto c1 = rewriter.create(loc, /*value=*/1); SmallVector resultShape; for (int64_t i = 0; i < inputType.getRank(); i++) { if (dim != i) { auto currentDimSize = rewriter.create(loc, input, i); resultShape.push_back(currentDimSize); - } else if (keepDim) - resultShape.push_back(c1); + } } // First fill the output buffer for the index. Value filledTensorIdx = @@ -146,27 +141,23 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { Value filledTensorVal = rewriter.create(loc, fillValue, initTensorVal).result(); + SmallVector iteratorTypes( + inputType.getRank(), utils::IteratorType::parallel); + iteratorTypes[dim] = utils::IteratorType::reduction; + // Create the affine expressions that will be used to // iterate over the input and output tensors. // Here we also set the type of iterator: parallel or reduction. + SmallVector exprs; - SmallVector iteratorTypes; SmallVector resultExprs; for (auto size : llvm::enumerate(makeShapeTorchCompatible(inputType.getShape()))) { exprs.push_back(rewriter.getAffineDimExpr(size.index())); - - if (unsigned(dim) == size.index()) { - iteratorTypes.push_back(utils::IteratorType::reduction); - // If `keepDim`, create affine map to the first element - // in the current dimension. - if (keepDim) - resultExprs.push_back(rewriter.getAffineConstantExpr(0)); - } else { - iteratorTypes.push_back(utils::IteratorType::parallel); + if (unsigned(dim) != size.index()) resultExprs.push_back(rewriter.getAffineDimExpr(size.index())); - } } + auto maps = AffineMap::inferFromExprList({exprs, resultExprs, resultExprs}, rewriter.getContext()); auto linalgOp = rewriter.create( @@ -219,12 +210,58 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { nestedLoc, ValueRange({resultVal, resultIndex})); }); - // This cast is required to fix the shape in the case of keepDim=True - Value valuesCast = rewriter.create(loc, valResultType, - linalgOp.getResult(0)); - Value idxCast = rewriter.create(loc, idxResultType, - linalgOp.getResult(1)); - rewriter.replaceOp(op, {valuesCast, idxCast}); + if (!keepDim) { + Value rVal = rewriter.create(loc, valResultType, + linalgOp.getResult(0)); + Value rIdx = rewriter.create(loc, idxResultType, + linalgOp.getResult(1)); + llvm::SmallVector res{rVal, rIdx}; + rewriter.replaceOp(op, res); + return success(); + } + + llvm::SmallVector valShape(valResultType.getShape()); + llvm::SmallVector idxShape(idxResultType.getShape()); + for (int i = dim, s = valShape.size() - 1; i < s; ++i) { + valShape[i] = valShape[i + 1]; + idxShape[i] = idxShape[i + 1]; + } + + valShape.resize(valShape.size() - 1); + idxShape.resize(idxShape.size() - 1); + + Value rVal = rewriter.create( + loc, valResultType.clone(valShape), linalgOp.getResult(0)); + Value rIdx = rewriter.create( + loc, idxResultType.clone(idxShape), linalgOp.getResult(1)); + + SmallVector reassociation(valShape.size()); + if (reassociation.size() > 0) { + for (int i = 0; i < dim; ++i) + reassociation[i].push_back(i); + reassociation[std::max(0, dim - 1)].push_back(dim); + for (int i = dim, s = reassociation.size(); i < s; ++i) + reassociation[i].push_back(i + 1); + } + + valShape.push_back(0); + idxShape.push_back(0); + for (int i = dim, s = valShape.size() - 1; i < s; ++i) { + valShape[i + 1] = valShape[i]; + idxShape[i + 1] = idxShape[i]; + } + + valShape[dim] = 1; + idxShape[dim] = 1; + + Value unsqueezeVal = rewriter.create( + loc, valResultType, rVal, reassociation); + + Value unsqueezeIdx = rewriter.create( + loc, idxResultType, rIdx, reassociation); + + llvm::SmallVector unsqueezes = {unsqueezeVal, unsqueezeIdx}; + rewriter.replaceOp(op, unsqueezes); return success(); } }; diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index f9c1f63b568c..51a710d940e9 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1316,6 +1316,57 @@ class DecomposeAten_LogSoftmaxBackwardDataOp }; } // namespace +namespace { +class DecomposeAtenAMinMaxOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(Torch::AtenAminOp op, + PatternRewriter &rewriter) const override { + llvm::SmallVector dimList; + if (!matchPattern(op.getDim(), m_TorchListOfConstantInts(dimList))) { + return rewriter.notifyMatchFailure(op, "dims not foldable constants"); + } + + bool keepdim; + if (!matchPattern(op.getKeepdim(), m_TorchConstantBool(&keepdim))) { + return rewriter.notifyMatchFailure(op, "keepdims not foldable constants"); + } + + auto loc = op.getLoc(); + std::sort(dimList.begin(), dimList.end(), std::greater()); + + Value reduction = op.getSelf(); + auto resultTy = cast(op.getType()); + auto reductionTy = cast(reduction.getType()); + llvm::SmallVector reductionShape(reductionTy.getSizes()); + + for (auto dim : dimList) { + auto dimValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(dim)); + reductionShape[dim] = 1; + if (!keepdim) { + for (int i = dim, s = reductionShape.size() - 1; i < s; ++i) + reductionShape[i] = reductionShape[i + 1]; + reductionShape.resize(reductionShape.size() - 1); + } + + reductionTy = rewriter.getType( + reductionShape, resultTy.getOptionalDtype()); + auto idxTy = rewriter.getType( + reductionShape, rewriter.getIntegerType(32, /*is_signed*/ true)); + llvm::SmallVector types{reductionTy, idxTy}; + reduction = rewriter + .create(loc, types, reduction, + dimValue, op.getKeepdim()) + .getResult(0); + } + + rewriter.replaceOp(op, reduction); + return success(); + } +}; +} // namespace + // Decompose `AtenArgMaxOp` into `AtenMaxDimOp` as well as `AtenArgMinOp` into // `AtenMinDimOp` namespace { @@ -6867,6 +6918,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 7ac95ab6c4e9..55bedc1192eb 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -77,6 +77,7 @@ void TorchConversion::createTorchBackendToLinalgOnTensorsBackendPipeline( pm.addNestedPass(createConvertTorchToTMTensorPass()); pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createConvertTorchToLinalgPass()); + pm.addNestedPass(createCanonicalizerPass()); pm.addNestedPass(createConvertTorchToSCFPass()); pm.addNestedPass(createConvertTorchToArithPass()); pm.addNestedPass(createConvertTorchToTensorPass()); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 74f7300c9274..a4ac58b1d909 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1472,6 +1472,62 @@ } ONNX_XFAIL_SET = { + # Failure - cast error + "MeanDimNoneDimModule_basic", + "MeanDtypeModule_basic", + "MeanDynamicSizesModule_basic", + "MeanModule_basic", + "MseLossMeanReductionModule_basic", + "PermuteNegativeIndexModule_basic", + "StdBiasedModule_basic", + "VarBiasedModule_basic", + "VarMeanBiasedModule_basic", + + # Failure - constant int lowering + "SplitTensorGetItem_Module_basic", + "SplitTensorLastSmallerModule_basic", + "SplitTensorListUnpackModule_basic", + "SplitTensorNegativeDimModule_basic", + "SplitWithSizesListUnpackModule_basic", + "UnbindIntGetItem_Module_basic", + "UnbindIntListUnpack_Module_basic", + + # Failure - incorrect numerics + "AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic", + "AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic", + "ElementwiseAtan2TensorIntModule_basic", + "ElementwiseLog10IntModule_basic", + "ElementwiseLog2IntModule_basic", + "ElementwiseSeluModule_basic", + "FlipModuleStaticShape_basic", + "FlipNegativeIndexModule_basic", + "HardsigmoidModule_basic", + "HardsigmoidRandomModule_basic", + "IndexSelectDynamicInputSizeModule_basic", + "IndexSelectWholeDimensionModule_basic", + "IndexSelectWholeTensorModule_basic", + "IndexTensorStaticModule_basic", + "IndexTensorStaticNonContiguousWithNoneModule_basic", + "PixelShuffleModuleStaticRank4Float32_basic", + "ResNet18Module_basic", + "SliceCopyEndGreaterThanDimSize_Module_basic", + "SliceCopyNegative_Module_basic", + "SliceCopyNonZeroDim_Module_basic", + "SliceCopy_Module_basic", + "TupleModule_basic", + + # Failure - incorrect shape + "ArangeStartOutDtypeModule_basic", + "ArangeStartOutViewModule_basic", + "BroadcastDynamicDimModule_basic", + "BroadcastToModule_basic", + "ExpandModule_basic", + "MoveDimIntNegativeIndexModule_basic", + "ReduceAmaxKeepDim_basic", + "ReduceMaxKeepDimReturnBoth_basic", + "ReduceMaxNegativeDim_basic", + "ViewSizeFromOtherTensor_basic", + # Failure - onnx_export "AdaptiveAvgPool1dGeneralDynamic_basic", "AdaptiveAvgPool1dNonUnitOutputSizeDynamicModule_basic", @@ -1594,6 +1650,7 @@ "EmptyStridedSizeIntStrideModule_basic", "EqIntModule_basic", "ExponentialModule_basic", + "FloatImplicitModule_basic", "GeFloatIntModule_basic", "GeFloatModule_basic", "GeIntModule_basic", @@ -1613,6 +1670,7 @@ "IndexPutImpl3DFloatNonAccumulateModule_basic", "IndexPutImplIndexWithNoneModule_basic", "IntFloatModule_basic", + "IntImplicitModule_basic", "IouOfModule_basic", "IsFloatingPointFloat_True", "IsFloatingPointInt_False", @@ -1818,13 +1876,8 @@ "_ConvolutionDeprecated2DCudnnModule_basic", "_ConvolutionDeprecated2DDeterministicModule_basic", "_SoftmaxModule_basic", - + # Failure - onnx_import - "BucketizeTensorFloatModule_basic", - "BucketizeTensorModule_basic", - "BucketizeTensorOutInt32RightModule_basic", - "BucketizeTensorStaticFloatModule_basic", - "BucketizeTensorStaticModule_basic", "DiagonalModule_basic", "DiagonalModule_nonsquare", "DiagonalModule_transposed", @@ -1832,31 +1885,6 @@ "DiagonalModule_with_dims_and_offset", "DiagonalModule_with_negative_dims", "DiagonalModule_with_offset", - "ElementwiseClampMaxModule_basic", - "ElementwiseClampMinModule_basic", - "ElementwiseClampMinTensorFloatModule_basic", - "ElementwiseClampMinTensorIntModule_basic", - "ElementwiseClampModule_basic", - "ElementwiseClampTensorFloatModule_basic", - "ElementwiseClampTensorInt8Module_basic", - "ElementwiseClampTensorIntModule_basic", - "HBC_basic", - "IndexPut1DFloatAccumulateModule_basic", - "IndexPut1DIntAccumulateModule_basic", - "IndexPut2DFloatAccumulateModule_basic", - "IndexPut2DIntAccumulateModule_basic", - "IndexPut3DFloatAccumulateModule_basic", - "IndexPut3DIntAccumulateModule_basic", - "IndexPutHackedTwin1DFloatAccumulateModule_basic", - "IndexPutHackedTwin1DIntAccumulateModule_basic", - "IndexPutHackedTwin2DFloatAccumulateModule_basic", - "IndexPutHackedTwin2DIntAccumulateModule_basic", - "IndexPutHackedTwin3DFloatAccumulateModule_basic", - "IndexPutHackedTwin3DIntAccumulateModule_basic", - "NormalizeModule_basic", - "PadWithNoneValModule_basic", - "QuantizedMLP_basic", - "RandModule_basic", "ScatterReduceFloatMaxModuleIncludeSelf", "ScatterReduceFloatMinModuleIncludeSelf", "ScatterReduceFloatProdModuleIncludeSelf", @@ -1867,21 +1895,11 @@ "ScatterReduceIntSumModuleIncludeSelf", "TileBigDimsSizeModule_basic", "TileSmallDimsSizeModule_basic", - "UpSampleNearest2dDynamicSize_basic", - "UpSampleNearest2dStaticSize_basic", - - # Failure - onnx_lowering + + # Failure - onnx_lowering: onnx.AveragePool "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", "AdaptiveAvgPool1dStaticEvenMultiple_basic", "AdaptiveAvgPool2dNonUnitOutputSizeStaticModule_basic", - "AtenMmFloatTypes_basic", - "AtenMmIntTypes_basic", - "AtenTrilModule_basic", - "AtenTrilWithNegDiagonalModule_basic", - "AtenTrilWithPosDiagonalModule_basic", - "AtenTriuModule_basic", - "AtenTriuWithNegDiagonalModule_basic", - "AtenTriuWithPosDiagonalModule_basic", "AvgPool1dFloatModule_basic", "AvgPool1dIntModule_basic", "AvgPool1dStaticModule_basic", @@ -1890,78 +1908,73 @@ "AvgPool2dFloatModule_basic", "AvgPool2dIntModule_basic", "AvgPool2dStaticModule_basic", - "BernoulliFloatModule_basic", - "BernoulliModule_basic", - "BernoulliPModule_basic", - "BernoulliTensorModule_basic", - "ConstantPad2dStaticModule_basic", - "ConstantPadNdModule_basic", - "ConstantPadNdPartialStaticModule_basic", - "ConstantPadNdStaticModule_basic", - "CrossEntropyLossModule_basic", - "CrossEntropyLossNoReductionModule_basic", - "DropoutTrainModule_basic", - "DropoutTrainStaticShapeModule_basic", + + # Failure - onnx_lowering: onnx.Cast + "BucketizeTensorOutInt32RightModule_basic", + "ElementwiseToDtypeI64ToI8Module_basic", + "ElementwiseToDtypeI64ToUI8Module_basic", + "HBC_basic", + "QuantizedMLP_basic", + "TypeConversionI1ToI32Module_basic", + "TypeConversionI64ToI32Module_basic", + + # Failure - onnx_lowering: onnx.Clip + "ElementwiseClampMaxModule_basic", + "ElementwiseClampMinModule_basic", + "ElementwiseClampMinTensorFloatModule_basic", + "ElementwiseClampMinTensorIntModule_basic", + "ElementwiseClampModule_basic", + "ElementwiseClampTensorFloatModule_basic", + "ElementwiseClampTensorInt8Module_basic", + "ElementwiseClampTensorIntModule_basic", + "NormalizeModule_basic", + + # Failure - onnx_lowering: onnx.Einsum "EinsumStaticContractRhsModule_basic", "EinsumStaticFourDimensionModule_basic", "EinsumStaticModule_basic", - "ElementwiseMishModule_basic", - "ElementwiseRemainderScalarModule_Bool_basic", - "ElementwiseRemainderScalarModule_Int_basic", - "ElementwiseToDtypeI64ToI8Module_basic", - "ElementwiseToDtypeI64ToUI8Module_basic", - "GroupNormModule_basic", - "GroupNormNoWeightAndBiasModule_basic", + + # Failure - onnx_lowering: onnx.Gemm + "AtenMmFloatTypes_basic", + "AtenMmIntTypes_basic", + "MmDagModule_basic", + "MmModule_basic", + "MmModule_chained", + "MmTanhModule_basic", + + # Failure - onnx_lowering: onnx.HardSwish "HardswishModule_basic", "HardswishRandomModule_basic", - "IndexPut1DFloatNonAccumulateModule_basic", - "IndexPut1DIntNonAccumulateModule_basic", - "IndexPut2DFloatNonAccumulateModule_basic", - "IndexPut2DIntNonAccumulateModule_basic", - "IndexPut3DFloatNonAccumulateModule_basic", - "IndexPut3DIntNonAccumulateModule_basic", - "IndexPutHackedTwin1DFloatNonAccumulateModule_basic", - "IndexPutHackedTwin1DIntNonAccumulateModule_basic", - "IndexPutHackedTwin2DFloatNonAccumulateModule_basic", - "IndexPutHackedTwin2DIntNonAccumulateModule_basic", - "IndexPutHackedTwin3DFloatNonAccumulateModule_basic", - "IndexPutHackedTwin3DIntNonAccumulateModule_basic", + "MobilenetV3Module_basic", + + # Failure - onnx_lowering: onnx.LogSoftmax "LogSoftmaxIntModule_basic", + "_LogSoftmaxModuleStable_basic", + "_LogSoftmaxModule_basic", + + # Failure - onnx_lowering: onnx.MaxPool "MaxPool2dWithIndicesAllNegativeValuesModule_basic", "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", "MaxPool2dWithIndicesStaticModule_basic", - "MmDagModule_basic", - "MmModule_basic", - "MmModule_chained", - "MmTanhModule_basic", - "MobilenetV3Module_basic", - "MseLossSumReductionWithDifferentElemTypeModule_basic", - "NativeDropoutTrainModule_basic", - "NativeDropoutTrainStaticShapeModule_basic", + + # Failure - onnx_lowering: onnx.Mod + "ElementwiseRemainderScalarModule_Bool_basic", + "ElementwiseRemainderScalarModule_Int_basic", + "UnflattenIntNegativeOneDimStaticModule_basic", + "UnflattenIntNegativeOneSizeStaticModule_basic", + "UnflattenIntStaticModule_basic", + "UnflattenStaticModule_basic", + + # Failure - onnx_lowering: onnx.OneHot "OneHotModule_basic", + + # Failure - onnx_lowering: onnx.Pad + "ConstantPad2dStaticModule_basic", + "ConstantPadNdModule_basic", + "ConstantPadNdPartialStaticModule_basic", + "ConstantPadNdStaticModule_basic", "PadModule_basic", - "RandIntLowDtypeModule_basic", - "RandIntLowModule_basic", - "RandLikeDtypeModule_basic", - "RandLikeModule_basic", - "RandnDtypeDeviceModule_basic", - "RandnGeneratorF64Module_basic", - "RandnGeneratorModule_basic", - "RandnLikeDtypeModule_basic", - "RandnLikeModule_basic", - "RandnModule_basic", - "ReduceL1NormModule_basic", - "ReduceL1NormWithDTypeModule_basic", - "ReduceL2NormModule_basic", - "ReduceL3NormAllDimsModule_basic", - "ReduceL3NormKeepDimModule_basic", - "ReduceProdDimIntFloatModule_basic", - "ReduceSumDtypeFloatModule_basic", - "ReduceSumDtypeIntModule_basic", - "ReduceSumElementTypeBoolModule_basic", - "ReduceSumFloatModule_basic", - "ReduceSumSignedIntModule_basic", - "ReduceSumUnsignedIntModule_basic", + "PadWithNoneValModule_basic", "ReflectionPad1dModule2dInput_Right", "ReflectionPad1dModule2dInput_basic", "ReflectionPad1dModule3dInput_Left", @@ -1976,19 +1989,43 @@ "ReplicationPad2dModule_left0", "ReplicationPad2dModule_right0", "ReplicationPad2dModule_top0", - "ScatterSrcModule_basic", - "ScatterSrcStaticModule_basic", - "ScatterValueFloatModule_basic", - "ScatterValueIntModule_basic", - "SoftplusModule_basic", - "SortTensorDescending_basic", - "SortTensorInteger_basic", - "SortTensorNegativeDimension_basic", - "SortTensorSpecificDimension_basic", - "SortTensor_basic", - "SqueezeModule_allUnitDim", - "SqueezeModule_broadcast", - "SqueezeModule_static", + + # Failure - onnx_lowering: onnx.RandomNormal + "RandnDtypeDeviceModule_basic", + "RandnGeneratorF64Module_basic", + "RandnGeneratorModule_basic", + "RandnModule_basic", + + # Failure - onnx_lowering: onnx.RandomNormalLike + "RandnLikeDtypeModule_basic", + "RandnLikeModule_basic", + + # Failure - onnx_lowering: onnx.RandomUniform + "RandIntLowDtypeModule_basic", + "RandIntLowModule_basic", + + # Failure - onnx_lowering: onnx.RandomUniformLike + "BernoulliFloatModule_basic", + "BernoulliPModule_basic", + "BernoulliTensorModule_basic", + "RandLikeDtypeModule_basic", + "RandLikeModule_basic", + "RandModule_basic", + + # Failure - onnx_lowering: onnx.ReduceL1 + "ReduceL1NormModule_basic", + "ReduceL1NormWithDTypeModule_basic", + + # Failure - onnx_lowering: onnx.ReduceL2 + "ReduceL2NormModule_basic", + + # Failure - onnx_lowering: onnx.ReduceProd + "BernoulliModule_basic", + "DropoutTrainModule_basic", + "DropoutTrainStaticShapeModule_basic", + "NativeDropoutTrainModule_basic", + "NativeDropoutTrainStaticShapeModule_basic", + "ReduceProdDimIntFloatModule_basic", "StdCorrectionAllDimReduceModule_basic", "StdCorrectionKeepDimModule_basic", "StdCorrectionLargeInputModule_basic", @@ -1999,14 +2036,6 @@ "StdDimKeepDimTrueModule_basic", "StdDimNoneDimModule_basic", "StdUnbiasedModule_basic", - "TriuBroadcastModule_basic", - "TriuModule_basic", - "TypeConversionI1ToI32Module_basic", - "TypeConversionI64ToI32Module_basic", - "UnflattenIntNegativeOneDimStaticModule_basic", - "UnflattenIntNegativeOneSizeStaticModule_basic", - "UnflattenIntStaticModule_basic", - "UnflattenStaticModule_basic", "VarCorrectionAllDimReduceModule_basic", "VarCorrectionKeepDimModule_basic", "VarCorrectionLargeInputModule_basic", @@ -2025,58 +2054,85 @@ "VarMeanDimModule_basic", "VarMeanUnbiasedModule_basic", "VarUnbiasedModule_basic", - "_LogSoftmaxModuleStable_basic", - "_LogSoftmaxModule_basic", - - # Failure - cast_error - "MeanDimNoneDimModule_basic", - "MeanDtypeModule_basic", - "MeanDynamicSizesModule_basic", - "MeanModule_basic", - "MseLossMeanReductionModule_basic", - "StdBiasedModule_basic", - "VarBiasedModule_basic", - "VarMeanBiasedModule_basic", - - # Failure - constant_int - "ReduceMinAlongDimNegative_basic", - "ReduceMinAlongDimSignedInt_basic", - "ReduceMinAlongDim_basic", - "ReduceMinFloatModule_basic", - "ReduceMinKeepDimReturnBoth_basic", - "ReduceMinSignedIntModule_basic", - "ReduceMinUnsignedIntModule_basic", - "SplitTensorGetItem_Module_basic", - "SplitTensorLastSmallerModule_basic", - "SplitTensorListUnpackModule_basic", - "SplitTensorNegativeDimModule_basic", - "SplitWithSizesListUnpackModule_basic", - "UnbindIntGetItem_Module_basic", - "UnbindIntListUnpack_Module_basic", - - # Failure - operand_type - "ElementwiseAcosIntModule_basic", - "ElementwiseAsinIntModule_basic", - "ElementwiseAtanTensorIntModule_basic", - "ElementwiseCosIntModule_basic", - "ElementwiseErfIntModule_basic", - "ElementwiseExpIntModule_basic", - "ElementwiseLog10IntModule_basic", - "ElementwiseLog2IntModule_basic", - "ElementwiseLogIntModule_basic", - "ElementwiseSinIntModule_basic", - "ElementwiseTanIntModule_basic", - "ElementwiseUnaryIntModule_basic", - - # Failure - expand_multidim - "IndexTensorHackedTwinModule3dInput_basic", - "IndexTensorHackedTwinModule_basic", - "IndexTensorModule3dInput_basic", - "IndexTensorModule_basic", - "IndexTensorMultiInputContiguousOneDimDynamic_basic", - "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", - - # Failure - rankless_return + + # Failure - onnx_lowering: onnx.ReduceSum + "MseLossSumReductionWithDifferentElemTypeModule_basic", + "ReduceL3NormAllDimsModule_basic", + "ReduceL3NormKeepDimModule_basic", + "ReduceSumDtypeFloatModule_basic", + "ReduceSumDtypeIntModule_basic", + "ReduceSumElementTypeBoolModule_basic", + "ReduceSumFloatModule_basic", + "ReduceSumSignedIntModule_basic", + "ReduceSumUnsignedIntModule_basic", + + # Failure - onnx_lowering: onnx.Resize + "UpSampleNearest2dDynamicSize_basic", + "UpSampleNearest2dStaticSize_basic", + + # Failure - onnx_lowering: onnx.ScatterElements + "ScatterSrcModule_basic", + "ScatterSrcStaticModule_basic", + "ScatterValueFloatModule_basic", + "ScatterValueIntModule_basic", + + # Failure - onnx_lowering: onnx.ScatterND + "IndexPut1DFloatAccumulateModule_basic", + "IndexPut1DFloatNonAccumulateModule_basic", + "IndexPut1DIntAccumulateModule_basic", + "IndexPut1DIntNonAccumulateModule_basic", + "IndexPut2DFloatAccumulateModule_basic", + "IndexPut2DFloatNonAccumulateModule_basic", + "IndexPut2DIntAccumulateModule_basic", + "IndexPut2DIntNonAccumulateModule_basic", + "IndexPut3DFloatAccumulateModule_basic", + "IndexPut3DFloatNonAccumulateModule_basic", + "IndexPut3DIntAccumulateModule_basic", + "IndexPut3DIntNonAccumulateModule_basic", + "IndexPutHackedTwin1DFloatAccumulateModule_basic", + "IndexPutHackedTwin1DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin1DIntAccumulateModule_basic", + "IndexPutHackedTwin1DIntNonAccumulateModule_basic", + "IndexPutHackedTwin2DFloatAccumulateModule_basic", + "IndexPutHackedTwin2DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin2DIntAccumulateModule_basic", + "IndexPutHackedTwin2DIntNonAccumulateModule_basic", + "IndexPutHackedTwin3DFloatAccumulateModule_basic", + "IndexPutHackedTwin3DFloatNonAccumulateModule_basic", + "IndexPutHackedTwin3DIntAccumulateModule_basic", + "IndexPutHackedTwin3DIntNonAccumulateModule_basic", + + # Failure - onnx_lowering: onnx.SoftmaxCrossEntropyLoss + "CrossEntropyLossModule_basic", + "CrossEntropyLossNoReductionModule_basic", + + # Failure - onnx_lowering: onnx.Softplus + "ElementwiseMishModule_basic", + "SoftplusModule_basic", + + # Failure - onnx_lowering: onnx.Squeeze + "SqueezeModule_allUnitDim", + "SqueezeModule_broadcast", + "SqueezeModule_static", + + # Failure - onnx_lowering: onnx.TopK + "SortTensorDescending_basic", + "SortTensorInteger_basic", + "SortTensorNegativeDimension_basic", + "SortTensorSpecificDimension_basic", + "SortTensor_basic", + + # Failure - onnx_lowering: onnx.Trilu + "AtenTrilModule_basic", + "AtenTrilWithNegDiagonalModule_basic", + "AtenTrilWithPosDiagonalModule_basic", + "AtenTriuModule_basic", + "AtenTriuWithNegDiagonalModule_basic", + "AtenTriuWithPosDiagonalModule_basic", + "TriuBroadcastModule_basic", + "TriuModule_basic", + + # Failure - rankless return "ReduceAmaxMultiDim_basic", "ReduceAmaxOutOfOrderDim_basic", "ReduceAmaxSingleDim_basic", @@ -2088,8 +2144,8 @@ "ReduceMaxFloatModule_basic", "ReduceMaxSignedIntModule_basic", "ReduceMaxUnsignedIntModule_basic", - - # Failure - view_lowering + + # Failure - torch.aten.view lower "AddSizeIntModule_basic", "ElementwiseFlattenBroadcastModule_basic", "FlattenRank0Module_basic", @@ -2097,13 +2153,11 @@ "IndexTensorDyanmicInputNonContiguousWithNoneModule_basic", "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", "IndexTensorMultiInputContiguousCenter_basic", - "IndexTensorMultiInputNonContiguousDynamic_basic", "IndexTensorMultiInputNonContiguousMultipleStaticDims_basic", "IndexTensorMultiInputNonContiguous_basic", "IndexTensorMultiInputOneDim_basic", "IndexTensorMultiInputThreeIndexers_basic", "IndexTensorMultiInput_basic", - "IndexTensorSelectDimModule_basic", "IndexTensorStaticContiguousWithNoneModule_basic", "RepeatModule_basic", "SelectIntModule_basic", @@ -2116,63 +2170,50 @@ "ViewSizeDimLedAndFollowedByExpandedOnesModule_basic", "ViewSizeDimLedByCollapsedOnesModule_basic", "ViewSizeDimLedByExpandedOnesModule_basic", - - # Failure - numerical - "AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic", - "AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic", - "ElementwiseSeluModule_basic", - "EmbeddingModule1DIndices_basic", - "FlipNegativeIndexModule_basic", - "HardsigmoidModule_basic", - "HardsigmoidRandomModule_basic", - "IndexSelectDynamicIndexSizeModule_basic", - "IndexSelectDynamicInputSizeModule_basic", - "IndexSelectDynamicModulebasic", - "IndexSelectWholeDimensionModule_basic", - "IndexSelectWholeTensorModule_basic", - "IndexTensorStaticModule_basic", - "IndexTensorStaticNonContiguousWithNoneModule_basic", - "PixelShuffleModuleStaticRank4Float32_basic", - "ResNet18Module_basic", - "SliceCopyEndGreaterThanDimSize_Module_basic", - "SliceCopyNegative_Module_basic", - "SliceCopyNonZeroDim_Module_basic", - "SliceCopy_Module_basic", - "TupleModule_basic", - - # Failure - shape - "ArangeStartOutDtypeModule_basic", - "ArangeStartOutViewModule_basic", - "BroadcastDynamicDimModule_basic", - "BroadcastToModule_basic", - "EmbeddingModuleF16_basic", - "EmbeddingModuleI32_basic", - "EmbeddingModuleI64_basic", - "ExpandModule_basic", - "MoveDimIntNegativeIndexModule_basic", - "PermuteNegativeIndexModule_basic", - "ReduceAmaxKeepDim_basic", - "ReduceMaxKeepDimReturnBoth_basic", - "ReduceMaxNegativeDim_basic", - "ViewSizeFromOtherTensor_basic", - # Failure - onnx traces differently - "ElementwiseSigmoidIntModule_basic", - # Failure - unknown + "BucketizeTensorFloatModule_basic", + "BucketizeTensorModule_basic", + "BucketizeTensorStaticFloatModule_basic", + "BucketizeTensorStaticModule_basic", "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "CopyWithDifferentDTypesAndSizesModule_basic", "CopyWithDifferentDTypesModule_basic", "CosineSimilarityStaticBroadcastModule_basic", "CumsumInputDtypeInt32Module_basic", - "ElementwiseAtan2TensorIntModule_basic", + "ElementwiseAcosIntModule_basic", + "ElementwiseAsinIntModule_basic", + "ElementwiseAtanTensorIntModule_basic", + "ElementwiseCosIntModule_basic", "ElementwiseDivRoundingModeTruncModule_basic", + "ElementwiseErfIntModule_basic", + "ElementwiseExpIntModule_basic", + "ElementwiseLogIntModule_basic", "ElementwisePreluModule_basic", + "ElementwiseSigmoidIntModule_basic", + "ElementwiseSinIntModule_basic", + "ElementwiseTanIntModule_basic", + "ElementwiseUnaryIntModule_basic", "ElementwiseUnsqueezeNegDimsModule_basic", "ElementwiseWhereScalarModule_basic", + "EmbeddingModule1DIndices_basic", + "EmbeddingModuleF16_basic", + "EmbeddingModuleI32_basic", + "EmbeddingModuleI64_basic", "FlattenDynamicModule_basic", - "FlipModuleStaticShape_basic", "GluStaticModule_basic", + "GroupNormModule_basic", + "GroupNormNoWeightAndBiasModule_basic", + "IndexSelectDynamicIndexSizeModule_basic", + "IndexSelectDynamicModulebasic", + "IndexTensorHackedTwinModule3dInput_basic", + "IndexTensorHackedTwinModule_basic", + "IndexTensorModule3dInput_basic", + "IndexTensorModule_basic", + "IndexTensorMultiInputContiguousOneDimDynamic_basic", + "IndexTensorMultiInputNonContiguousDynamic_basic", + "IndexTensorMultiInputNonContiguousOneDimDynamic_basic", + "IndexTensorSelectDimModule_basic", "MaskedFillTensorFloatValueModule_basic", "ReduceAllDimEmpty_basic", "ReduceAllDimFloat_basic", @@ -2180,8 +2221,6 @@ "ReduceMinAlongDimUnsignedInt_basic", "TensorsStackNegativeDimModule_basic", "TensorsStackPromoteDTypeModule_basic", - "FloatImplicitModule_basic", - "IntImplicitModule_basic", } ONNX_CRASHING_SET = { } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 704e03acc1e2..42be32166c5f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -926,107 +926,121 @@ func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor // ----- +// CHECK-LABEL: func.func @test_reduce_min_empty_set_fp +func.func @test_reduce_min_empty_set_fp(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[INF:.+]] = torch.constant.float 0x7FF0000000000000 + // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 + // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT1]], %[[INT4]] + // CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[INF]], %[[NONE]], %[[NONE]], %[[NONE]] + // CHECK: return %[[FULL]] + %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> + return %0 : !torch.vtensor<[2,1,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_min_empty_set_int +func.func @test_reduce_min_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[INF:.+]] = torch.constant.int 2147483647 + // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 + // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT1]], %[[INT4]] + // CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[INF]], %[[NONE]], %[[NONE]], %[[NONE]] + // CHECK: return %[[FULL]] + %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32> + return %0 : !torch.vtensor<[2,1,4],si32> +} + +// ----- + + // CHECK-LABEL: func.func @test_reduce_min_bool_inputs func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int2 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list + // CHECK: %[[IDX:.+]] = torch.constant.int 0 + // CHECK: %[[SZ:.+]] = torch.constant.int 0 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] + // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[C0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list // CHECK: %[[TRUE:.+]] = torch.constant.bool true - // CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4,1],i1> + // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LST]], %[[TRUE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4,1],i1> + // CHECK: return %[[AMIN]] : !torch.vtensor<[4,1],i1> %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> return %0 : !torch.vtensor<[4,1],i1> } -// CHECK-LABEL: func.func @test_reduce_min_default_axes_keepdims_example -func.func @test_reduce_min_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.Bool.int %int1 : !torch.int -> !torch.bool - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT1_0:.+]] = torch.constant.int 1 - // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.prim.ListConstruct %int0, %int1_0, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: torch.aten.amin %arg0, %1, %0 : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[1,1,1],f32> - %0 = torch.operator "onnx.ReduceMin"(%arg0) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> - return %0 : !torch.vtensor<[1,1,1],f32> -} +// ----- -// CHECK-LABEL: func.func @test_reduce_min_do_not_keepdims_example -func.func @test_reduce_min_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list +// CHECK-LABEL: func.func @test_reduce_min_bool_inputs_nokeepdims +func.func @test_reduce_min_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[IDX:.+]] = torch.constant.int 0 + // CHECK: %[[SZ:.+]] = torch.constant.int 0 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] + // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[C0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %6 : (!torch.int) -> !torch.list // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: torch.aten.amin %arg0, %6, %false : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[3,2],f32> - %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> - return %0 : !torch.vtensor<[3,2],f32> + // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> + // CHECK: return %[[AMIN]] : !torch.vtensor<[4],i1> + %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> } -// CHECK-LABEL: func.func @test_reduce_min_empty_set -func.func @test_reduce_min_empty_set(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: %[[TRUE:.+]] = torch.constant.bool true - // CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[2,0,4],f32>, !torch.list, !torch.bool -> !torch.vtensor<[2,1,4],f32> - %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> - return %0 : !torch.vtensor<[2,1,4],f32> -} +// ----- -// CHECK-LABEL: func.func @test_reduce_min_keepdims_example -func.func @test_reduce_min_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: %[[TRUE:.+]] = torch.constant.bool true - // CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[3,1,2],f32> - %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> - return %0 : !torch.vtensor<[3,1,2],f32> +// CHECK-LABEL: func.func @test_reduce_all_dims_default +func.func @test_reduce_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[I0:.+]] = torch.constant.int 0 + // CHECK: %[[I1:.+]] = torch.constant.int 1 + // CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[C0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I0]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[A0:.+]] = torch.aten.add.int %[[I0]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I1]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[A1:.+]] = torch.aten.add.int %[[I1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[A0]], %[[A1]] + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[MIN:.+]] = torch.aten.amin %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[],i1> + // CHECK: return %[[MIN]] : !torch.vtensor<[],i1> + %0 = torch.operator "onnx.ReduceMin"(%arg0) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> } -// CHECK-LABEL: func.func @test_reduce_min_negative_axes_keepdims_example -func.func @test_reduce_min_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// ----- + +func.func @test_reduce_min_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: %[[TRUE:.+]] = torch.constant.bool true - // CHECK: torch.aten.amin %arg0, %6, %true : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[3,1,2],f32> - %0 = torch.operator "onnx.ReduceMin"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> - return %0 : !torch.vtensor<[3,1,2],f32> + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[AMIN:.+]] = torch.aten.amin %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> + // CHECK: return %[[AMIN]] + %0 = torch.operator "onnx.ReduceMin"(%arg0) {torch.onnx.keepdims = 0 : si64, torch.onnx.axes=[1 : si64]} : (!torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> } // ----- From 08bc013fcd3232cbf01ad029f057b2fc022e56e1 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 28 Feb 2024 09:46:58 -0800 Subject: [PATCH 231/283] [tosa] Fix TOSA batch matmul lowering to correct transpose ordering (#2959) The corrective transpose at the end is computed incorrectly. Is it actually computin the inverse transpose. Inverting the permutations fixes the issue. --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 30 ++++++++++++++-------- projects/pt1/e2e_testing/xfail_sets.py | 2 ++ 2 files changed, 21 insertions(+), 11 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index ce0a1af2f834..93fe9dc1c4e8 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -1234,7 +1234,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { return false; }; - SmallVector commonElems, lhsSqueezedElems, rhsSqueezedElems; + SmallVector batchElems, lhsSqueezedElems, rhsSqueezedElems; if (!performBatchDimBroadcast) { // Simple with no broadcasting artifacts. Just reshape up to 3D @@ -1288,7 +1288,7 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { if (isDynamicDim || lhsBroadcastedShape[dim] == rhsBroadcastedShape[dim]) { commonValue *= lhsBroadcastedShape[dim]; - commonElems.push_back({dim, lhsBroadcastedShape[dim]}); + batchElems.push_back({dim, lhsBroadcastedShape[dim]}); } } commonValue = commonValue < 0 ? kUnknownSize : commonValue; @@ -1315,9 +1315,9 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { // Step: Create the tosa.transpose array. If this array has a // non-monotonic series of dims, perform transpose. // First the common_elems - for (uint32_t i = 0; i < commonElems.size(); i++) { - transposedLhsShape.push_back(commonElems[i].shape); - transposedLhsDims.push_back(commonElems[i].dim); + for (uint32_t i = 0; i < batchElems.size(); i++) { + transposedLhsShape.push_back(batchElems[i].shape); + transposedLhsDims.push_back(batchElems[i].dim); } // then the lhs_squeezed elems for (uint32_t i = 0; i < lhsSqueezedElems.size(); i++) { @@ -1373,9 +1373,9 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { // Step: Create the RHS transpose sequence // RHS = {common, matmul_dim, rhs_squeezed} // first the common_dims - for (uint32_t i = 0; i < commonElems.size(); i++) { - transposedRhsShape.push_back(commonElems[i].shape); - transposedRhsDims.push_back(commonElems[i].dim); + for (uint32_t i = 0; i < batchElems.size(); i++) { + transposedRhsShape.push_back(batchElems[i].shape); + transposedRhsDims.push_back(batchElems[i].dim); } // The matmul_dim of RHS transposedRhsDims.push_back(maxInputRank - 2); @@ -1497,9 +1497,9 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { // Step: Construct the output transpose/reshape information // First the common_dims - for (uint32_t i = 0; i < commonElems.size(); i++) { - reshapedOpShape.push_back(commonElems[i].shape); - transposedOpDims.push_back(commonElems[i].dim); + for (uint32_t i = 0; i < batchElems.size(); i++) { + reshapedOpShape.push_back(batchElems[i].shape); + transposedOpDims.push_back(batchElems[i].dim); } // Then the LHS squeezed dims @@ -1532,6 +1532,14 @@ class ConvertAtenMatmulBaseOp : public OpConversionPattern { transposedOpDims.push_back(maxInputRank - 1); } + // The transposition order is the inverse of what we actually want, + // inversing should fix this: + llvm::SmallVector inverseTransposeDims(transposedOpDims.size()); + for (int i = 0, s = transposedOpDims.size(); i < s; ++i) + inverseTransposeDims[transposedOpDims[i]] = i; + + transposedOpDims = inverseTransposeDims; + // Final transposed output shape construction for (uint32_t i = 0; i < maxInputRank - 2; i++) { if (lhsBroadcastedTy.isDynamicDim(i)) { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index a4ac58b1d909..67a4f175ddb6 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1125,6 +1125,7 @@ "Matmul4dStatic_basic", "Matmul_3d", "Matmul_dot", + "MatmulStaticBroadcast_basic", "MaxPool2dEmptyStrideStaticModule_basic", "MaxPool2dStaticCeilModeTrueModule_basic", "MaxPool2dStaticModule_basic", @@ -1303,6 +1304,7 @@ # Dynamic shape, has extra unsupported broadcast ops "Matmul_3d", + "MatmulStaticBroadcast_basic", # failed to legalize operation 'torch.aten.max_pool2d_with_indices "MaxPool2dEmptyStrideStaticModule_basic", From dd673cfa8de6f215e27eedca54e53fb2f114b65d Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 28 Feb 2024 09:47:06 -0800 Subject: [PATCH 232/283] [torch] Add edgecase for aten.shape_to_tensor for rank-0 input (#2962) Currently lowering uses `tensor.from_elements` which does not allow zero inputs. In this case we return a `tensor.empty` operation. --- lib/Conversion/TorchToTensor/TorchToTensor.cpp | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/lib/Conversion/TorchToTensor/TorchToTensor.cpp b/lib/Conversion/TorchToTensor/TorchToTensor.cpp index 1b5341028c6d..8b934ccb0484 100644 --- a/lib/Conversion/TorchToTensor/TorchToTensor.cpp +++ b/lib/Conversion/TorchToTensor/TorchToTensor.cpp @@ -84,6 +84,12 @@ class ConvertAtenShapeToTensorPatternOp getTypeConverter()->convertType(op.getType()).cast(); int64_t rank = operandTy.getRank(); + if (rank == 0) { + rewriter.replaceOpWithNewOp(op, resultTy.getShape(), + resultTy.getElementType()); + return success(); + } + SmallVector dims; for (int i = 0; i < rank; ++i) { Value dim = rewriter.createOrFold(loc, operand, i); From 73b6df9007d8691aca328e1f95991ffe9691ace4 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 28 Feb 2024 10:27:19 -0800 Subject: [PATCH 233/283] [torch] Fix DecomposeAtenInstanceNorm decomposition (#2960) The decomposition only suports a NCHW lowering however the operation can support arbitrary spatial dimensions. Updated the lowering to better support spatial dimensions. --- .../Torch/Transforms/DecomposeComplexOps.cpp | 64 ++++++++----------- 1 file changed, 28 insertions(+), 36 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 51a710d940e9..736d66544e2d 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -4025,25 +4025,20 @@ class DecomposeAtenInstanceNormOp auto inputTy = op.getInput().getType().cast(); int64_t inputRank = inputTy.getSizes().size(); - auto reduceDimInts = - llvm::SmallVector({inputRank - 2, inputRank - 1}); - SmallVector reducedShape(inputTy.getSizes()); - reducedShape[inputRank - 1] = 1; - reducedShape[inputRank - 2] = 1; + SmallVector reduceDimInts; + SmallVector reduceDimVals; + for (int i = 2; i < inputRank; ++i) { + reducedShape[i] = 1; + reduceDimVals.push_back(rewriter.create( + loc, rewriter.getI64IntegerAttr(i))); + } Type dtype = inputTy.getOptionalDtype(); Type reducedTy = ValueTensorType::get(op.getContext(), llvm::ArrayRef(reducedShape), dtype); auto sizeListType = ListType::get(IntType::get(context)); - SmallVector reduceDimVals; - reduceDimVals.reserve(reduceDimInts.size()); - std::transform(reduceDimInts.begin(), reduceDimInts.end(), - std::back_inserter(reduceDimVals), [&](int64_t d) { - return rewriter.create( - loc, rewriter.getI64IntegerAttr(d)); - }); Value reduceDimList = rewriter.create(loc, sizeListType, reduceDimVals); Value cstTrue = rewriter.create(loc, true); @@ -4069,9 +4064,12 @@ class DecomposeAtenInstanceNormOp loc, reducedTy, inputSubMeanSquare, reduceDimList, cstTrue, /*dtype=*/none); + int64_t elemCount = 1; + for (int i = 2; i < inputRank; ++i) + elemCount *= inputTy.getSizes()[i]; + Value hw = rewriter.create( - loc, rewriter.getI64IntegerAttr(inputTy.getSizes()[inputRank - 1] * - inputTy.getSizes()[inputRank - 2])); + loc, rewriter.getI64IntegerAttr(elemCount)); Value inputVar = rewriter.create(loc, reducedTy, variancesum, hw); @@ -4104,19 +4102,14 @@ class DecomposeAtenInstanceNormOp op.getContext(), llvm::ArrayRef(newWeightShape), dtype); weight = rewriter.create(loc, newWeightTy, weight, zero); - Value two = rewriter.create( - loc, rewriter.getI64IntegerAttr(2)); - newWeightShape.push_back(1); - newWeightTy = ValueTensorType::get(op.getContext(), - llvm::ArrayRef(newWeightShape), dtype); - weight = rewriter.create(loc, newWeightTy, weight, two); - - Value three = rewriter.create( - loc, rewriter.getI64IntegerAttr(3)); - newWeightShape.push_back(1); - newWeightTy = ValueTensorType::get(op.getContext(), - llvm::ArrayRef(newWeightShape), dtype); - weight = rewriter.create(loc, newWeightTy, weight, three); + while (static_cast(newWeightShape.size()) < inputRank) { + Value i = rewriter.create( + loc, rewriter.getI64IntegerAttr(newWeightShape.size())); + newWeightShape.push_back(1); + newWeightTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newWeightShape), dtype); + weight = rewriter.create(loc, newWeightTy, weight, i); + } Value weightExpanded = rewriter.create(loc, inputTy, weight, op.getInput()); @@ -4134,15 +4127,14 @@ class DecomposeAtenInstanceNormOp llvm::ArrayRef(newBiasShape), dtype); bias = rewriter.create(loc, newBiasTy, bias, zero); - newBiasShape.push_back(1); - newBiasTy = ValueTensorType::get(op.getContext(), - llvm::ArrayRef(newBiasShape), dtype); - bias = rewriter.create(loc, newBiasTy, bias, two); - - newBiasShape.push_back(1); - newBiasTy = ValueTensorType::get(op.getContext(), - llvm::ArrayRef(newBiasShape), dtype); - bias = rewriter.create(loc, newBiasTy, bias, three); + while (static_cast(newBiasShape.size()) < inputRank) { + Value i = rewriter.create( + loc, rewriter.getI64IntegerAttr(newBiasShape.size())); + newBiasShape.push_back(1); + newBiasTy = ValueTensorType::get(op.getContext(), + llvm::ArrayRef(newBiasShape), dtype); + bias = rewriter.create(loc, newBiasTy, bias, i); + } Value biasExpanded = rewriter.create(loc, inputTy, bias, op.getInput()); From 6f3d62ab04e91bbe67d51b3c0b467f12fc3ed870 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 28 Feb 2024 12:04:52 -0800 Subject: [PATCH 234/283] [torch] Fix folders and `cat` and `view` torch lowerings (#2963) A bunch of small fixes are interlinked and trigger crashes if not addressed as a group. This includes: - aten view when expand from a rank-0 tensor - slice folder with negative indices - `aten._shape_as_tensor` folder on a rank-0 tensor - `aten.cat` of a tensor with a length-0 tensor --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + lib/Conversion/TorchToLinalg/DataMovement.cpp | 38 ++++++---- lib/Dialect/Torch/IR/TorchOps.cpp | 76 ++++++++++++------- projects/pt1/e2e_testing/xfail_sets.py | 11 --- .../build_tools/torch_ods_gen.py | 2 +- test/Dialect/Torch/canonicalize.mlir | 16 +++- 6 files changed, 89 insertions(+), 55 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 57b15ed18f4e..9d245723fd84 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -8925,6 +8925,7 @@ def Torch_Aten_ShapeAsTensorOp : Torch_Op<"aten._shape_as_tensor", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; + let hasFolder = 1; } def Torch_AtenIsnanOp : Torch_Op<"aten.isnan", [ diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index d9132317e32f..42aacceab0b4 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -799,10 +799,15 @@ class ConvertAtenViewOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "result shape of rank 0 is invalid"); - // TODO: add support for case inputRank 0 expanded to size 1 - if (inputRank == 0) - return rewriter.notifyMatchFailure( - op, "unimplemented: input rank 0 is not supported"); + if (inputRank == 0) { + Value expanded = + rewriter + .create(loc, resultType, input, + ArrayRef()) + .getResult(); + rewriter.replaceOp(op, expanded); + return success(); + } // Extract the desired output size as a list of integers. This list should // have been created using the operation `torch.prim.ListConstruct`. @@ -1500,6 +1505,14 @@ class ConvertAtenCatOp : public OpConversionPattern { RankedTensorType newResultType = typeConverter->convertType(op.getType()).cast(); + int rank = newResultType.getRank(); + Value dimValue = op.getDim(); + int64_t dim; + if (!matchPattern(dimValue, m_TorchConstantInt(&dim))) + return op.emitError("unimplemented: dim is not constant"); + dim = toPositiveDim(dim, rank); + if (!isValidDim(dim, rank)) + return rewriter.notifyMatchFailure(op, "dim is statically invalid"); auto outElemType = newResultType.getElementType(); for (size_t i = 0; i < tensors.size(); ++i) { @@ -1510,17 +1523,16 @@ class ConvertAtenCatOp : public OpConversionPattern { } } - int rank = newResultType.getRank(); - Value dimValue = op.getDim(); - int64_t dim; - if (!matchPattern(dimValue, m_TorchConstantInt(&dim))) - return op.emitError("unimplemented: dim is not constant"); - dim = toPositiveDim(dim, rank); - if (!isValidDim(dim, rank)) - return rewriter.notifyMatchFailure(op, "dim is statically invalid"); + llvm::SmallVector filteredTensors; + for (auto tensor : tensors) { + auto inputType = cast(tensor.getType()); + if (inputType.getDimSize(dim) != 0) { + filteredTensors.push_back(tensor); + } + } rewriter.replaceOpWithNewOp(op, newResultType, dim, - tensors); + filteredTensors); return success(); } }; diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 2f0884b1344e..6120fd6f0e32 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2972,8 +2972,10 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { unaryNonDim &= inType.getSizes()[i] == 1 || i == dimInt; } if (unaryNonDim) { - Attribute value = - input.getValues()[start.getValue().getSExtValue()]; + int64_t idx = start.getValue().getSExtValue(); + if (idx < 0) + idx += input.getNumElements(); + Attribute value = input.getValues()[idx]; return DenseElementsAttr::get( outType.toBuiltinTensor().clone(inType.getDtype()), value); } @@ -3237,6 +3239,34 @@ OpFoldResult AtenTensorOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenTensorOp +//===----------------------------------------------------------------------===// + +OpFoldResult Aten_ShapeAsTensorOp::fold(FoldAdaptor adaptor) { + auto selfTy = dyn_cast(getSelf().getType()); + auto resultTy = dyn_cast(getType()); + if (!selfTy || !resultTy || !selfTy.hasSizes() || !resultTy.hasDtype() || + !resultTy.hasSizes()) + return {}; + + llvm::SmallVector values(selfTy.getSizes()); + if (llvm::any_of(values, [](int64_t d) { return d == Torch::kUnknownSize; })) + return {}; + + auto dty = dyn_cast(resultTy.getDtype()); + if (!dty) + return {}; + + llvm::SmallVector attrs; + for (auto val : values) { + attrs.push_back(IntegerAttr::get(dty, val)); + } + + auto attrty = RankedTensorType::get(resultTy.getSizes(), dty); + return DenseElementsAttr::get(attrty, attrs); +} + //===----------------------------------------------------------------------===// // AtenIntTensorOp //===----------------------------------------------------------------------===// @@ -3409,25 +3439,25 @@ OpFoldResult AtenItemOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) { SmallVector sizes; if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) { - LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: size operand is " - "not a list of constant integers.\n"); return nullptr; } Type resultType = getResult().getType(); BaseTensorType resultTensorType = resultType.dyn_cast(); if (!resultTensorType || !resultTensorType.hasDtype()) { - LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: result type is not " - "a tensor type or does not have a dtype.\n"); return nullptr; } + int64_t ct = sizes.size(); + if (resultTensorType.getSizes().size() != 1) + return nullptr; + if (resultTensorType.getSizes()[0] != ct) + return nullptr; + ShapedType shapedty = mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType sizes, resultTensorType.getDtype()); if (!shapedty) { - LLVM_DEBUG(llvm::dbgs() - << "Failing to fold AtenOnesOp: ShapedType cast failed.\n"); return nullptr; } auto elementType = shapedty.getElementType(); @@ -3439,33 +3469,31 @@ OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) { Attribute attribute = FloatAttr::get(elementType, 1.0); return DenseElementsAttr::get(shapedty, attribute); } - LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenOnesOp: element type is " - "not integer or float.\n"); return nullptr; } OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) { SmallVector sizes; if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) { - LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: size operand is " - "not a list of constant integers.\n"); return nullptr; } Type resultType = getResult().getType(); BaseTensorType resultTensorType = resultType.dyn_cast(); if (!resultTensorType || !resultTensorType.hasDtype()) { - LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: result type is " - "not a tensor type or does not have a dtype.\n"); return nullptr; } + int64_t ct = sizes.size(); + if (resultTensorType.getSizes().size() != 1) + return nullptr; + if (resultTensorType.getSizes()[0] != ct) + return nullptr; + ShapedType shapedty = mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType sizes, resultTensorType.getDtype()); if (!shapedty) { - LLVM_DEBUG(llvm::dbgs() - << "Failing to fold AtenZerosOp: ShapedType cast failed.\n"); return nullptr; } @@ -3479,33 +3507,31 @@ OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) { return DenseElementsAttr::get(shapedty, attribute); } - LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenZerosOp: element type is " - "not integer or float.\n"); return nullptr; } OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) { SmallVector sizes; if (!matchPattern(getSize(), m_TorchListOfConstantInts(sizes))) { - LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: size operand is " - "not a list of constant integers.\n"); return nullptr; } Type resultType = getResult().getType(); BaseTensorType resultTensorType = resultType.dyn_cast(); if (!resultTensorType || !resultTensorType.hasDtype()) { - LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: result type is not " - "a tensor type or does not have a dtype.\n"); return nullptr; } + int64_t ct = sizes.size(); + if (resultTensorType.getSizes().size() != 1) + return nullptr; + if (resultTensorType.getSizes()[0] != ct) + return nullptr; + ShapedType shapedty = mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType sizes, resultTensorType.getDtype()); if (!shapedty) { - LLVM_DEBUG(llvm::dbgs() - << "Failing to fold AtenFullOp: ShapedType cast failed.\n"); return nullptr; } auto elementType = shapedty.getElementType(); @@ -3523,8 +3549,6 @@ OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) { return DenseElementsAttr::get(shapedty, attribute); } } - LLVM_DEBUG(llvm::dbgs() << "Failing to fold AtenFullOp: element type is " - "not integer or float.\n"); return nullptr; } //===----------------------------------------------------------------------===// diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 67a4f175ddb6..60b08c02539e 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -64,14 +64,6 @@ # See also: https://github.com/pytorch/torchdynamo/issues/327 "AtenEmbeddingBagSumExample_basic", - # error: failed to legalize operation 'torch.valsem.aten.bernoulli.float' that was explicitly marked illegal - "BernoulliFloatModule_basic", - "BernoulliPModule_basic", - # error: failed to legalize operation 'torch.aten.view' that was explicitly marked illegal - "ElementwiseFlattenBroadcastModule_basic", - "FlattenRank0Module_basic", - "UniformModule_basic", - "UniformStaticShapeModule_basic", # error: unsupported by backend contract: tensor with unknown rank # note: see current operation: %1 = "torch.tensor_static_info_cast"(%arg0) : (!torch.vtensor<[5,4,3,2,1],f32>) -> !torch.vtensor<*,f32> "ElementwisePreluModule_basic", @@ -2150,7 +2142,6 @@ # Failure - torch.aten.view lower "AddSizeIntModule_basic", "ElementwiseFlattenBroadcastModule_basic", - "FlattenRank0Module_basic", "IndexTensorDyanmicInputContiguousWithNoneModule_basic", "IndexTensorDyanmicInputNonContiguousWithNoneModule_basic", "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", @@ -2163,7 +2154,6 @@ "IndexTensorStaticContiguousWithNoneModule_basic", "RepeatModule_basic", "SelectIntModule_basic", - "SelectIntNegativeDimAndIndexStaticModule_basic", "SliceSingleIdxModule_basic", "ViewFlattenAndExpandModule_basic", "ViewSizeDimFollowedByCollapsedOnesModule_basic", @@ -2205,7 +2195,6 @@ "FlattenDynamicModule_basic", "GluStaticModule_basic", "GroupNormModule_basic", - "GroupNormNoWeightAndBiasModule_basic", "IndexSelectDynamicIndexSizeModule_basic", "IndexSelectDynamicModulebasic", "IndexTensorHackedTwinModule3dInput_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index c81f543b5dc9..fed048a64340 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -582,7 +582,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::tensor.bool : (bool, int?, Device?, bool) -> (Tensor)") emit("aten::tensor.int : (int, int?, Device?, bool) -> (Tensor)") emit("aten::scalar_tensor : (Scalar, int?, int?, Device?, bool?) -> (Tensor)") - emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)") + emit("aten::_shape_as_tensor : (Tensor) -> (Tensor)", has_folder=True) emit("aten::isnan : (Tensor) -> (Tensor)") emit("aten::isinf : (Tensor) -> (Tensor)") emit("aten::isneginf : (Tensor) -> (Tensor)") diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 85b95eb1cdba..b3dd4c6f0641 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2081,20 +2081,18 @@ func.func @torch.aten.slice.tensor$fold_dim_1() -> (!torch.vtensor<[1, 1],si64>, // CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) { // CHECK-NOT: torch.aten.slice.Tensor // CHECK: %[[RET_0:.*]] = torch.vtensor.literal(dense<1.600000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> -// CHECK-NOT: torch.aten.slice.Tensor // CHECK: %[[RET_1:.*]] = torch.vtensor.literal(dense<6.400000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> -// CHECK-NOT: torch.aten.slice.Tensor // CHECK: return %[[RET_0]], %[[RET_1]] : !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32> func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1],f32>) { %tensor = torch.vtensor.literal(dense<[[2.0],[4.0],[8.0],[16.0],[32.0],[64.0],[128.0],[256.0],[512.0],[1024.0]]> : tensor<10x1xf32>) : !torch.vtensor<[10, 1],f32> %int0 = torch.constant.int 0 %int1 = torch.constant.int 1 - %int3 = torch.constant.int 3 + %intn7 = torch.constant.int -7 %int4 = torch.constant.int 4 %int5 = torch.constant.int 5 %int6 = torch.constant.int 6 %dim = torch.constant.int 0 - %0 = torch.aten.slice.Tensor %tensor, %dim, %int3, %int4, %int1 : !torch.vtensor<[10, 1], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], f32> + %0 = torch.aten.slice.Tensor %tensor, %dim, %intn7, %int4, %int1 : !torch.vtensor<[10, 1], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], f32> %1 = torch.aten.slice.Tensor %tensor, %dim, %int5, %int6, %int1 : !torch.vtensor<[10, 1], f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[1, 1], f32> return %0, %1 : !torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1], f32> } @@ -2655,3 +2653,13 @@ func.func @aten_eq_tensor_dense_int() -> !torch.vtensor<[4],i1> { return %0 : !torch.vtensor<[4],i1> } +// ----- + +// CHECK-LABEL: @aten_shape_to_tensor +func.func @aten_shape_to_tensor(%arg0 : !torch.vtensor<[4,5,6],f32>) -> !torch.vtensor<[3],si32> { + // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<[4, 5, 6]> : tensor<3xsi32>) : !torch.vtensor<[3],si32> + %0 = torch.aten._shape_as_tensor %arg0 : !torch.vtensor<[4,5,6],f32> -> !torch.vtensor<[3],si32> + // CHECK: return %[[CST]] + return %0 : !torch.vtensor<[3],si32> +} + From e48fe4588631e7a37a2899f9d4cd5c4cbc967481 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 28 Feb 2024 12:18:02 -0800 Subject: [PATCH 235/283] [onnx] Import `onnx` import to pass remaining tests (#2951) Finish supporting importing the vast majority of `onnx` operations. This includes: - region support - region value inherentance - `torch.string` support - `torch.list` support - `torch.optional` support --- .../torch-mlir/Dialect/Torch/IR/TorchOps.td | 13 +- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 3 +- lib/Dialect/Torch/IR/TorchTypes.cpp | 6 + .../Transforms/AbstractInterpLibrary.cpp | 22 +-- python/torch_mlir/extras/onnx_importer.py | 176 +++++++++++++++--- .../python/onnx_importer/import_smoke_test.py | 145 --------------- 6 files changed, 185 insertions(+), 180 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td index c86244f5f1e3..f5214db58f19 100644 --- a/include/torch-mlir/Dialect/Torch/IR/TorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/TorchOps.td @@ -843,12 +843,23 @@ def Torch_OperatorOp : Torch_Op<"operator", [ let arguments = (ins StrAttr:$name, Variadic:$operands); let results = (outs Variadic:$results); + let regions = (region VariadicRegion:$regions); let assemblyFormat = [{ - $name `(` $operands `)` attr-dict `:` functional-type($operands, $results) + $name `(` $operands `)` attr-dict `:` functional-type($operands, $results) $regions }]; } +def Torch_OperatorTerminatorOp : Torch_Op<"operator_terminator", [Terminator, + HasParent<"::mlir::torch::Torch::OperatorOp">]> { + let summary = "Implicit terminator for torch.operator"; + + let arguments = (ins Variadic:$operands); + let results = (outs); + + let assemblyFormat = "$operands attr-dict `:` type($operands)"; +} + def Torch_LinearParamsCreateOp : Torch_Op<"linear_params.create", [ AllowsTypeRefinement, AllowedInModuleInitializer, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index adf6d1cb639a..3deba85a6c77 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -181,7 +181,8 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( c = rewriter .create(binder.getLoc(), cTy, newOperands, - newAttributes) + newAttributes, + binder.op->getRegions().size()) .getResult(0); Value outScale = rewriter.create( diff --git a/lib/Dialect/Torch/IR/TorchTypes.cpp b/lib/Dialect/Torch/IR/TorchTypes.cpp index 7e3f37a7b870..b22c82b8a28f 100644 --- a/lib/Dialect/Torch/IR/TorchTypes.cpp +++ b/lib/Dialect/Torch/IR/TorchTypes.cpp @@ -190,6 +190,12 @@ static bool isValidTorchDtype(Type dtype) { // Builtin floating point types. if (dtype.isa()) return true; + if (dtype.isa()) + return true; + + if (dtype.isa()) + return true; // Builtin integer types. if (IntegerType type = dtype.dyn_cast()) { if (type.isSignless() && type.getWidth() == 1) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index a8327b0e0da6..b40ca2cdeb89 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -122,7 +122,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int0 = torch.constant.int 0\n" -" %0 = torch.operator \"aten.ge\"(%arg0, %int0) : (!torch.union, !torch.int) -> !torch.bool\n" +" %0 = torch.operator \"aten.ge\"(%arg0, %int0) : (!torch.union, !torch.int) -> !torch.bool \n" " torch.prim.If %0 -> () {\n" " torch.prim.If.yield\n" " } else {\n" @@ -138,14 +138,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int0 = torch.constant.int 0\n" -" %0 = torch.operator \"aten.ge\"(%arg1, %int0) : (!torch.union, !torch.int) -> !torch.bool\n" +" %0 = torch.operator \"aten.ge\"(%arg1, %int0) : (!torch.union, !torch.int) -> !torch.bool \n" " torch.prim.If %0 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = torch.operator \"aten.ge\"(%arg1, %arg0) : (!torch.union, !torch.union) -> !torch.bool\n" +" %1 = torch.operator \"aten.ge\"(%arg1, %arg0) : (!torch.union, !torch.union) -> !torch.bool \n" " torch.prim.If %1 -> () {\n" " torch.prim.If.yield\n" " } else {\n" @@ -162,16 +162,16 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" " %int0 = torch.constant.int 0\n" -" %0 = torch.operator \"aten.ne\"(%arg2, %int0) : (!torch.union, !torch.int) -> !torch.bool\n" +" %0 = torch.operator \"aten.ne\"(%arg2, %int0) : (!torch.union, !torch.int) -> !torch.bool \n" " torch.prim.If %0 -> () {\n" " torch.prim.If.yield\n" " } else {\n" " torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" " torch.prim.If.yield\n" " }\n" -" %1 = torch.operator \"aten.lt\"(%arg2, %int0) : (!torch.union, !torch.int) -> !torch.bool\n" +" %1 = torch.operator \"aten.lt\"(%arg2, %int0) : (!torch.union, !torch.int) -> !torch.bool \n" " torch.prim.If %1 -> () {\n" -" %6 = torch.operator \"aten.ge\"(%arg0, %arg1) : (!torch.union, !torch.union) -> !torch.bool\n" +" %6 = torch.operator \"aten.ge\"(%arg0, %arg1) : (!torch.union, !torch.union) -> !torch.bool \n" " torch.prim.If %6 -> () {\n" " torch.prim.If.yield\n" " } else {\n" @@ -180,7 +180,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " torch.prim.If.yield\n" " } else {\n" -" %6 = torch.operator \"aten.ge\"(%arg1, %arg0) : (!torch.union, !torch.union) -> !torch.bool\n" +" %6 = torch.operator \"aten.ge\"(%arg1, %arg0) : (!torch.union, !torch.union) -> !torch.bool \n" " torch.prim.If %6 -> () {\n" " torch.prim.If.yield\n" " } else {\n" @@ -5507,12 +5507,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " %15 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" " %16 = torch.aten.__getitem__.t %11, %int0 : !torch.list, !torch.int -> !torch.float\n" -" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float\n" +" %17 = torch.operator \"aten.mul.int_float\"(%15, %16) : (!torch.int, !torch.float) -> !torch.float \n" " %18 = torch.aten.Int.float %17 : !torch.float -> !torch.int\n" " %19 = torch.aten.append.t %1, %18 : !torch.list, !torch.int -> !torch.list\n" " %20 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" " %21 = torch.aten.__getitem__.t %11, %int1 : !torch.list, !torch.int -> !torch.float\n" -" %22 = torch.operator \"aten.mul.int_float\"(%20, %21) : (!torch.int, !torch.float) -> !torch.float\n" +" %22 = torch.operator \"aten.mul.int_float\"(%20, %21) : (!torch.int, !torch.float) -> !torch.float \n" " %23 = torch.aten.Int.float %22 : !torch.float -> !torch.int\n" " %24 = torch.aten.append.t %1, %23 : !torch.list, !torch.int -> !torch.list\n" " torch.prim.If.yield\n" @@ -7246,7 +7246,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %3 = torch.prim.If %2 -> (!torch.list) {\n" " %5 = torch.prim.ListConstruct %int1 : (!torch.int) -> !torch.list\n" " %6 = torch.aten.sub.int %1, %0 : !torch.int, !torch.int -> !torch.int\n" -" %7 = torch.operator \"aten.mul.left_t\"(%5, %6) : (!torch.list, !torch.int) -> !torch.list\n" +" %7 = torch.operator \"aten.mul.left_t\"(%5, %6) : (!torch.list, !torch.int) -> !torch.list \n" " %8 = torch.aten.add.t %7, %arg1 : !torch.list, !torch.list -> !torch.list\n" " torch.prim.If.yield %8 : !torch.list\n" " } else {\n" @@ -9304,7 +9304,7 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @__torch__.hacky_get_unknown_dimension_size() -> !torch.int {\n" " %0 = torch.prim.CreateObject !torch.nn.Module<\"__torch__.DummyClassType\">\n" " %1 = torch.prim.CallMethod %0[\"__init__\"] () : !torch.nn.Module<\"__torch__.DummyClassType\">, () -> !torch.none\n" -" %2 = torch.operator \"prim.id\"(%0) : (!torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.int\n" +" %2 = torch.operator \"prim.id\"(%0) : (!torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.int \n" " return %2 : !torch.int\n" " }\n" " func.func @__torch__.DummyClassType.__init__(%arg0: !torch.nn.Module<\"__torch__.DummyClassType\">) -> !torch.none {\n" diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index a0cfbf26ed30..91ee4c14c75a 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -128,7 +128,7 @@ def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto): # Generate the effective input map, which for old models can be a # subset of the input map. - if model_info.config.elide_initialized_inputs: + if model_info and model_info.config.elide_initialized_inputs: self.input_map = { k: v for k, v in self.declared_input_map.items() @@ -151,9 +151,8 @@ def find_type_proto_for_name(self, name: str) -> onnx.TypeProto: value_info = self.value_info_map.get(name) or self.output_map.get(name) if value_info is not None: return value_info.type - raise OnnxImportError( - f"No type information associated with '{name}'. Run shape inference?" - ) + # No type information is associated, this can occur when the value is unused: + return "" class OnnxImportError(Exception): @@ -252,7 +251,7 @@ def _populate_graph_attrs(self, container_op: Operation): "torch.onnx_meta.producer_version" ] = StringAttr.get(m.producer_version) - def import_all(self): + def import_all(self, func=True): """Imports all nodes topologically.""" # TODO: Consider pulling in initializers on demand since there can be so # much unused crap. @@ -272,7 +271,12 @@ def import_all(self): f"Non topologically produced ONNX graph output '{output_name}'" ) with InsertionPoint(self._b), Location.unknown(): - func_dialect.ReturnOp(outputs) + if func: + func_dialect.ReturnOp(outputs) + else: + Operation.create( + name="torch.operator_terminator", + operands=outputs) def get_none(self): if '' in self._nv_map: @@ -315,23 +319,24 @@ def import_node(self, node: onnx.NodeProto): for n in output_names ] - # TODO: Attributes. - attrs = { - "name": StringAttr.get(f"onnx.{op_type}"), - } - self.import_attributes(node.attribute, attrs) + attrs = self.import_attributes(node.attribute) + attrs["name"] = StringAttr.get(f"onnx.{op_type}") + regions = self.count_regions(node.attribute) + custom_op = Operation.create( name="torch.operator", results=output_types, operands=input_values, attributes=attrs, + regions=regions ) + + self.import_regions(node.attribute, custom_op) for output_name, output_value in zip(output_names, custom_op.results): self._nv_map[output_name] = output_value - def import_attributes( - self, onnx_attrs: list[onnx.AttributeProto], attrs: dict[str, Attribute] - ): + def import_attributes(self, onnx_attrs: list[onnx.AttributeProto]): + attrs = {} for onnx_attr in onnx_attrs: attr_type = onnx_attr.type if attr_type not in ATTRIBUTE_TYPE_HANDLERS: @@ -351,6 +356,38 @@ def import_attributes( ) result = handler(onnx_attr, self._cc) attrs[f"torch.onnx.{onnx_attr.name}"] = result + return attrs + + def count_regions(self, onnx_attrs: list[onnx.AttributeProto]): + count = 0 + for onnx_attr in onnx_attrs: + if onnx_attr.type == onnx.AttributeProto.AttributeType.GRAPH: + count += 1 + return count + + def import_regions(self, onnx_attrs: list[onnx.AttributeProto], op): + attr_map = {} + for onnx_attr in onnx_attrs: + attr_type = onnx_attr.type + if attr_type != onnx.AttributeProto.AttributeType.GRAPH: + continue + attr_map[onnx_attr.name] = onnx_attr + + for name, region in zip(sorted(attr_map.keys()), op.regions): + attr = attr_map[name] + block_types = [self._cc.type_proto_to_type(input.type) for input in attr.g.input] + block_names = [input.name for input in attr.g.input] + region.blocks.append(*block_types, arg_locs=[op.location] * len(block_types)) + block = region.blocks[0] + graph_info = GraphInfo(None, attr.g) + imp = NodeImporter(graph_info, parent_op=op, block=block, context_cache=self._cc) + + for node_name, input_value in zip(block_names, block.arguments): + imp._nv_map[node_name] = input_value + for k in self._nv_map: + imp._nv_map[k] = self._nv_map[k] + + imp.import_all(False) def import_initializer(self, initializer: onnx.TensorProto, extern_name: str = None) -> Value: # If an explicitly specified name is given, use that; otherwise, pick @@ -414,12 +451,16 @@ class ContextCache: __slots__ = [ "_c", "_elem_type_map", + "_list_type_map", + "_optional_type_map", "_vtensor_type_map", ] def __init__(self, context: Context): self._c = context self._elem_type_map: dict[int, IrType] = {} + self._list_type_map:dict[str, IrType] = {} + self._optional_type_map:dict[str, IrType] = {} self._vtensor_type_map: dict[tuple[tuple[Optional[int]], IrType], IrType] = {} def tensor_element_type(self, elem_type: int) -> IrType: @@ -436,6 +477,67 @@ def tensor_element_type(self, elem_type: int) -> IrType: def get_none_type(self): return IrType.parse("!torch.none", context=self._c) + def get_list_type(self, element_type: IrType) -> IrType: + key = str(element_type) + t = self._list_type_map.get(key) + if t is None: + asm = f"!torch.list<{str(element_type)}>" + try: + t = IrType.parse(asm, context=self._c) + except MLIRError as e: + raise OnnxImportError( + f"Unparseable torch type (MLIR asm format bug?): {asm}" + ) from e + self._list_type_map[key] = t + return t + + + def get_optional_type(self, element_type: IrType) -> IrType: + key = str(element_type) + t = self._optional_type_map.get(key) + if t is None: + asm = f"!torch.optional<{str(element_type)}>" + try: + t = IrType.parse(asm, context=self._c) + except MLIRError as e: + raise OnnxImportError( + f"Unparseable torch type (MLIR asm format bug?): {asm}" + ) from e + self._optional_type_map[key] = t + return t + + + def get_list_element_type(self, tp: onnx.TypeProto) -> IrType: + tt = tp.tensor_type + if tt.elem_type: + element_type = self.tensor_element_type(tt.elem_type) + dims = tuple( + (d.dim_value if not d.dim_param else None) for d in tt.shape.dim + ) + shape_asm = ",".join("?" if d is None else str(d) for d in dims) + return f"vtensor<[{shape_asm}],{element_type}>" + + raise OnnxImportError( + f"Unsupport list element type") + + def get_optional_element_type(self, tp: onnx.TypeProto) -> IrType: + st = tp.sequence_type + tt = tp.tensor_type + if tt.elem_type: + element_type = self.tensor_element_type(tt.elem_type) + dims = tuple( + (d.dim_value if not d.dim_param else None) for d in tt.shape.dim + ) + shape_asm = ",".join("?" if d is None else str(d) for d in dims) + return f"vtensor<[{shape_asm}],{element_type}>" + + if st.elem_type: + element_type = self.get_list_element_type(st.elem_type) + return f"list<{element_type}>" + + raise OnnxImportError( + f"Unsupport optional element type") + def get_vtensor_type( self, dims: tuple[Optional[int]], element_type: IrType ) -> IrType: @@ -461,11 +563,20 @@ def tensor_proto_to_builtin_type(self, tp: onnx.TensorProto) -> IrType: element_type = self.tensor_element_type(tp.data_type) # TODO: Fixme upstream: RankedTensorType.get should not require a location. with Location.unknown(): - return RankedTensorType.get(tuple(tp.dims), element_type) + try: + return RankedTensorType.get(tuple(tp.dims), element_type) + except TypeError as e: + raise OnnxImportError( + f"Unsupported builtin tensor type" + ) from e + def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType: - if tp.tensor_type: - tt = tp.tensor_type + if tp == "": + return self.get_none_type() + + tt = tp.tensor_type + if tt.elem_type: if not tt.shape: raise OnnxImportError( f"Unsupported Tensor type without shape (run shape inference?): {tp}" @@ -475,10 +586,20 @@ def type_proto_to_type(self, tp: onnx.TypeProto) -> IrType: (d.dim_value if not d.dim_param else None) for d in tt.shape.dim ) return self.get_vtensor_type(dims, element_type) - else: - # TODO: Others if ever needed. Or we consider ourselves DNN-only. - # See TypeProto: sequence_type, map_type, optional_type, sparse_tensor_type. - raise OnnxImportError(f"Unsupported ONNX TypeProto: {tp}") + + st = tp.sequence_type + if len(str(st.elem_type)) > 0: + element_type = self.get_list_element_type(st.elem_type) + return self.get_list_type(element_type) + + ot = tp.optional_type + if len(str(ot.elem_type)) > 0: + element_type = self.get_optional_element_type(ot.elem_type) + return self.get_optional_type(element_type) + + # TODO: Others if ever needed. Or we consider ourselves DNN-only. + # See TypeProto: sequence_type, map_type, optional_type, sparse_tensor_type. + raise OnnxImportError(f"Unsupported ONNX TypeProto: {tp}") def _sanitize_name(self, name): if not name.isidentifier(): @@ -524,6 +645,7 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: onnx.TensorProto.DataType.FLOAT8E4M3FNUZ: lambda: Float8E5M2FNUZType.get(), onnx.TensorProto.DataType.FLOAT8E5M2: lambda: Float8E5M2Type.get(), onnx.TensorProto.DataType.FLOAT8E5M2FNUZ: lambda: Float8E5M2FNUZType.get(), + onnx.TensorProto.DataType.STRING: lambda: "!torch.str", # Ommitted: STRING, } @@ -546,6 +668,16 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: onnx.TensorProto.DataType.FLOAT: lambda tp: DenseElementsAttr.get( np.asarray(tp.float_data, dtype=np.float32).reshape(tp.dims), signless=False ), + onnx.TensorProto.DataType.BOOL: lambda tp: DenseElementsAttr.get( + np.packbits(np.asarray(tp.int32_data, dtype=np.bool_).reshape(tp.dims), + axis=None, bitorder="little"), signless=False + ), + onnx.TensorProto.DataType.INT8: lambda tp: DenseElementsAttr.get( + np.asarray(tp.int32_data, dtype=np.int8).reshape(tp.dims), signless=False + ), + onnx.TensorProto.DataType.INT16: lambda tp: DenseElementsAttr.get( + np.asarray(tp.int32_data, dtype=np.int16).reshape(tp.dims), signless=False + ), onnx.TensorProto.DataType.INT32: lambda tp: DenseElementsAttr.get( np.asarray(tp.int32_data, dtype=np.int32).reshape(tp.dims), signless=False ), @@ -605,7 +737,7 @@ def tensor_proto_to_attr(self, tp: onnx.TensorProto) -> Attribute: onnx.AttributeProto.AttributeType.TENSOR: lambda a, cc: cc.tensor_proto_to_attr( a.t ), - onnx.AttributeProto.AttributeType.GRAPH: False, + onnx.AttributeProto.AttributeType.GRAPH: None, onnx.AttributeProto.AttributeType.SPARSE_TENSOR: False, onnx.AttributeProto.AttributeType.TYPE_PROTO: False, onnx.AttributeProto.AttributeType.FLOATS: lambda a, cc: ArrayAttr.get( diff --git a/test/python/onnx_importer/import_smoke_test.py b/test/python/onnx_importer/import_smoke_test.py index 22d460050cae..bd687ae37049 100644 --- a/test/python/onnx_importer/import_smoke_test.py +++ b/test/python/onnx_importer/import_smoke_test.py @@ -44,155 +44,10 @@ OUTPUT_PATH.mkdir(parents=True, exist_ok=True) TEST_CAST_XFAILS = [ - "light_light_bvlc_alexnet", - "light_light_inception_v1", - "light_light_squeezenet", - "light_light_vgg19", - "node_test_affine_grid_2d_align_corners_expanded_model", - "node_test_affine_grid_2d_expanded_model", - "node_test_affine_grid_3d_align_corners_expanded_model", - "node_test_affine_grid_3d_expanded_model", - "node_test_ai_onnx_ml_label_encoder_string_int_model", - "node_test_ai_onnx_ml_label_encoder_string_int_no_default_model", "node_test_ai_onnx_ml_label_encoder_tensor_mapping_model", - "node_test_ai_onnx_ml_label_encoder_tensor_value_only_mapping_model", - "node_test_cast_FLOAT16_to_FLOAT8E4M3FNUZ_model", - "node_test_cast_FLOAT16_to_FLOAT8E4M3FN_model", - "node_test_cast_FLOAT16_to_FLOAT8E5M2FNUZ_model", - "node_test_cast_FLOAT16_to_FLOAT8E5M2_model", - "node_test_cast_FLOAT8E4M3FNUZ_to_FLOAT16_model", - "node_test_cast_FLOAT8E4M3FNUZ_to_FLOAT_model", - "node_test_cast_FLOAT8E4M3FN_to_FLOAT16_model", - "node_test_cast_FLOAT8E4M3FN_to_FLOAT_model", - "node_test_cast_FLOAT8E5M2FNUZ_to_FLOAT16_model", - "node_test_cast_FLOAT8E5M2FNUZ_to_FLOAT_model", - "node_test_cast_FLOAT8E5M2_to_FLOAT16_model", - "node_test_cast_FLOAT8E5M2_to_FLOAT_model", - "node_test_cast_FLOAT_to_FLOAT8E4M3FNUZ_model", - "node_test_cast_FLOAT_to_FLOAT8E4M3FN_model", - "node_test_cast_FLOAT_to_FLOAT8E5M2FNUZ_model", - "node_test_cast_FLOAT_to_FLOAT8E5M2_model", - "node_test_cast_FLOAT_to_STRING_model", - "node_test_cast_STRING_to_FLOAT_model", - "node_test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FNUZ_model", - "node_test_cast_no_saturate_FLOAT16_to_FLOAT8E4M3FN_model", - "node_test_cast_no_saturate_FLOAT16_to_FLOAT8E5M2FNUZ_model", - "node_test_cast_no_saturate_FLOAT16_to_FLOAT8E5M2_model", - "node_test_cast_no_saturate_FLOAT_to_FLOAT8E4M3FNUZ_model", - "node_test_cast_no_saturate_FLOAT_to_FLOAT8E4M3FN_model", - "node_test_cast_no_saturate_FLOAT_to_FLOAT8E5M2FNUZ_model", - "node_test_cast_no_saturate_FLOAT_to_FLOAT8E5M2_model", - "node_test_castlike_FLOAT8E4M3FNUZ_to_FLOAT_expanded_model", - "node_test_castlike_FLOAT8E4M3FNUZ_to_FLOAT_model", - "node_test_castlike_FLOAT8E4M3FN_to_FLOAT_expanded_model", - "node_test_castlike_FLOAT8E4M3FN_to_FLOAT_model", - "node_test_castlike_FLOAT8E5M2FNUZ_to_FLOAT_expanded_model", - "node_test_castlike_FLOAT8E5M2FNUZ_to_FLOAT_model", - "node_test_castlike_FLOAT8E5M2_to_FLOAT_expanded_model", - "node_test_castlike_FLOAT8E5M2_to_FLOAT_model", - "node_test_castlike_FLOAT_to_FLOAT8E4M3FNUZ_expanded_model", - "node_test_castlike_FLOAT_to_FLOAT8E4M3FNUZ_model", - "node_test_castlike_FLOAT_to_FLOAT8E4M3FN_expanded_model", - "node_test_castlike_FLOAT_to_FLOAT8E4M3FN_model", - "node_test_castlike_FLOAT_to_FLOAT8E5M2FNUZ_expanded_model", - "node_test_castlike_FLOAT_to_FLOAT8E5M2FNUZ_model", - "node_test_castlike_FLOAT_to_FLOAT8E5M2_expanded_model", - "node_test_castlike_FLOAT_to_FLOAT8E5M2_model", - "node_test_castlike_FLOAT_to_STRING_expanded_model", - "node_test_castlike_FLOAT_to_STRING_model", - "node_test_castlike_STRING_to_FLOAT_expanded_model", - "node_test_castlike_STRING_to_FLOAT_model", - "node_test_dequantizelinear_e4m3fn_model", - "node_test_dequantizelinear_e4m3fn_zero_point_model", - "node_test_dequantizelinear_e5m2_model", - "node_test_equal_string_broadcast_model", - "node_test_equal_string_model", - "node_test_gru_defaults_model", - "node_test_gru_seq_length_model", - "node_test_gru_with_initial_bias_model", - "node_test_identity_opt_model", - "node_test_identity_sequence_model", - "node_test_if_model", "node_test_if_opt_model", - "node_test_if_seq_model", - "node_test_loop11_model", - "node_test_loop13_seq_model", - "node_test_loop16_seq_none_model", - "node_test_lstm_defaults_model", - "node_test_lstm_with_initial_bias_model", - "node_test_lstm_with_peepholes_model", - "node_test_optional_get_element_optional_sequence_model", - "node_test_optional_get_element_optional_tensor_model", - "node_test_optional_get_element_sequence_model", - "node_test_optional_has_element_empty_optional_input_model", - "node_test_optional_has_element_optional_input_model", - "node_test_optional_has_element_tensor_input_model", - "node_test_quantizelinear_e4m3fn_model", - "node_test_quantizelinear_e5m2_model", - "node_test_range_float_type_positive_delta_expanded_model", - "node_test_range_int32_type_negative_delta_expanded_model", - "node_test_regex_full_match_basic_model", - "node_test_regex_full_match_email_domain_model", - "node_test_regex_full_match_empty_model", - "node_test_rnn_seq_length_model", - "node_test_scan9_sum_model", - "node_test_scan_sum_model", - "node_test_sequence_insert_at_back_model", - "node_test_sequence_insert_at_front_model", - "node_test_sequence_map_add_1_sequence_1_tensor_expanded_model", - "node_test_sequence_map_add_1_sequence_1_tensor_model", - "node_test_sequence_map_add_2_sequences_expanded_model", - "node_test_sequence_map_add_2_sequences_model", - "node_test_sequence_map_extract_shapes_expanded_model", - "node_test_sequence_map_extract_shapes_model", - "node_test_sequence_map_identity_1_sequence_1_tensor_expanded_model", - "node_test_sequence_map_identity_1_sequence_1_tensor_model", - "node_test_sequence_map_identity_1_sequence_expanded_model", - "node_test_sequence_map_identity_1_sequence_model", - "node_test_sequence_map_identity_2_sequences_expanded_model", - "node_test_sequence_map_identity_2_sequences_model", - "node_test_simple_rnn_defaults_model", - "node_test_simple_rnn_with_initial_bias_model", - "node_test_split_to_sequence_1_model", - "node_test_split_to_sequence_2_model", - "node_test_split_to_sequence_nokeepdims_model", - "node_test_string_concat_broadcasting_model", - "node_test_string_concat_empty_string_model", - "node_test_string_concat_model", - "node_test_string_concat_utf8_model", - "node_test_string_concat_zero_dimensional_model", - "node_test_string_split_basic_model", - "node_test_string_split_consecutive_delimiters_model", - "node_test_string_split_empty_string_delimiter_model", - "node_test_string_split_empty_tensor_model", - "node_test_string_split_maxsplit_model", - "node_test_string_split_no_delimiter_model", - "node_test_strnormalizer_export_monday_casesensintive_lower_model", - "node_test_strnormalizer_export_monday_casesensintive_nochangecase_model", - "node_test_strnormalizer_export_monday_casesensintive_upper_model", - "node_test_strnormalizer_export_monday_empty_output_model", - "node_test_strnormalizer_export_monday_insensintive_upper_twodim_model", - "node_test_strnormalizer_nostopwords_nochangecase_model", - "simple_test_sequence_model1_model", - "simple_test_sequence_model2_model", - "simple_test_sequence_model3_model", - "simple_test_sequence_model4_model", - "simple_test_sequence_model5_model", - "simple_test_sequence_model6_model", - "simple_test_sequence_model7_model", - "simple_test_sequence_model8_model", - "simple_test_strnorm_model_monday_casesensintive_lower_model", - "simple_test_strnorm_model_monday_casesensintive_nochangecase_model", - "simple_test_strnorm_model_monday_casesensintive_upper_model", - "simple_test_strnorm_model_monday_empty_output_model", - "simple_test_strnorm_model_monday_insensintive_upper_twodim_model", - "simple_test_strnorm_model_nostopwords_nochangecase_model", ] - - - - class ImportSmokeTest(unittest.TestCase): @classmethod def setUpClass(cls): From 5437f32193888f4cae3b4ae02123bd8828564ad6 Mon Sep 17 00:00:00 2001 From: Andreas Falkenberg <149819731+afalkenberg1@users.noreply.github.com> Date: Wed, 28 Feb 2024 13:52:15 -0800 Subject: [PATCH 236/283] [onnx][torch] Lower `onnx.grid_sampler` to the `torch` equivalents (#2952) This is the lowering of gridsampler from onnx to torch using our prior implementation of AtenGridSamplerOp. Here are several checks for cornercases implemented. We may decide to have part of these checks in AtenGridSamplerOp instead of the onnx lowering portion. --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 67 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 12 ++++ 2 files changed, 79 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 12b7ab559f4f..4d1aaf42d679 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -92,6 +92,73 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( operand, vApproximate); return success(); }); + patterns.onOp( + "GridSample", 20, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value input; + Value grid; + if (binder.tensorOperandAtIndex(input, 0) || + binder.tensorOperandAtIndex(grid, 1) || + binder.tensorResultType(resultType)) + return rewriter.notifyMatchFailure( + binder.op, "operand grid_sampler bind failure"); + + auto inputTensorType = input.getType().cast(); + ArrayRef inputShape = inputTensorType.getSizes(); + uint32_t inputRank = inputShape.size(); + auto gridTensorType = grid.getType().cast(); + ArrayRef gridShape = gridTensorType.getSizes(); + uint32_t gridRank = gridShape.size(); + + if (inputRank != 4) + return rewriter.notifyMatchFailure(binder.op, + "only input rank 4 supported"); + if (gridRank != 4) + return rewriter.notifyMatchFailure(binder.op, + "only grid rank 4 supported"); + if (inputShape[0] != gridShape[0]) + return rewriter.notifyMatchFailure( + binder.op, "N must be same for input and grid"); + if (gridShape[3] != 2) + return rewriter.notifyMatchFailure(binder.op, + "gridShape[3] expected to be 2"); + std::string mode; + if (binder.customOpNameStringAttr(mode, "mode", "bilinear")) + return rewriter.notifyMatchFailure(binder.op, "mode bind failure"); + if (mode != "bilinear") + return rewriter.notifyMatchFailure( + binder.op, "currently only mode : bilinear supported"); + std::string padding; + if (binder.customOpNameStringAttr(padding, "padding_mode", "zeros")) + return rewriter.notifyMatchFailure(binder.op, + "padding_mode bind failure"); + if (padding != "zeros") + return rewriter.notifyMatchFailure( + binder.op, "currently only padding_mode : zeros supported"); + int64_t align; + if (binder.s64IntegerAttr(align, "align_corners", 0)) + return rewriter.notifyMatchFailure(binder.op, + "align_corners bind failure"); + if (align != 0) + return rewriter.notifyMatchFailure( + binder.op, "currently only align_corners : 0 supported"); + + Value interpolationMode = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + Value paddingMode = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + Value alignCorners = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getBoolAttr(false)); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, input, grid, interpolationMode, paddingMode, + alignCorners); + return success(); + }); patterns.onOp("Less", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 8729e7f2dd5a..d5a47aba353d 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -395,6 +395,18 @@ func.func @test_gelu_tanh_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso // ----- +// CHECK-LABEL: @test_grid_sampler +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT0_0:.*]] = torch.constant.int 0 +// CHECK: %[[B0:.*]] = torch.constant.bool false +// CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT0]], %[[INT0_0]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32> +func.func @test_grid_sampler(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %4 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 0 : si64, torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> + return %4 : !torch.vtensor<[?,?,?,?],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_less_or_equal func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 16 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: !torch.vtensor<[3,4,5],f32> From ed6e75908b959129c0abb571b849597a792f44bf Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 28 Feb 2024 14:13:26 -0800 Subject: [PATCH 237/283] Bump LLVM to llvm/llvm-project@e5ed7b6e2fd368b722b6359556cd0125881e7638 (#2964) --- externals/llvm-project | 2 +- projects/pt1/e2e_testing/xfail_sets.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/externals/llvm-project b/externals/llvm-project index bb180856ec28..e5ed7b6e2fd3 160000 --- a/externals/llvm-project +++ b/externals/llvm-project @@ -1 +1 @@ -Subproject commit bb180856ec28efe305dc77ca4bb3db12d8932edf +Subproject commit e5ed7b6e2fd368b722b6359556cd0125881e7638 diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 60b08c02539e..adb406f60405 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -989,6 +989,7 @@ "ElementwiseClampMaxModule_basic", "ElementwiseClampMinModule_basic", "ElementwiseClampModule_basic", + "ElementwiseClampTensorInt8Module_basic", "ElementwiseCloneChannelsLastMemoryFormatModule_basic", "ElementwiseCloneContiguousModule_basic", "ElementwiseCloneModule_basic", From e85a2a87c5662c26e047a2d93d3ff6216cabdcec Mon Sep 17 00:00:00 2001 From: Peiming Liu <36770114+PeimingLiu@users.noreply.github.com> Date: Wed, 28 Feb 2024 16:08:37 -0800 Subject: [PATCH 238/283] [torch-mlir][sparse] support e2e sparse kernels with COO inputs. (#2939) --- python/torch_mlir/extras/fx_importer.py | 2 +- test/python/fx_importer/sparse_test.py | 13 +++++++++---- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index e6d0f03deda4..2edfeb6cf340 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -273,7 +273,7 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str: if sparsity.layout is torch.sparse_coo: assert sparse_dim == 2 and blocksize is None # TODO: deeper sparse dims - lvls = f"d{batch_dim}:compressed(nonunique),d{batch_dim+1}:singleton" + lvls = f"d{batch_dim}:compressed(nonunique),d{batch_dim+1}:singleton(soa)" elif sparsity.layout is torch.sparse_csr: assert sparse_dim == 2 and blocksize is None lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed" diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 138942b07092..f4eebf6b1ca6 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -161,7 +161,13 @@ def sparse_jit(f, *args, **kwargs): for a in args: if a.layout is torch.sparse_coo: xargs.append(a.values().numpy()) - xargs.append(a.indices().numpy()) + # Construct the additional position array required by MLIR with data + # array([0, nnz]). + xargs.append(torch.tensor([0, a._nnz()], dtype=a.indices().dtype).numpy()) + # Transform a tensor into [tensor x ndim] to conform + # MLIR SoA COO representation. + for idx in a.indices(): + xargs.append(idx.numpy()) elif a.layout is torch.sparse_csr or a.layout is torch.sparse_bsr: xargs.append(a.values().numpy()) xargs.append(a.crow_indices().numpy()) @@ -254,7 +260,7 @@ def forward(self, x, v): @run # CHECK-LABEL: test_sparse_SpMM -# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton), posWidth = 64, crdWidth = 64 }> +# CHECK: #[[$COO:.*]] = #sparse_tensor.encoding<{ map = (d0, d1) -> (d0 : compressed(nonunique), d1 : singleton(soa)), posWidth = 64, crdWidth = 64 }> # CHECK: func.func @main( # CHECK-SAME: %[[A:.*0]]: !torch.vtensor<[8,8],f32,#[[$COO]]>, # CHECK-SAME: %[[B:.*1]]: !torch.vtensor<[8,8],f32>) -> !torch.vtensor<[8,8],f32> { @@ -286,8 +292,7 @@ def forward(self, x, y): print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. - # TODO: run with COO, right now only CSR works - sparse_input = dense_input.to_sparse_csr() + net = MatMulNet() res1 = net(sparse_input, dense_input) res2 = sparse_jit(net, sparse_input, dense_input) print("torch.sparse") From f21b76b68a411819df0795a2fe483b8eeb40d0f0 Mon Sep 17 00:00:00 2001 From: Aart Bik <39774503+aartbik@users.noreply.github.com> Date: Wed, 28 Feb 2024 17:14:00 -0800 Subject: [PATCH 239/283] [torch-mlir][sparse] fixed merge conflict (#2967) --- test/python/fx_importer/sparse_test.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index f4eebf6b1ca6..6d801a1d8799 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -292,7 +292,6 @@ def forward(self, x, y): print(m) # Run it with PyTorch torch.sparse and with TORCH-MLIR sparse_jit. - net = MatMulNet() res1 = net(sparse_input, dense_input) res2 = sparse_jit(net, sparse_input, dense_input) print("torch.sparse") From 76b81e0ccdefd05bd7d6026ee04c060e82b49751 Mon Sep 17 00:00:00 2001 From: mmakevic <150796284+mmakevic@users.noreply.github.com> Date: Thu, 29 Feb 2024 06:52:03 +0100 Subject: [PATCH 240/283] Implement lowering of torch.aten.fmod.Tensor (#2767) Closing https://github.com/nod-ai/SHARK-Turbine/issues/351 --- lib/Conversion/TorchToLinalg/Linear.cpp | 22 +++---- .../TorchToLinalg/Uncategorized.cpp | 62 +++++++++++++------ .../Transforms/AbstractInterpLibrary.cpp | 12 ++++ projects/pt1/e2e_testing/xfail_sets.py | 3 + .../build_tools/abstract_interp_lib_gen.py | 11 ++++ .../test_suite/elementwise.py | 62 +++++++++++++++++++ 6 files changed, 142 insertions(+), 30 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Linear.cpp b/lib/Conversion/TorchToLinalg/Linear.cpp index 6c04dd12f55a..44ac95ce0429 100644 --- a/lib/Conversion/TorchToLinalg/Linear.cpp +++ b/lib/Conversion/TorchToLinalg/Linear.cpp @@ -882,14 +882,14 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { if (bias.getType().isa()) { Value c0; if (resultDTy.isa()) { - c0 = rewriter.create( - loc, FloatAttr::get(resultDTy, 0.0)); + c0 = rewriter.create(loc, + FloatAttr::get(resultDTy, 0.0)); } else if (resultDTy.isa()) { - c0 = rewriter.create( - loc, IntegerAttr::get(resultDTy, 0)); + c0 = rewriter.create(loc, + IntegerAttr::get(resultDTy, 0)); } - outputTensor = rewriter.create(loc, c0, initTensor) - .getResult(0); + outputTensor = + rewriter.create(loc, c0, initTensor).getResult(0); } else { auto biasType = bias.getType().cast(); @@ -1058,11 +1058,11 @@ class ConvertAtenConvolutionOp : public OpConversionPattern { loc, collapsedType, weight, collapsedDims); conv = rewriter - .create( - loc, outputTensor.getType(), - ValueRange{paddedInput, collapsedWeight}, outputTensor, - stridesAttr, dilationAttr) - .getResult(0); + .create( + loc, outputTensor.getType(), + ValueRange{paddedInput, collapsedWeight}, outputTensor, + stridesAttr, dilationAttr) + .getResult(0); Type newResultType = getTypeConverter()->convertType(op.getType()); rewriter.replaceOpWithNewOp(op, newResultType, conv); diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 657ea460f76e..d28369cc560e 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1274,6 +1274,29 @@ static Value createLinalgPayloadCalculationForElementwiseOp( return result; } + if (auto fmod = dyn_cast(op)) { + Type newResultType = converter->convertType(fmod.getType()) + .cast() + .getElementType(); + + Value self = convertScalarToDtype(b, loc, payloadArgs[0], newResultType); + Value other = convertScalarToDtype(b, loc, payloadArgs[1], newResultType); + Value result; + + if (newResultType.isa()) { + Value n = b.create(loc, self, other); + n = b.create(loc, n); + Value n_y = b.create(loc, n, other); + result = b.create(loc, self, n_y); + } else if (newResultType.isa()) { + Value n = b.create(loc, self, other); + Value n_y = b.create(loc, n, other); + result = b.create(loc, self, n_y); + } else { + fmod.emitError("Unsupported type encountered for AtenFmodTensorOp."); + } + return result; + } if (auto reciprocal = dyn_cast(op)) { Type dtype = converter->convertType(reciprocal.getType()) .cast() @@ -1541,22 +1564,22 @@ class ConvertElementwiseOp : public ConversionPattern { AtenMulScalarOp, AtenLogOp, AtenErfOp, AtenSqrtOp, AtenFloorOp, AtenPowScalarOp, AtenPowTensorScalarOp, AtenPowTensorTensorOp, AtenLog2Op, AtenLog10Op, AtenLog1pOp, AtenRsqrtOp, AtenDivScalarOp, - AtenRemainderScalarOp, AtenRemainderTensorOp, AtenAbsOp, - AtenReciprocalOp, AtenBitwiseAndTensorOp, AtenBitwiseAndScalarOp, - AtenBitwiseOrTensorOp, AtenBitwiseXorTensorOp, - AtenBitwiseLeftShiftTensorOp, AtenBitwiseRightShiftTensorOp, - AtenGtScalarOp, AtenGeScalarOp, AtenEqScalarOp, AtenLtScalarOp, - AtenLeScalarOp, AtenWhereSelfOp, AtenCeilOp, AtenGtTensorOp, - AtenGeTensorOp, AtenEqTensorOp, AtenNeTensorOp, AtenLtTensorOp, - AtenLeTensorOp, AtenSubScalarOp, AtenAddScalarOp, AtenThresholdOp, - AtenThresholdBackwardOp, AtenHardtanhBackwardOp, AtenCloneOp, - AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenNegOp, - AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, - AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, - AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, - AtenFillTensorOp, AtenAtanOp, AtenAcosOp, AtenAtanhOp, AtenAcoshOp, - AtenAsinOp, AtenAsinhOp, AtenRealOp, AtenImagOp, - AtenDequantizeSelfOp, AtenDequantizeTensorOp, + AtenRemainderScalarOp, AtenRemainderTensorOp, AtenFmodTensorOp, + AtenAbsOp, AtenReciprocalOp, AtenBitwiseAndTensorOp, + AtenBitwiseAndScalarOp, AtenBitwiseOrTensorOp, + AtenBitwiseXorTensorOp, AtenBitwiseLeftShiftTensorOp, + AtenBitwiseRightShiftTensorOp, AtenGtScalarOp, AtenGeScalarOp, + AtenEqScalarOp, AtenLtScalarOp, AtenLeScalarOp, AtenWhereSelfOp, + AtenCeilOp, AtenGtTensorOp, AtenGeTensorOp, AtenEqTensorOp, + AtenNeTensorOp, AtenLtTensorOp, AtenLeTensorOp, AtenSubScalarOp, + AtenAddScalarOp, AtenThresholdOp, AtenThresholdBackwardOp, + AtenHardtanhBackwardOp, AtenCloneOp, AtenSinOp, AtenCosOp, + AtenNeScalarOp, AtenNegOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, + AtenLogicalAndOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, + AtenTriuOp, AtenTrilOp, AtenBitwiseNotOp, AtenRoundOp, + AtenFillScalarOp, AtenFillTensorOp, AtenAtanOp, AtenAcosOp, + AtenAtanhOp, AtenAcoshOp, AtenAsinOp, AtenAsinhOp, AtenRealOp, + AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp>(op)) return rewriter.notifyMatchFailure(op, "not a supported elementwise op"); @@ -2584,9 +2607,10 @@ void mlir::torch::torch_to_linalg::populateUncategorizedPatternsAndLegality( AtenCloneOp, AtenSinOp, AtenCosOp, AtenNeScalarOp, AtenMaskedFillTensorOp, AtenLogicalOrOp, AtenLogicalAndOp, AtenAtanOp, AtenAcosOp, AtenLogicalXorOp, AtenLogicalNotOp, AtenIsinfOp, AtenTriuOp, AtenTrilOp, - AtenRemainderScalarOp, AtenRemainderTensorOp, AtenBitwiseNotOp, - AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, AtenRealOp, AtenImagOp, - AtenDequantizeSelfOp, AtenDequantizeTensorOp, AtenQuantizePerTensorOp>(); + AtenRemainderScalarOp, AtenFmodTensorOp, AtenRemainderTensorOp, + AtenBitwiseNotOp, AtenRoundOp, AtenFillScalarOp, AtenFillTensorOp, + AtenRealOp, AtenImagOp, AtenDequantizeSelfOp, AtenDequantizeTensorOp, + AtenQuantizePerTensorOp>(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index b40ca2cdeb89..f9ad383d1000 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6865,6 +6865,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.fmod.Tensor\"(%arg0: !torch.list, %arg1: !torch.list) -> !torch.list {\n" +" %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.floor_divide.Scalar\"(%arg0: !torch.list, %arg1: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -11395,6 +11399,14 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%1, %3) : (!torch.list>, !torch.list) -> !torch.int\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.fmod.Tensor\"(%arg0: !torch.tuple, %arg1: !torch.tuple) -> !torch.int {\n" +" %0:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %1#0, %0#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.prim.ListConstruct %1#1, %0#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %4 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %3) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.floor_divide.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index adb406f60405..e2b198839436 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1641,6 +1641,9 @@ "ElementwiseOrTensorStaticShapeModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseRemainderTensorModule_Int_basic", + "ElementwiseFmodTensor_Float_basic", + "ElementwiseFmodTensor_Int_Float_basic", + "ElementwiseFmodTensor_Int_basic", "EmptyStridedModule_basic", "EmptyStridedSizeIntStrideModule_basic", "EqIntModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 99f4f2200d35..a9bf5640d5e3 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -438,6 +438,9 @@ def aten〇remainder〇Scalar〡shape(self: List[int], other: float) -> List[int def aten〇remainder〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +def aten〇fmod〇Tensor〡shape(self: List[int], other: List[int]) -> List[int]: + return upstream_shape_functions.unary(self) + def aten〇floor_divide〇Scalar〡shape(self: List[int], other: float) -> List[int]: return upstream_shape_functions.unary(self) @@ -3491,6 +3494,14 @@ def aten〇fmod〇Scalar〡dtype(self_rank_dtype: Tuple[int, int], other: Union[ dtypes = [self_dtype, get_dtype_of_scalar(other)] return promote_dtypes(ranks, dtypes) +@check_dtype_function(_check_two_tensor_op()) +def aten〇fmod〇Tensor〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int]) -> int: + other_rank, other_dtype = other_rank_dtype + self_rank, self_dtype = self_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + @check_dtype_function( _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1) + _check_tensors_with_the_same_dtype(num_of_tensors=1, error_types={torch.complex64, torch.complex128}, other=1.0)) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index 24bbe29194a2..d9921d23d677 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -2526,6 +2526,68 @@ def ElementwiseRemainderScalarModule_Bool_basic(module, tu: TestUtils): # ============================================================================== + +class ElementwiseFmodTensor_Float(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.float32, True), + ([-1], torch.float32, True) + ]) + def forward(self, x, y): + return torch.fmod(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseFmodTensor_Float()) +def ElementwiseFmodTensor_Float_basic(module, tu: TestUtils): + module.forward(tu.rand(100, low=-10, high=10), tu.rand(100, low=-10, high=10)) + +# ============================================================================== + +class ElementwiseFmodTensor_Int_Float(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int32, True), + ([-1], torch.float32, True) + ]) + def forward(self, x, y): + return torch.fmod(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseFmodTensor_Int_Float()) +def ElementwiseFmodTensor_Int_Float_basic(module, tu: TestUtils): + module.forward(tu.randint(100, low=-10, high=10).to(torch.int32), tu.rand(100, low=-10, high=10)) + +# ============================================================================== + +class ElementwiseFmodTensor_Int(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1], torch.int32, True), + ([-1], torch.int32, True), + ]) + def forward(self, x, y): + return torch.fmod(x, y) + + +@register_test_case(module_factory=lambda: ElementwiseFmodTensor_Int()) +def ElementwiseFmodTensor_Int_basic(module, tu: TestUtils): + module.forward(tu.randint(100, low=0, high=1000).to(torch.int32), tu.randint(100, low=1, high=1000).to(torch.int32)) + # ============================================================================== class ElementwiseRemainderTensorModule_Int_Float(torch.nn.Module): From 579ac8b66628b5707ca1a7c4c41fbf4c829b30b5 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 29 Feb 2024 21:48:46 +0530 Subject: [PATCH 241/283] [MLIR][TORCH] Fix OnnxToLinalg lowering issue for sub and sum op (#2954) This commit adds the support for scalar conversion to byte. This commit also fixes the OnnxToLinalg lowering issue for Onnx.Sub and Onnx.Sum op. Fixes https://github.com/nod-ai/SHARK-Turbine/issues/466 Fixes https://github.com/nod-ai/SHARK-Turbine/issues/467 Signed-Off By: Vivek Khandelwal --- include/torch-mlir/Conversion/Utils/Utils.h | 3 +- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 10 ++- .../TorchToLinalg/Uncategorized.cpp | 3 +- lib/Conversion/Utils/Utils.cpp | 33 +++++++-- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 69 +++++++++++++++++-- 5 files changed, 104 insertions(+), 14 deletions(-) diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index 516954b88fbc..b76efe869a0f 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -88,7 +88,8 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef shape, // should be converted builtin types. Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, std::optional srcOriginalDtype = std::nullopt, - std::optional dstOriginalDtype = std::nullopt); + std::optional dstOriginalDtype = std::nullopt, + std::optional originalScalar = std::nullopt); Value toPositiveValidDim(ConversionPatternRewriter &rewriter, Location loc, Value torchOptionalInt, Value builtinInt, diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 3deba85a6c77..b697a4fa2c48 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -489,8 +489,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( return success(); } // When binder.op->getNumOperands() > 2 - auto baseType = Torch::ValueTensorType::getWithLeastStaticInformation( - binder.op->getContext()); Value curr = rewriter.create( binder.getLoc(), resultType, valList[0], valList[1], const1); for (int i = 2; i < numOperands; i++) { @@ -498,6 +496,14 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( curr = rewriter.create( binder.getLoc(), resultType, curr, valList[i], const1); } else { + SmallVector resultBroadcastShapeInt; + SmallVector resultBroadcastShapeValue; + Torch::computeBroadcastShape(rewriter, binder.getLoc(), curr, + valList[i], resultBroadcastShapeInt, + resultBroadcastShapeValue); + auto baseType = Torch::ValueTensorType::get( + binder.op->getContext(), resultBroadcastShapeInt, + resultType.getOptionalDtype()); curr = rewriter.create( binder.getLoc(), baseType, curr, valList[i], const1); } diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index d28369cc560e..8b4297a62e17 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -645,7 +645,8 @@ static Value createLinalgPayloadCalculationForElementwiseOp( /*dstOriginalDtype=*/resultElementType); Value alpha = convertScalarToDtype(b, loc, adaptor.getAlpha(), dtype, /*srcOriginalDtype=*/std::nullopt, - /*dstOriginalDtype=*/resultElementType); + /*dstOriginalDtype=*/resultElementType, + /*originalScalar=*/sub.getAlpha()); if (dtype.isa()) { Value scaled = b.create(loc, rhs, alpha); return b.create(loc, lhs, scaled); diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 3df9da94b735..064215c51da0 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -245,12 +245,20 @@ mlir::RankedTensorType GetTypeFromTensorShape(llvm::ArrayRef shape, elementType, encoding); } +static std::optional getIntegerValue(Value scalar) { + if (auto constOp = scalar.getDefiningOp()) { + return std::optional(constOp.getValue()); + } + return std::optional(); +} + // Convert a scalar value to the target type. The scalar value can be an element // from a tensor or a scalar in the pytorch dialect. Both the scalar and dtype // should be converted builtin types. Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, std::optional srcOriginalDtype, - std::optional dstOriginalDtype) { + std::optional dstOriginalDtype, + std::optional originalScalar) { Type scalarType = scalar.getType(); if (scalarType == dtype) return scalar; @@ -262,7 +270,8 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, return false; }; - // We don't support conversion to Byte dtype. + // We support conversion to Byte dtype only if the original scalar is an + // integer constant with value lying between 0 - 63. if (isByteOrChar(dtype)) { if (!dstOriginalDtype.has_value()) { mlir::emitError(loc) @@ -271,10 +280,22 @@ Value convertScalarToDtype(OpBuilder &b, Location loc, Value scalar, Type dtype, return nullptr; } if (dstOriginalDtype->isUnsignedInteger()) { - mlir::emitError(loc) - << "unsupported: conversion to byte type for convertScalarToDtype " - << scalarType << "(scalar type) -> " << dtype << "(dtype)"; - return nullptr; + if (originalScalar.has_value()) { + std::optional optConstVal = + getIntegerValue(originalScalar.value()); + if (optConstVal.has_value()) { + int64_t constVal = optConstVal.value(); + if (constVal < 0 || constVal > 63) { + // Do the conversion only if the original integer value is between + // 0 - 63. + mlir::emitError(loc) + << "unsupported: conversion to byte type for " + "convertScalarToDtype " + << scalarType << "(scalar type) -> " << dtype << "(dtype)"; + return nullptr; + } + } + } } } diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 42be32166c5f..58b4287a41c5 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -223,6 +223,8 @@ func.func @test_scatter_elements_with_axis(%arg0: !torch.vtensor<[1,5],f32>, %ar return %0 : !torch.vtensor<[1,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_scatter_elements_with_duplicate_indices func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -232,6 +234,8 @@ func.func @test_scatter_elements_with_duplicate_indices(%arg0: !torch.vtensor<[1 return %0 : !torch.vtensor<[1,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_scatter_elements_without_axis func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>, %arg1: !torch.vtensor<[2,3],si64>, %arg2: !torch.vtensor<[2,3],f32>) -> !torch.vtensor<[3,3],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -240,6 +244,8 @@ func.func @test_scatter_elements_without_axis(%arg0: !torch.vtensor<[3,3],f32>, return %0 : !torch.vtensor<[3,3],f32> } +// ----- + // CHECK-LABEL: func.func @test_scatter_elements_with_reduction_mul func.func @test_scatter_elements_with_reduction_mul(%arg0: !torch.vtensor<[1,5],f32>, %arg1: !torch.vtensor<[1,2],si64>, %arg2: !torch.vtensor<[1,2],f32>) -> !torch.vtensor<[1,5],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -295,6 +301,8 @@ func.func @test_sub_bcast(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vten return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_sub_example func.func @test_sub_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -303,6 +311,8 @@ func.func @test_sub_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtenso return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: func.func @test_sub func.func @test_sub(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -311,6 +321,8 @@ func.func @test_sub(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3 return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_sub_uint8 func.func @test_sub_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vtensor<[3,4,5],ui8>) -> !torch.vtensor<[3,4,5],ui8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 14 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -324,19 +336,23 @@ func.func @test_sub_uint8(%arg0: !torch.vtensor<[3,4,5],ui8>, %arg1: !torch.vten // CHECK-LABEL: func.func @test_sum_example func.func @test_sum_example(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>, %arg2: !torch.vtensor<[3],f32>, %arg3: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 - // CHECK: torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> - // CHECK: torch.aten.add.Tensor %0, %arg2, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor - // CHECK: torch.aten.add.Tensor %1, %arg3, %int1 : !torch.vtensor, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: %[[SUM:.*]] = torch.aten.add.Tensor %arg0, %arg1, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: %[[SUM_1:.*]] = torch.aten.add.Tensor %[[SUM]], %arg2, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> + // CHECK: %[[SUM_2:.*]] = torch.aten.add.Tensor %[[SUM_1]], %arg3, %int1 : !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.int -> !torch.vtensor<[3],f32> %0 = torch.operator "onnx.Sum"(%arg0, %arg1, %arg2, %arg3) : (!torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>, !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: func.func @test_sum_one_input func.func @test_sum_one_input(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %0 = torch.operator "onnx.Sum"(%arg0) : (!torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> return %0 : !torch.vtensor<[3],f32> } +// ----- + // CHECK-LABEL: func.func @test_sum_two_inputs func.func @test_sum_two_inputs(%arg0: !torch.vtensor<[3],f32>, %arg1: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -370,6 +386,8 @@ func.func @test_xor2d(%arg0: !torch.vtensor<[3,4],i1>, %arg1: !torch.vtensor<[3, return %0 : !torch.vtensor<[3,4],i1> } +// ----- + // CHECK-LABEL: func.func @test_xor3d func.func @test_xor3d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.vtensor<[3,4,5],i1>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[3,4,5],i1> -> !torch.vtensor<[3,4,5],i1> @@ -377,6 +395,8 @@ func.func @test_xor3d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.vtensor<[ return %0 : !torch.vtensor<[3,4,5],i1> } +// ----- + // CHECK-LABEL: func.func @test_xor4d func.func @test_xor4d(%arg0: !torch.vtensor<[3,4,5,6],i1>, %arg1: !torch.vtensor<[3,4,5,6],i1>) -> !torch.vtensor<[3,4,5,6],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4,5,6],i1>, !torch.vtensor<[3,4,5,6],i1> -> !torch.vtensor<[3,4,5,6],i1> @@ -384,6 +404,8 @@ func.func @test_xor4d(%arg0: !torch.vtensor<[3,4,5,6],i1>, %arg1: !torch.vtensor return %0 : !torch.vtensor<[3,4,5,6],i1> } +// ----- + // CHECK-LABEL: func.func @test_xor_bcast3v1d func.func @test_xor_bcast3v1d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.vtensor<[5],i1>) -> !torch.vtensor<[3,4,5],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[3,4,5],i1>, !torch.vtensor<[5],i1> -> !torch.vtensor<[3,4,5],i1> @@ -391,6 +413,8 @@ func.func @test_xor_bcast3v1d(%arg0: !torch.vtensor<[3,4,5],i1>, %arg1: !torch.v return %0 : !torch.vtensor<[3,4,5],i1> } +// ----- + // CHECK-LABEL: func.func @test_xor_bcast4v4d func.func @test_xor_bcast4v4d(%arg0: !torch.vtensor<[1,4,1,6],i1>, %arg1: !torch.vtensor<[3,1,5,6],i1>) -> !torch.vtensor<[3,4,5,6],i1> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.logical_xor %arg0, %arg1 : !torch.vtensor<[1,4,1,6],i1>, !torch.vtensor<[3,1,5,6],i1> -> !torch.vtensor<[3,4,5,6],i1> @@ -417,6 +441,8 @@ func.func @test_squeeze(%arg0: !torch.vtensor<[1,3,4,5],f32>, %arg1: !torch.vten return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_squeeze_two_axes func.func @test_squeeze_two_axes(%arg0: !torch.vtensor<[3,1,4,5,1],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -467,6 +493,8 @@ func.func @test_unsqueeze_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor return %0 : !torch.vtensor<[1,3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_unsqueeze_axis_1 func.func @test_unsqueeze_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -491,6 +519,8 @@ func.func @test_unsqueeze_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor return %0 : !torch.vtensor<[3,1,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_unsqueeze_axis_2 func.func @test_unsqueeze_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,4,1,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -515,6 +545,8 @@ func.func @test_unsqueeze_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !tor return %0 : !torch.vtensor<[3,4,1,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_unsqueeze_negative_axes func.func @test_unsqueeze_negative_axes(%arg0: !torch.vtensor<[1,3,1,5],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[1,3,1,1,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -539,6 +571,8 @@ func.func @test_unsqueeze_negative_axes(%arg0: !torch.vtensor<[1,3,1,5],f32>, %a return %0 : !torch.vtensor<[1,3,1,1,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_unsqueeze_three_axes func.func @test_unsqueeze_three_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -585,6 +619,8 @@ func.func @test_unsqueeze_three_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: return %0 : !torch.vtensor<[3,4,1,5,1,1],f32> } +// ----- + // CHECK-LABEL: func.func @test_unsqueeze_unsorted_axes func.func @test_unsqueeze_unsorted_axes(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[3,4,1,5,1,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT0:.*]] = torch.constant.int 0 @@ -642,6 +678,8 @@ func.func @test_softmax_axis_0(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vte return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_softmax_axis_1 func.func @test_softmax_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -651,6 +689,8 @@ func.func @test_softmax_axis_1(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vte return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_softmax_axis_2 func.func @test_softmax_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT2:.*]] = torch.constant.int 2 @@ -660,6 +700,8 @@ func.func @test_softmax_axis_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vte return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_softmax_default_axis func.func @test_softmax_default_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT2:.*]] = torch.constant.int 2 @@ -669,6 +711,8 @@ func.func @test_softmax_default_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !tor return %0 : !torch.vtensor<[3,4,5],f32> } +// ----- + // CHECK-LABEL: func.func @test_softmax_large_number func.func @test_softmax_large_number(%arg0: !torch.vtensor<[2,4],f32>) -> !torch.vtensor<[2,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT1:.*]] = torch.constant.int 1 @@ -678,6 +722,8 @@ func.func @test_softmax_large_number(%arg0: !torch.vtensor<[2,4],f32>) -> !torch return %0 : !torch.vtensor<[2,4],f32> } +// ----- + // CHECK-LABEL: func.func @test_softmax_negative_axis func.func @test_softmax_negative_axis(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[INT2:.*]] = torch.constant.int 2 @@ -773,6 +819,8 @@ func.func @test_reduce_sum_default_axes_keepdims_example(%arg0: !torch.vtensor<[ return %0 : !torch.vtensor<[1,1,1],f32> } +// ----- + // CHECK-LABEL: func.func @test_reduce_sum_do_not_keepdims_example func.func @test_reduce_sum_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.+]] = torch.constant.none @@ -792,12 +840,16 @@ func.func @test_reduce_sum_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2] return %0 : !torch.vtensor<[3,2],f32> } +// ----- + // CHECK-LABEL: func.func @test_reduce_sum_empty_axes_input_noop_example func.func @test_reduce_sum_empty_axes_input_noop_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[3,2,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { %0 = torch.operator "onnx.ReduceSum"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64, torch.onnx.noop_with_empty_axes = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[3,2,2],f32> return %0 : !torch.vtensor<[3,2,2],f32> } +// ----- + // CHECK-LABEL: func.func @test_reduce_sum_empty_set_non_reduced_axis_zero func.func @test_reduce_sum_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,0,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.+]] = torch.constant.none @@ -817,6 +869,8 @@ func.func @test_reduce_sum_empty_set_non_reduced_axis_zero(%arg0: !torch.vtensor return %0 : !torch.vtensor<[2,0,1],f32> } +// ----- + // CHECK-LABEL: func.func @test_reduce_sum_keepdims_example func.func @test_reduce_sum_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.+]] = torch.constant.none @@ -836,6 +890,8 @@ func.func @test_reduce_sum_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, return %0 : !torch.vtensor<[3,1,2],f32> } +// ----- + // CHECK-LABEL: func.func @test_reduce_sum_negative_axes_keepdims_example func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.+]] = torch.constant.none @@ -867,6 +923,8 @@ func.func @test_reduce_mean_default_axes_keepdims_example(%arg0: !torch.vtensor< return %0 : !torch.vtensor<[1,1,1],f32> } +// ----- + // CHECK-LABEL: func.func @test_reduce_mean_do_not_keepdims_example func.func @test_reduce_mean_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.+]] = torch.constant.none @@ -886,6 +944,8 @@ func.func @test_reduce_mean_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2 return %0 : !torch.vtensor<[3,2],f32> } +// ----- + // CHECK-LABEL: func.func @test_reduce_mean_keepdims_example func.func @test_reduce_mean_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.+]] = torch.constant.none @@ -905,6 +965,8 @@ func.func @test_reduce_mean_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, return %0 : !torch.vtensor<[3,1,2],f32> } +// ----- + // CHECK-LABEL: func.func @test_reduce_mean_negative_axes_keepdims_example func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[NONE:.+]] = torch.constant.none @@ -958,7 +1020,6 @@ func.func @test_reduce_min_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %a // ----- - // CHECK-LABEL: func.func @test_reduce_min_bool_inputs func.func @test_reduce_min_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[IDX:.+]] = torch.constant.int 0 From e7d90a4b82be35ae7aed9bd801048205abe7de38 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Thu, 29 Feb 2024 13:01:13 -0800 Subject: [PATCH 242/283] [onnx] Fix type on create_module() in onnx_importer.py. (#2968) The type returned was changed in https://github.com/llvm/torch-mlir/pull/2795. This led to errors in the downstream IREE project: https://github.com/openxla/iree/pull/16622. --- python/torch_mlir/extras/onnx_importer.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index 91ee4c14c75a..289e5722efce 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -99,12 +99,12 @@ def __init__(self, model_proto: onnx.ModelProto, *, config: Config = Config()): assert model_proto.graph, "Model must contain a main Graph" self.main_graph = GraphInfo(self, model_proto.graph) - def create_module(self, context: Optional[Context] = None) -> Operation: + def create_module(self, context: Optional[Context] = None) -> Module: if not context: context = Context() - module_op = Module.create(Location.unknown(context)) + module = Module.create(Location.unknown(context)) # TODO: Populate module level metadata from the ModelProto - return module_op + return module class GraphInfo: From d030bffc624860b57d43dc918e3bd2a55d33e077 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 1 Mar 2024 12:31:07 -0800 Subject: [PATCH 243/283] [torch] Support `aten.view` rank-0 collapse (#2965) Collapsing to a rank-0 tensor using `aten.view` was currently bailing out. Added the special case. --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 22 +++++++++++-------- projects/pt1/e2e_testing/xfail_sets.py | 1 - 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index 42aacceab0b4..e6ae601dc855 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -795,17 +795,21 @@ class ConvertAtenViewOp : public OpConversionPattern { auto resultType = typeConverter->convertType(op.getType()).cast(); int64_t resultRank = resultType.getRank(); - if (resultRank == 0) - return rewriter.notifyMatchFailure(op, - "result shape of rank 0 is invalid"); + if (resultRank == 0) { + rewriter + .replaceOpWithNewOp( + op, resultType, input, ArrayRef()) + .getResult(); + return success(); + } if (inputRank == 0) { - Value expanded = - rewriter - .create(loc, resultType, input, - ArrayRef()) - .getResult(); - rewriter.replaceOp(op, expanded); + llvm::SmallVector outshape(resultRank, 1); + auto expandTy = + RankedTensorType::get(outshape, resultType.getElementType()); + Value expand = rewriter.create( + op.getLoc(), expandTy, input, ArrayRef()); + rewriter.replaceOpWithNewOp(op, resultType, expand); return success(); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e2b198839436..05ca09922ac6 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2144,7 +2144,6 @@ "ReduceMaxUnsignedIntModule_basic", # Failure - torch.aten.view lower - "AddSizeIntModule_basic", "ElementwiseFlattenBroadcastModule_basic", "IndexTensorDyanmicInputContiguousWithNoneModule_basic", "IndexTensorDyanmicInputNonContiguousWithNoneModule_basic", From 61f0a5facf6d6bde55eecc4200d048aa55690b64 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 1 Mar 2024 21:41:12 -0800 Subject: [PATCH 244/283] [torch] Add an `aten.cat` length-0 canonicalization (#2966) If an input is length-0 along the dimension of canonicalization we can remove the tensor from the list --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1 + lib/Dialect/Torch/IR/TorchOps.cpp | 35 ++++++++++++++++++- projects/pt1/e2e_testing/xfail_sets.py | 1 - .../build_tools/torch_ods_gen.py | 2 +- test/Dialect/Torch/canonicalize.mlir | 10 ++++++ 5 files changed, 46 insertions(+), 3 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 9d245723fd84..7b698793ae10 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -12594,6 +12594,7 @@ def Torch_AtenCatOp : Torch_Op<"aten.cat", [ } }]; let hasFolder = 1; + let hasCanonicalizer = 1; } def Torch_AtenStackOp : Torch_Op<"aten.stack", [ diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 6120fd6f0e32..1aae3735d0d5 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2900,13 +2900,46 @@ OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) { OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) { auto list = getOperand(0).getDefiningOp(); - if (!list || !list->hasOneUse() || list.getElements().size() != 1) + if (!list || list.getElements().size() != 1) return nullptr; if (list.getElements()[0].getType() != getResult().getType()) return nullptr; return list.getElements()[0]; } +void AtenCatOp::getCanonicalizationPatterns(RewritePatternSet &patterns, + MLIRContext *context) { + patterns.add(+[](AtenCatOp op, PatternRewriter &rewriter) { + auto list = op.getTensors().getDefiningOp(); + auto resultTy = dyn_cast(op.getType()); + if (!list || !resultTy) + return failure(); + + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return failure(); + + llvm::SmallVector filtered; + for (auto operand : list.getOperands()) { + auto operandTy = dyn_cast(operand.getType()); + if (!operandTy || !operandTy.hasSizes()) + return failure(); + int64_t adim = dim < 0 ? dim + operandTy.getSizes().size() : dim; + if (operandTy.getSizes()[adim] != 0) + filtered.push_back(operand); + } + + if (filtered.size() == list.getNumOperands()) + return failure(); + + auto newlist = rewriter.create( + op.getLoc(), list.getType(), filtered); + rewriter.replaceOpWithNewOp(op, op.getType(), newlist, + op.getDim()); + return success(); + }); +} + //===----------------------------------------------------------------------===// // AtenBroadcastToOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 05ca09922ac6..55bcc4a33620 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2144,7 +2144,6 @@ "ReduceMaxUnsignedIntModule_basic", # Failure - torch.aten.view lower - "ElementwiseFlattenBroadcastModule_basic", "IndexTensorDyanmicInputContiguousWithNoneModule_basic", "IndexTensorDyanmicInputNonContiguousWithNoneModule_basic", "IndexTensorHackedTwinMultiInputNonContiguousMultipleStaticDims_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index fed048a64340..3f6cf1b51dd4 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -727,7 +727,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::Delete.Dict_str : (Dict(str, t), str) -> ()") # List ops. - emit("aten::cat : (Tensor[], int) -> (Tensor)", has_folder=True) + emit("aten::cat : (Tensor[], int) -> (Tensor)", has_canonicalizer=True, has_folder=True) emit("aten::stack : (Tensor[], int) -> (Tensor)") emit("aten::append.t : (t[], t) -> (t[])") emit("aten::add.t : (t[], t[]) -> (t[])", has_canonicalizer=True) diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index b3dd4c6f0641..952c9f78d0c5 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2663,3 +2663,13 @@ func.func @aten_shape_to_tensor(%arg0 : !torch.vtensor<[4,5,6],f32>) -> !torch.v return %0 : !torch.vtensor<[3],si32> } +// ----- + +// CHECK-LABEL: @aten_cat_zero +func.func @aten_cat_zero(%arg0 : !torch.vtensor<[4,5,6],f32>, %arg1 : !torch.vtensor<[4,0,6],f32>) -> !torch.vtensor<[4,5,6],f32> { + // CHECK: return %arg0 : !torch.vtensor<[4,5,6],f32> + %list = torch.prim.ListConstruct %arg0, %arg1 : (!torch.vtensor<[4,5,6],f32>, !torch.vtensor<[4,0,6],f32>) -> !torch.list + %dim = torch.constant.int -2 + %0 = torch.aten.cat %list, %dim : !torch.list, !torch.int -> !torch.vtensor<[4,5,6],f32> + return %0 : !torch.vtensor<[4,5,6],f32> +} From 916554f270bebcb8a2195ce58eea7dd2c04c47e1 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Mon, 4 Mar 2024 23:31:54 +0800 Subject: [PATCH 245/283] [Stablehlo] add torch_to_stablehlo::getBackendTypeForScalarType (#2975) --- lib/Conversion/TorchToStablehlo/Basic.cpp | 12 +++----- .../TorchToStablehlo/CMakeLists.txt | 1 + lib/Conversion/TorchToStablehlo/Utils.cpp | 30 +++++++++++++++++++ lib/Conversion/TorchToStablehlo/Utils.h | 25 ++++++++++++++++ 4 files changed, 60 insertions(+), 8 deletions(-) create mode 100644 lib/Conversion/TorchToStablehlo/Utils.cpp create mode 100644 lib/Conversion/TorchToStablehlo/Utils.h diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index bee6c529bacb..d902202e8202 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -11,6 +11,7 @@ #include "../PassDetail.h" #include "PopulatePatterns.h" +#include "Utils.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Complex/IR/Complex.h" @@ -1662,19 +1663,14 @@ LogicalResult ConvertAtenOp::matchAndRewrite( if (!matchPattern(op.getDtype(), m_TorchConstantInt(&dtypeInt))) return rewriter.notifyMatchFailure( op, "unimplemented: dtype must be a constant integer or none"); - FailureOr maybeResultElementType = getTypeForScalarType( - op->getContext(), (torch_upstream::ScalarType)dtypeInt); + FailureOr maybeResultElementType = + torch_to_stablehlo::getBackendTypeForScalarType( + op->getContext(), (torch_upstream::ScalarType)dtypeInt); if (failed(maybeResultElementType)) { return rewriter.notifyMatchFailure( op, "unable to convert `dtypeInt` to builtin type"); } resultElementType = *maybeResultElementType; - // The stablehlo backend expects signed integers to be signless. - if (resultElementType.isSignedInteger()) { - resultElementType = IntegerType::get( - op->getContext(), resultElementType.getIntOrFloatBitWidth(), - IntegerType::Signless); - } } // Create an uninitialized tensor of `resultSize` shape. diff --git a/lib/Conversion/TorchToStablehlo/CMakeLists.txt b/lib/Conversion/TorchToStablehlo/CMakeLists.txt index 07ef1e2ea661..566f1d15b6ad 100644 --- a/lib/Conversion/TorchToStablehlo/CMakeLists.txt +++ b/lib/Conversion/TorchToStablehlo/CMakeLists.txt @@ -7,6 +7,7 @@ add_mlir_conversion_library(TorchMLIRTorchToStablehlo ViewLike.cpp Reduction.cpp Pooling.cpp + Utils.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/torch-mlir/Conversion/TorchToStablehlo diff --git a/lib/Conversion/TorchToStablehlo/Utils.cpp b/lib/Conversion/TorchToStablehlo/Utils.cpp new file mode 100644 index 000000000000..390888750110 --- /dev/null +++ b/lib/Conversion/TorchToStablehlo/Utils.cpp @@ -0,0 +1,30 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "./Utils.h" +#include "torch-mlir/Dialect/Torch/Utils/Utils.h" + +using namespace mlir; +using namespace torch; + +FailureOr torch_to_stablehlo::getBackendTypeForScalarType( + MLIRContext *context, torch_upstream::ScalarType dtypeInt) { + FailureOr maybeType = Torch::getTypeForScalarType( + context, (torch_upstream::ScalarType)dtypeInt); + if (failed(maybeType)) { + return failure(); + } + Type type = *maybeType; + // The stablehlo backend expects signed integers to be signless. + if (type.isSignedInteger()) { + type = IntegerType::get(context, type.getIntOrFloatBitWidth(), + IntegerType::Signless); + } + return type; +} diff --git a/lib/Conversion/TorchToStablehlo/Utils.h b/lib/Conversion/TorchToStablehlo/Utils.h new file mode 100644 index 000000000000..16788e3955c4 --- /dev/null +++ b/lib/Conversion/TorchToStablehlo/Utils.h @@ -0,0 +1,25 @@ +//===----------------------------------------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// Also available under a BSD-style license. See LICENSE. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Transforms/DialectConversion.h" +#include "torch-mlir/Dialect/Torch/Utils/TorchUpstream.h" + +namespace mlir { +namespace torch { +namespace torch_to_stablehlo { + +// Convert a scalar type to the corresponding builtin type in the +// stablehlo backend. +FailureOr +getBackendTypeForScalarType(MLIRContext *context, + torch_upstream::ScalarType dtypeInt); + +} // namespace torch_to_stablehlo +} // namespace torch +} // namespace mlir From d51e80b648cf114165a466a36b237dd5e2949009 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 4 Mar 2024 08:25:19 -0800 Subject: [PATCH 246/283] [onnx] Fix onnx.gather lowering for rank-0 indices (#2973) We assumed rank was atleast 1 however it can be rank-0, generating an illegal pair of flatten / unflatten operations. Corrected this. --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 51 ++++++++++++++----- projects/pt1/e2e_testing/xfail_sets.py | 2 - .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 33 ++++++++++-- 3 files changed, 69 insertions(+), 17 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 4d1aaf42d679..6855dcc3df33 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -572,10 +572,12 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( auto ctx = binder.op->getContext(); auto indicesTy = cast(indices.getType()); auto dataTy = cast(data.getType()); - if (!dataTy || !dataTy.hasSizes()) + if (!dataTy || !dataTy.hasSizes() || !indicesTy.hasSizes()) return failure(); - if (axis < 0) - axis += dataTy.getSizes().size(); + + int64_t dataRank = dataTy.getSizes().size(); + int64_t indicesRank = indicesTy.getSizes().size(); + axis = axis < 0 ? axis + dataRank : axis; Value index = rewriter.create( loc, Torch::IntType::get(ctx), rewriter.getI64IntegerAttr(axis)); @@ -599,8 +601,16 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( auto intListTy = rewriter.getType( rewriter.getType()); - auto indicesSize = - rewriter.create(loc, intListTy, indices); + + llvm::SmallVector indicesDims; + for (int i = 0, s = indicesTy.getSizes().size(); i < s; ++i) { + Value k = rewriter.create(binder.getLoc(), i); + indicesDims.push_back(rewriter.create( + binder.getLoc(), indices, k)); + } + + Value indicesSizeList = rewriter.create( + binder.getLoc(), intListTy, indicesDims); // Determine the collapsed dim size: auto indicesCt = 1; @@ -615,20 +625,37 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( auto flattenTy = rewriter.getType( SmallVector{indicesCt}, indicesTy.getOptionalDtype()); - Value rank = rewriter.create(loc, intTy, indices); - Value end = rewriter.create(loc, rank, one); - indices = rewriter.create( - loc, flattenTy, indices, zero, end); + + if (indicesRank == 0) { + indices = rewriter.create( + binder.getLoc(), flattenTy, indices, zero); + } else if (indicesRank > 1) { + Value rank = rewriter.create(loc, intTy, indices); + Value end = rewriter.create(loc, rank, one); + indices = rewriter.create( + loc, flattenTy, indices, zero, end); + } llvm::SmallVector gatherShape(dataTy.getSizes()); gatherShape[axis] = indicesCt; - auto gatherTy = rewriter.getType( gatherShape, dataTy.getOptionalDtype()); Value gather = rewriter.create( loc, gatherTy, data, index, indices); - rewriter.replaceOpWithNewOp( - binder.op, resultType, gather, index, indicesSize); + + if (indicesRank == 1) { + rewriter.replaceOp(binder.op, gather); + return success(); + } + + if (indicesRank > 1) { + gather = rewriter.replaceOpWithNewOp( + binder.op, resultType, gather, index, indicesSizeList); + return success(); + } + + rewriter.replaceOpWithNewOp(binder.op, resultType, + gather); return success(); }); patterns.onOp( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 55bcc4a33620..6e94adedadff 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2190,14 +2190,12 @@ "ElementwiseUnaryIntModule_basic", "ElementwiseUnsqueezeNegDimsModule_basic", "ElementwiseWhereScalarModule_basic", - "EmbeddingModule1DIndices_basic", "EmbeddingModuleF16_basic", "EmbeddingModuleI32_basic", "EmbeddingModuleI64_basic", "FlattenDynamicModule_basic", "GluStaticModule_basic", "GroupNormModule_basic", - "IndexSelectDynamicIndexSizeModule_basic", "IndexSelectDynamicModulebasic", "IndexTensorHackedTwinModule3dInput_basic", "IndexTensorHackedTwinModule_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index d5a47aba353d..ee93f13e2c40 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -37,8 +37,8 @@ func.func @test_less(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[ // ----- -// CHECK-LABEL: func.func @test_gather -func.func @test_gather(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} { +// CHECK-LABEL: func.func @test_gather_nd +func.func @test_gather_nd(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[8,10,20,40], si64>) -> !torch.vtensor<[8,10,20,40,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} { // CHECK: %[[AXIS:.+]] = torch.constant.int 0 // CHECK: %[[ZERO:.+]] = torch.constant.int 0 // CHECK: %[[ONE:.+]] = torch.constant.int 1 @@ -46,7 +46,15 @@ func.func @test_gather(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]] // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]] // CHECK: %[[SEL:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %arg1 - // CHECK: %[[SZ:.+]] = torch.aten.size %[[SEL]] + // CHECK: %[[D0:.+]] = torch.constant.int 0 + // CHECK: %[[SZ0:.+]] = torch.aten.size.int %[[SEL]], %[[D0]] + // CHECK: %[[D1:.+]] = torch.constant.int 1 + // CHECK: %[[SZ1:.+]] = torch.aten.size.int %[[SEL]], %[[D1]] + // CHECK: %[[D2:.+]] = torch.constant.int 2 + // CHECK: %[[SZ2:.+]] = torch.aten.size.int %[[SEL]], %[[D2]] + // CHECK: %[[D3:.+]] = torch.constant.int 3 + // CHECK: %[[SZ3:.+]] = torch.aten.size.int %[[SEL]], %[[D3]] + // CHECK: %[[SZ:.+]] = torch.prim.ListConstruct %[[SZ0]], %[[SZ1]], %[[SZ2]], %[[SZ3]] // CHECK: %[[DIM:.+]] = torch.aten.dim %[[SEL]] // CHECK: %[[SUB:.+]] = torch.aten.sub.int %[[DIM]], %[[ONE]] // CHECK: %[[FLAT:.+]] = torch.aten.flatten.using_ints %[[SEL]], %[[ZERO]], %[[SUB]] @@ -59,6 +67,25 @@ func.func @test_gather(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor // ----- +// CHECK-LABEL: func.func @test_gather_scalar +func.func @test_gather_scalar(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[], si64>) -> !torch.vtensor<[4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} { + // CHECK: %[[AXIS:.+]] = torch.constant.int 0 + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK: %[[ONE:.+]] = torch.constant.int 1 + // CHECK: %[[LT:.+]] = torch.aten.le.Scalar %arg1, %[[ZERO]] + // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]] + // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]] + // CHECK: %[[SEL:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %arg1 + // CHECK: %[[FLAT:.+]] = torch.aten.unsqueeze %[[SEL]], %[[ZERO]] : !torch.vtensor<[],si64>, !torch.int -> !torch.vtensor<[1],si64> + // CHECK: %[[ISEL:.+]] = torch.aten.index_select %arg0, %[[AXIS]], %[[FLAT]] + // CHECK: %[[RES:.+]] = torch.aten.squeeze %[[ISEL]] : !torch.vtensor<[1,4,5],f32> -> !torch.vtensor<[4,5],f32> + // CHECK: return %[[RES]] + %0 = torch.operator "onnx.Gather"(%arg0, %arg1) {torch.onnx.axis = 0 : si64} : (!torch.vtensor<[3,4,5],f32>, !torch.vtensor<[], si64>) -> !torch.vtensor<[4,5],f32> + return %0 : !torch.vtensor<[4,5],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_gather_elements func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[3,4,5], si64>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.opset_version = 13 : si64} { // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0 From 19d488827859d0a1611255e1b49666186aa0cd0f Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 4 Mar 2024 10:17:42 -0800 Subject: [PATCH 247/283] [torch] Make torch.aten.unflatten lower directly to linalg (#2971) Existing lowering via aten.view does not work as well for dynamic shapes as the lowering to tensor.expand must re-infer dynamic shape matching. Better to directly lower. --- lib/Conversion/TorchToLinalg/DataMovement.cpp | 64 +++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 - projects/pt1/python/torch_mlir/torchscript.py | 2 +- .../onnx_backends/linalg_on_tensors.py | 4 +- 4 files changed, 68 insertions(+), 3 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/DataMovement.cpp b/lib/Conversion/TorchToLinalg/DataMovement.cpp index e6ae601dc855..e4bf1886bb91 100644 --- a/lib/Conversion/TorchToLinalg/DataMovement.cpp +++ b/lib/Conversion/TorchToLinalg/DataMovement.cpp @@ -638,6 +638,68 @@ class ConvertAtenFlattenUsingIntsOp }; } // namespace +// Lower aten.unflatten.int into tensor.expand_shape +namespace { +class ConvertAtenUnflattenIntOp + : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenUnflattenIntOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + Value self = op.getSelf(); + BaseTensorType outputTensorType = op.getType().cast(); + if (!outputTensorType.hasSizes()) + return rewriter.notifyMatchFailure( + op, "unimplemented: output must have known sizes"); + + std::optional maybeRank = getTensorRank(self); + if (!maybeRank) + return rewriter.notifyMatchFailure(op, "unimplemented: unranked tensor"); + auto inputTensorType = self.getType().cast(); + if (!inputTensorType || !inputTensorType.hasSizes()) { + return rewriter.notifyMatchFailure(op, + "Expected input type having sizes"); + } + int inputRank = inputTensorType.getSizes().size(); + int outputRank = outputTensorType.getSizes().size(); + + int64_t dimInt; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dimInt))) + return rewriter.notifyMatchFailure( + op, "unimplemented: requires dim to be constants"); + + dimInt = toPositiveDim(dimInt, inputRank); + if (!isValidDim(dimInt, inputRank)) + return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); + + auto sizesOp = op.getSizes().getDefiningOp(); + int numSizes = sizesOp.getNumOperands(); + + SmallVector reassociations(inputRank); + if (inputRank > 0) { + for (int i = 0; i < dimInt; ++i) + reassociations[i].push_back(i); + + for (int i = 0; i < numSizes; ++i) + reassociations[dimInt].push_back(i + dimInt); + + for (int i = dimInt + numSizes; i < outputRank; ++i) + reassociations[i - numSizes + 1].push_back(i); + } + + auto expandTy = getTypeConverter()->convertType(outputTensorType); + auto expand = rewriter + .create( + loc, expandTy, adaptor.getSelf(), reassociations) + .getResult(); + rewriter.replaceOp(op, expand); + return success(); + } +}; +} // namespace + namespace { /// The `ConvertAtenViewOp` conversion pattern converts `aten.View` op to /// one `linalg.TensorExpandShape` op for all expanded dimensions and one @@ -2043,6 +2105,8 @@ void mlir::torch::torch_to_linalg::populateDataMovementPatternsAndLegality( target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); + patterns.add(typeConverter, context); + target.addIllegalOp(); patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 5d3488b11aed..e123f1dee4ac 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -379,7 +379,6 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); - target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/python/torch_mlir/torchscript.py b/projects/pt1/python/torch_mlir/torchscript.py index f3412b83addb..f52a24360afc 100644 --- a/projects/pt1/python/torch_mlir/torchscript.py +++ b/projects/pt1/python/torch_mlir/torchscript.py @@ -248,7 +248,7 @@ def _get_for_tracing( # compiler where each backend can "own" its set of legal ops. BACKEND_LEGAL_OPS = { OutputType.TOSA: ['aten.flatten.using_ints', 'aten.native_layer_norm', 'aten.linear'], - OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints','aten.adaptive_avg_pool1d'], + OutputType.LINALG_ON_TENSORS: ['aten.flatten.using_ints','aten.adaptive_avg_pool1d', 'aten.unflatten.int'], OutputType.STABLEHLO: [], } diff --git a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py index 0e5073fdd89d..449e6bb40f01 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py +++ b/projects/pt1/python/torch_mlir_e2e_test/onnx_backends/linalg_on_tensors.py @@ -50,9 +50,11 @@ def compile(self, imported_module: Module): f"builtin.module(func.func({ONNX_TO_TORCH_FUNC_PIPELINE}))", "Lowering Onnx backend contract to Linalg-on-Tensors backend contract") + backend_legal_ops = ['aten.flatten.using_ints','aten.adaptive_avg_pool1d', 'aten.unflatten.int'] + option_string = "{backend-legal-ops=" + ",".join(backend_legal_ops) + "}" run_pipeline_with_repro_report( imported_module, - f"builtin.module(torch-lower-to-backend-contract)", + f"builtin.module(torch-lower-to-backend-contract{option_string})", "Lowering TorchFX IR -> Torch Backend IR", ) From 09875fabd1b37b8c15822088cebe57a6e866c528 Mon Sep 17 00:00:00 2001 From: Chi_Liu <22491986+AmosLewis@users.noreply.github.com> Date: Mon, 4 Mar 2024 11:07:03 -0800 Subject: [PATCH 248/283] [MLIR][ONNX] Add ONNX ReduceProd support (#2943) Alternatives to https://github.com/llvm/torch-mlir/pull/2908 Fix https://github.com/nod-ai/SHARK-Turbine/issues/353 --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 157 ++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 7 - .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 52 ++++++ 3 files changed, 209 insertions(+), 7 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index b697a4fa2c48..2b08630705c4 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -1777,6 +1777,163 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( data, dimValueList); return success(); }); + patterns.onOp( + "ReduceProd", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // ReduceProd allows us to pass a list of dims but AtenProdDimIn only + // allow one dim as input. + Torch::ValueTensorType resultType; + Value data; + Value axes; + int64_t keepDims; + int64_t noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + + auto dataTy = cast(data.getType()); + Torch::IntType torchIntTy = rewriter.getType(); + + if (!resultType.hasSizes() || !resultType.areAllSizesKnown() || + !dataTy.areAllSizesKnown()) + return rewriter.notifyMatchFailure( + binder.op, + "Expected the input and result type to have known sizes"); + + int64_t rank = dataTy.getSizes().size(); + SmallVector axesList; + Value zero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + + // Previous version of the operation had the axes as an attribute: + llvm::SmallVector axesAttr; + if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) { + for (int i = 0, s = axesAttr.size(); i < s; ++i) { + axesList.push_back(rewriter.create( + binder.getLoc(), torchIntTy, + rewriter.getI64IntegerAttr(axesAttr[i]))); + } + } + + // Handle cases that axes are explicitly specified. + // Extract the axes values from the axes operand. + // This really shouldn't happen but it helps pass weird tests. + // TODO: Derive the chosen axes from the data type and final result type + // instead of using the dynamic axes at operand[1]. + if (!binder.tensorOperandAtIndex(axes, 1)) { + Torch::BaseTensorType axesType = + axes.getType().cast(); + auto sizes = axesType.getSizes(); + for (int i = 0; i < sizes[0]; i++) { + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i)); + Value extract = rewriter.create( + binder.getLoc(), + axesType.getWithSizesAndDtype(llvm::SmallVector{1}, + axesType.getOptionalDtype()), + axes, zero, selectIndex); + Value dim = rewriter.create(binder.getLoc(), + torchIntTy, extract); + axesList.push_back(dim); + } + } + + // Handle the noop case: + // When axes is empty and noop_with_empty_axes is set to true, input + // tensor will not be reduced, and the output tensor would be + // equivalent to input tensor. + if (axesList.empty() && noop_with_empty_axes) { + rewriter.replaceOp(binder.op, data); + return success(); + } + + // Handle case when no axes arg is passed but not a noop: + // Manually set positive axis to all dims. + if (axesList.empty()) { + for (int i = 0; i < rank; i++) { + Value dimValue = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(i)); + axesList.push_back(dimValue); + } + } + + // Handle negative axis: + Value rankVal = rewriter.create(binder.getLoc(), + torchIntTy, data); + for (Value &axes : axesList) { + Value isNegative = + rewriter.create(binder.getLoc(), axes, zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, rankVal); + axes = rewriter.create(binder.getLoc(), axes, + finalOffset); + } + + // Handle multiple axes case: + // ReduceProd on each dim, always set keepDimsBool == True to avoid + // segfault. + Value trueVal = + rewriter.create(binder.getLoc(), true); + Value noneVal = rewriter.create(binder.getLoc()); + SmallVector intermediateShape(rank, Torch::kUnknownSize); + Value dataReduceProd = data; + for (int i = 0, numAxes = axesList.size(); i < numAxes; i++) { + auto axis = axesList[i]; + if (keepDims && i == numAxes - 1) { + dataReduceProd = rewriter.create( + binder.getLoc(), + dataTy.getWithSizesAndDtype(resultType.getSizes(), + dataTy.getOptionalDtype()), + dataReduceProd, axis, trueVal, noneVal); + rewriter.replaceOp(binder.op, dataReduceProd); + return success(); + } + Type resultTyReduceProd = dataTy.getWithSizesAndDtype( + ArrayRef(intermediateShape), dataTy.getOptionalDtype()); + dataReduceProd = rewriter.create( + binder.getLoc(), resultTyReduceProd, dataReduceProd, axis, + trueVal, noneVal); + } + + // Derived the final shape of the tensor after prod loop of each axis. + SmallVector dataReduceProdSize; + auto dataSize = dataTy.getSizes(); + auto resultTypeSizes = resultType.getSizes(); + if (!keepDims) { + // Handle the keepDimsBool == False case: + // 2 point algorithm to derive the static shape after prod loop. + int j = 0; + for (int i = 0; i < rank; i++) { + if (resultTypeSizes.size() && dataSize[i] == resultTypeSizes[j]) { + dataReduceProdSize.push_back(resultTypeSizes[i]); + j++; + continue; + } + dataReduceProdSize.push_back(1); + } + } + + // Handle the keepDimsBool == False case: + // Reshape the prod loop result to the final result shape. + SmallVector dataReduceProdShape; + for (auto dim : dataReduceProdSize) + dataReduceProdShape.push_back(rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(dim))); + Value dataReduceProdShapeList = + rewriter.create( + binder.getLoc(), + rewriter.getType( + rewriter.getType()), + dataReduceProdShape); + rewriter.replaceOpWithNewOp( + binder.op, resultType, dataReduceProd, dataReduceProdShapeList); + return success(); + }); patterns.onOp( "Range", 11, [](OpBinder binder, ConversionPatternRewriter &rewriter) { // ONNX.Range(start, limit, delta) -- limit is exclusive diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 6e94adedadff..abb019f8ed06 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2029,9 +2029,6 @@ "StdCorrectionLargeInputModule_basic", "StdCorrectionModule_basic", "StdCorrectionNoneModule_basic", - "StdCorrectionSingleDimReduceModule_basic", - "StdDimKeepDimFalseModule_basic", - "StdDimKeepDimTrueModule_basic", "StdDimNoneDimModule_basic", "StdUnbiasedModule_basic", "VarCorrectionAllDimReduceModule_basic", @@ -2039,17 +2036,13 @@ "VarCorrectionLargeInputModule_basic", "VarCorrectionModule_basic", "VarCorrectionNoneModule_basic", - "VarCorrectionSingleDimReduceModule_basic", "VarDimAllDimReduceModule_basic", "VarDimModule_basic", "VarDimMultiDimModule_basic", - "VarDimNegativeModule_basic", "VarDimNoneDimModule_basic", "VarDimSingleDimModule_basic", "VarDimUnbiasedModule_basic", - "VarMeanCorrectionModule_basic", "VarMeanCorrectionNoneModule_basic", - "VarMeanDimModule_basic", "VarMeanUnbiasedModule_basic", "VarUnbiasedModule_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 58b4287a41c5..11c9e4e6e5da 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1106,6 +1106,58 @@ func.func @test_reduce_min_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtens // ----- +// CHECK-LABEL: func.func @test_reduce_prod_default_axes_keepdims_random +func.func @test_reduce_prod_default_axes_keepdims_random(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: %[[INT0:.*]] = torch.constant.int 0 + // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 + // CHECK: %[[INT1:.*]] = torch.constant.int 1 + // CHECK: %[[INT2:.*]] = torch.constant.int 2 + // CHECK: %[[RANK:.*]] = torch.aten.dim %arg0 : !torch.vtensor<[3,2,2],f32> -> !torch.int + // CHECK: %[[LT:.*]] = torch.aten.lt.int %[[INT0_0]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.*]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.*]] = torch.aten.add.int %[[INT0_0]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LT_0:.*]] = torch.aten.lt.int %[[INT1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL_0:.*]] = torch.aten.Int.bool %[[LT_0]] : !torch.bool -> !torch.int + // CHECK: %[[MUL_0:.*]] = torch.aten.mul.int %[[BOOL_0]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD_0:.*]] = torch.aten.add.int %[[INT1]], %[[MUL_0]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LT_1:.*]] = torch.aten.lt.int %[[INT2]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL_1:.*]] = torch.aten.Int.bool %[[LT_1]] : !torch.bool -> !torch.int + // CHECK: %[[MUL_1:.*]] = torch.aten.mul.int %[[BOOL_1]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD_1:.*]] = torch.aten.add.int %[[INT2]], %[[MUL_1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[TRUE:.*]] = torch.constant.bool true + // CHECK: %[[NONE:.*]] = torch.constant.none + // CHECK: %[[PROD_0:.*]] = torch.aten.prod.dim_int %arg0, %[[ADD]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32> + // CHECK: %[[PROD_1:.*]] = torch.aten.prod.dim_int %[[PROD_0]], %[[ADD_0]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[?,?,?],f32> + // CHECK: %[[PROD_2:.*]] = torch.aten.prod.dim_int %[[PROD_1]], %[[ADD_1]], %[[TRUE]], %[[NONE]] : !torch.vtensor<[?,?,?],f32>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> + // CHECK: return %[[PROD_2]] : !torch.vtensor<[1,1,1],f32> + %0 = torch.operator "onnx.ReduceProd"(%arg0) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> + return %0 : !torch.vtensor<[1,1,1],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_prod_keepdims_random +func.func @test_reduce_prod_keepdims_random(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { +// CHECK: %[[INT0:.*]] = torch.constant.int 0 +// CHECK: %[[INT0_0:.*]] = torch.constant.int 0 +// CHECK: %[[SELECT:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> +// CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT]] : !torch.vtensor<[1],si64> -> !torch.int +// CHECK: %[[DIM:.*]] = torch.aten.dim %arg0 : !torch.vtensor<[3,2,2],f32> -> !torch.int +// CHECK: %[[LT:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool +// CHECK: %[[BOOL:.*]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int +// CHECK: %[[MUL:.*]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[ADD:.*]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int +// CHECK: %[[BOOL:.*]] = torch.constant.bool true +// CHECK: %[[NONE:.*]] = torch.constant.none +// CHECK: %[[PROD:.*]] = torch.aten.prod.dim_int %arg0, %[[ADD]], %[[BOOL]], %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.int, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> +// CHECK: return %[[PROD]] : !torch.vtensor<[3,1,2],f32> + %0 = torch.operator "onnx.ReduceProd"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> +} + +// ----- + // CHECK-LABEL: func.func @test_sinh func.func @test_sinh_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64} { // CHECK: torch.aten.sinh %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> From a86e89ecb5c7929a39a38743fb7cacadf1ff41bb Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 4 Mar 2024 11:46:49 -0800 Subject: [PATCH 249/283] [torch] Additional folders for shape computations (#2972) A handful of operations are commonly used in shape calculations (slice, concat, broadcast). Added these additional folders to better propagate simple shape computations. --- lib/Dialect/Torch/IR/TorchOps.cpp | 141 +++++++++++++++++++++------ test/Dialect/Torch/canonicalize.mlir | 63 +++++++++--- 2 files changed, 160 insertions(+), 44 deletions(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 1aae3735d0d5..03f39be9c806 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -2899,12 +2899,59 @@ OpFoldResult AtenSubIntOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// OpFoldResult AtenCatOp::fold(FoldAdaptor adaptor) { + // We set a maximum folding size of 16. This is a reasonable upper limit + // for shape computations. + constexpr int64_t kMaxFoldSize = 16; auto list = getOperand(0).getDefiningOp(); - if (!list || list.getElements().size() != 1) + if (!list) return nullptr; - if (list.getElements()[0].getType() != getResult().getType()) + + auto elements = list.getElements(); + if (elements.size() == 1 && elements[0].getType() == getResult().getType()) + return list.getElements()[0]; + + auto resultTy = dyn_cast(getType()); + if (!resultTy || !resultTy.hasSizes() || !resultTy.hasDtype()) return nullptr; - return list.getElements()[0]; + + auto bResultTy = resultTy.toBuiltinTensor(); + if (!bResultTy.hasStaticShape() || bResultTy.getNumElements() > kMaxFoldSize) + return nullptr; + + auto dimAttr = dyn_cast_or_null(adaptor.getDim()); + if (!dimAttr) + return nullptr; + auto dim = dimAttr.getValue().getSExtValue(); + dim += dim < 0 ? bResultTy.getRank() : 0; + + for (int i = 0, s = bResultTy.getRank(); i < s; ++i) { + if (i == dim) + continue; + if (bResultTy.getDimSize(i) != 1) + return nullptr; + } + + llvm::SmallVector values; + for (auto operand : list.getOperands()) { + DenseElementsAttr dattr; + if (!matchPattern(operand, m_Constant(&dattr))) + return nullptr; + + auto oty = dyn_cast(dattr.getType()); + if (!oty) + return nullptr; + + if (dattr.isSplat()) { + for (int i = 0, s = oty.getDimSize(dim); i < s; ++i) + values.push_back(dattr.getSplatValue()); + } else { + auto evals = dattr.getValues(); + for (int i = 0, s = oty.getDimSize(dim); i < s; ++i) + values.push_back(evals[i]); + } + } + + return DenseElementsAttr::get(bResultTy.clone(resultTy.getDtype()), values); } void AtenCatOp::getCanonicalizationPatterns(RewritePatternSet &patterns, @@ -2947,19 +2994,32 @@ void AtenCatOp::getCanonicalizationPatterns(RewritePatternSet &patterns, OpFoldResult AtenBroadcastToOp::fold(FoldAdaptor adaptor) { auto inType = getOperand(0).getType().dyn_cast(); auto outType = getResult().getType().dyn_cast(); - if (inType != outType) - return nullptr; - if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes()) + if (!inType || !outType || !inType.hasSizes() || !outType.hasSizes() || + !outType.hasDtype()) return nullptr; - if (inType.getSizes().size() != outType.getSizes().size() || - (!isAssumingStrictSymbolicShapes((*this)->getBlock()) && - (!inType.areAllSizesKnown() || !outType.areAllSizesKnown()))) + + if (!inType.areAllSizesKnown() || !outType.areAllSizesKnown()) return nullptr; - for (size_t i = 0; i < inType.getSizes().size(); ++i) { - if (inType.getSizes()[i] != outType.getSizes()[i]) - return nullptr; + + auto inSizes = inType.getSizes(); + auto outSizes = outType.getSizes(); + if (inSizes.size() == outSizes.size()) { + bool sameSizes = true; + for (int i = 0, s = inSizes.size(); i < s; ++i) + sameSizes &= inSizes[i] == outSizes[i]; + + if (sameSizes) + return getOperand(0); } - return getOperand(0); + + auto selfAttr = dyn_cast_or_null(adaptor.getSelf()); + if (!selfAttr) + return nullptr; + if (!selfAttr.isSplat()) + return nullptr; + + auto attrty = RankedTensorType::get(outType.getSizes(), outType.getDtype()); + return DenseElementsAttr::get(attrty, selfAttr.getSplatValue()); } //===----------------------------------------------------------------------===// @@ -2995,23 +3055,44 @@ OpFoldResult AtenSliceTensorOp::fold(FoldAdaptor adaptor) { outType.toBuiltinTensor().clone(inType.getDtype()), input.getSplatValue()); - // If the output is a single value we can index into a constant input and grab - // that single value: - if (input && start && dim && - llvm::all_of(outType.getSizes(), [](int64_t dim) { return dim == 1; })) { - bool unaryNonDim = true; - int64_t dimInt = dim.getValue().getSExtValue(); - for (int i = 0, s = inType.getSizes().size(); i < s; ++i) { - unaryNonDim &= inType.getSizes()[i] == 1 || i == dimInt; - } - if (unaryNonDim) { - int64_t idx = start.getValue().getSExtValue(); - if (idx < 0) - idx += input.getNumElements(); - Attribute value = input.getValues()[idx]; - return DenseElementsAttr::get( - outType.toBuiltinTensor().clone(inType.getDtype()), value); - } + int count = 1; + for (auto dim : outType.getSizes()) + count = count * dim; + + if (count == 0) + return {}; + + if (!dim) + return nullptr; + int64_t dimInt = dim.getValue().getSExtValue(); + if (dimInt < 0) + dimInt += inType.getSizes().size(); + + bool unaryNonDim = true; + for (int i = 0, s = outType.getSizes().size(); i < s; ++i) + unaryNonDim &= outType.getSizes()[i] == 1 || i == dimInt; + + // Fold the slice if the output tensor is relatively small, currently + // coded to 16: + if (input && start && step && dim && count < 16 && unaryNonDim && + count < 16) { + int64_t inCount = input.getNumElements(); + int64_t begin = start.getValue().getSExtValue(); + int64_t stride = step.getValue().getSExtValue(); + if (stride < 1) + return {}; + int64_t limit = end.getValue().getSExtValue(); + begin = begin < 0 ? begin + inCount : begin; + limit = limit < 0 ? limit + inCount : limit; + limit = limit < 0 ? inType.getSizes()[dimInt] : limit; + limit = std::min(limit, inType.getSizes()[dimInt]); + + llvm::SmallVector values; + for (int i = begin; i < limit; i += stride) + values.push_back(input.getValues()[i]); + + return DenseElementsAttr::get( + outType.toBuiltinTensor().clone(inType.getDtype()), values); } // If the input and output shapes are the same we can just fold: diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 952c9f78d0c5..2b5405b75197 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1,4 +1,4 @@ -// RUN: torch-mlir-opt %s -canonicalize | FileCheck %s +// RUN: torch-mlir-opt %s -canonicalize --split-input-file | FileCheck %s // CHECK-LABEL: func.func @torch.aten.__range_length$fold() -> (!torch.int, !torch.int, !torch.int, !torch.int) { // CHECK-DAG: %[[INT1:.*]] = torch.constant.int 1 @@ -1990,6 +1990,7 @@ func.func @torch.aten.sort$nofold (%arg0 : !torch.vtensor<[3, 1, 4],si64>, %arg1 return %0, %1 : !torch.vtensor<[3, 1, 4],si64>, !torch.vtensor<[3],si64> } +// ----- // CHECK-LABEL: @torch.aten.cat$fold_single_operand // CHECK-SAME: %[[ARG0:.+]]: !torch.tensor @@ -2001,6 +2002,22 @@ func.func @torch.aten.cat$fold_single_operand(%arg0: !torch.tensor) -> !torch.te return %1: !torch.tensor } +// ----- + +// CHECK-LABEL: @torch.aten.cat$fold_zero_dim_operand +// CHECK: %[[FOLD:.+]] = torch.vtensor.literal(dense<[1, 3, 2, 2]> : tensor<4xsi32>) +// CHECK: return %[[FOLD]] : !torch.vtensor +func.func @torch.aten.cat$fold_zero_dim_operand() -> !torch.vtensor<[4],si32> { + %0 = torch.vtensor.literal(dense<[1, 3]> : tensor<2xsi32>) : !torch.vtensor<[2],si32> + %1 = torch.vtensor.literal(dense<2> : tensor<2xsi32>) : !torch.vtensor<[2],si32> + %int0 = torch.constant.int 0 + %list = torch.prim.ListConstruct %0, %1 : (!torch.vtensor<[2],si32>, !torch.vtensor<[2],si32>) -> !torch.list + %cat = torch.aten.cat %list, %int0 : !torch.list, !torch.int -> !torch.vtensor<[4],si32> + return %cat: !torch.vtensor<[4],si32> +} + +// ----- + // CHECK-LABEL: func.func @torch.aten.broadcast_to$fold( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,4,2],f32> { // CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[3,4,2],f32> @@ -2013,15 +2030,23 @@ func.func @torch.aten.broadcast_to$fold(%arg0: !torch.vtensor<[3,4,2],f32>) -> ! return %0 : !torch.vtensor<[3,4,2],f32> } -// CHECK-LABEL: func.func @torch.aten.broadcast_to_strict$fold( -// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?],f32>, {{.*}}) -> !torch.vtensor<[?],f32> -// CHECK-NEXT: return %[[ARG]] : !torch.vtensor<[?],f32> -func.func @torch.aten.broadcast_to_strict$fold(%arg0: !torch.vtensor<[?],f32>, %arg1: !torch.int) -> !torch.vtensor<[?],f32> attributes {torch.assume_strict_symbolic_shapes} { - %list = torch.prim.ListConstruct %arg1 : (!torch.int) -> !torch.list - %0 = torch.aten.broadcast_to %arg0, %list : !torch.vtensor<[?],f32>, !torch.list -> !torch.vtensor<[?],f32> - return %0 : !torch.vtensor<[?],f32> +// ----- + +// CHECK-LABEL: func.func @torch.aten.broadcast_to$fold_splat +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<3.000000e+00> : tensor<3x4x2xf32>) : !torch.vtensor<[3,4,2],f32> +// CHECK: return %[[CST]] +func.func @torch.aten.broadcast_to$fold_splat() -> !torch.vtensor<[3,4,2],f32> { + %tensor = torch.vtensor.literal(dense<3.0> : tensor<1x4x1xf32>) : !torch.vtensor<[1,4,1],f32> + %int3 = torch.constant.int 3 + %int4 = torch.constant.int 4 + %int2 = torch.constant.int 2 + %list = torch.prim.ListConstruct %int3, %int4, %int2 : (!torch.int, !torch.int, !torch.int) -> !torch.list + %0 = torch.aten.broadcast_to %tensor, %list : !torch.vtensor<[1,4,1],f32>, !torch.list -> !torch.vtensor<[3,4,2],f32> + return %0 : !torch.vtensor<[3,4,2],f32> } +// ----- + // CHECK-LABEL: @torch.aten.slice.tensor$fold_full_domain_slice // CHECK-SAME: %[[ARG0:.+]]: !torch.vtensor<[4],f32> // CHECK: return %[[ARG0]] : !torch.vtensor<[4],f32> @@ -2078,11 +2103,21 @@ func.func @torch.aten.slice.tensor$fold_dim_1() -> (!torch.vtensor<[1, 1],si64>, // ----- -// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32>) { -// CHECK-NOT: torch.aten.slice.Tensor -// CHECK: %[[RET_0:.*]] = torch.vtensor.literal(dense<1.600000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> -// CHECK: %[[RET_1:.*]] = torch.vtensor.literal(dense<6.400000e+01> : tensor<1x1xf32>) : !torch.vtensor<[1,1],f32> -// CHECK: return %[[RET_0]], %[[RET_1]] : !torch.vtensor<[1,1],f32>, !torch.vtensor<[1,1],f32> +// CHECK-LABEL: func.func @torch.aten.slice.tensor$fold_small() -> !torch.vtensor<[2],si32> { +// CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<[3, 5]> : tensor<2xsi32>) : !torch.vtensor<[2],si32> +// CHECK: return %[[CST]] +func.func @torch.aten.slice.tensor$fold_small() -> (!torch.vtensor<[2],si32>) { + %tensor = torch.vtensor.literal(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xsi32>) : !torch.vtensor<[10],si32> + %dim = torch.constant.int 0 + %int3 = torch.constant.int 3 + %int2 = torch.constant.int 2 + %int7 = torch.constant.int 7 + %0 = torch.aten.slice.Tensor %tensor, %dim, %int3, %int7, %int2 : !torch.vtensor<[10], si32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[2], si32> + return %0 : !torch.vtensor<[2],si32> +} + +// ----- + func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1],f32>) { %tensor = torch.vtensor.literal(dense<[[2.0],[4.0],[8.0],[16.0],[32.0],[64.0],[128.0],[256.0],[512.0],[1024.0]]> : tensor<10x1xf32>) : !torch.vtensor<[10, 1],f32> %int0 = torch.constant.int 0 @@ -2097,7 +2132,7 @@ func.func @torch.aten.slice.tensor$fold_dim_0() -> (!torch.vtensor<[1, 1],f32>, return %0, %1 : !torch.vtensor<[1, 1],f32>, !torch.vtensor<[1, 1], f32> } - +// ----- // CHECK-LABEL: func.func @torch.aten.rsub.Scalar$canonicalize_literal_0d() -> !torch.vtensor<[],si64> { // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<-1> : tensor) : !torch.vtensor<[],si64> From 4d01b0f1a38708d6e7966d1df326dcc9e52d8c5e Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Wed, 6 Mar 2024 01:04:38 +0800 Subject: [PATCH 250/283] [FxImporter] remove dataclass slots to support python3.9 (#2974) * that `dataclass`'s `slots` is supported after python 3.10. --- python/torch_mlir/extras/fx_importer.py | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 2edfeb6cf340..a703db45398f 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -315,15 +315,24 @@ def is_builtin_function_or_method(obj: Any) -> bool: return isinstance(obj, (BuiltinMethodType, BuiltinFunctionType)) -@dataclass(frozen=True, slots=True) +# TODO: switch back to `slots=True` when py3.9 support is dropped +@dataclass(frozen=True) class InputInfo: """Provides additional metadata when resolving inputs.""" + __slots__ = [ + "program", + "input_spec", + "node", + "ir_type", + "mutable_producer_node_name", + ] + program: torch.export.ExportedProgram input_spec: TypingInputSpec node: Node ir_type: IrType - mutable_producer_node_name: Optional[str] = None + mutable_producer_node_name: Optional[str] class FxImporterHooks: @@ -546,7 +555,13 @@ def import_program( node_ir_type = self._cc.node_val_to_type(node, mutable=False) parameter_bindings[node] = ( value, - InputInfo(prog, input_spec, node=node, ir_type=node_ir_type), + InputInfo( + prog, + input_spec, + node=node, + ir_type=node_ir_type, + mutable_producer_node_name=None, + ), ) elif input_spec.kind == InputKind.BUFFER and isinstance( arg, TensorArgument From 933db87a07fb828b17b75f2f5f396a434f8f1a17 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 5 Mar 2024 13:55:13 -0800 Subject: [PATCH 251/283] [onnx] Add support for constants of `i1`s (#2978) `getRawBuffer` expects a densely packed vector of `i1` values however `onnx` does not densely pack the values. Include code to handle the packing / unpacking. --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 14 ++++++++++++-- .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 19 +++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 1c356db890db..52cd59e898a6 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -700,9 +700,19 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( } auto ty = cast(attr.getType()); + ElementsAttr denseAttr; auto ptr = attr.getRawHandle().getBlob()->getData(); - DenseElementsAttr denseAttr = - DenseElementsAttr::getFromRawBuffer(ty, ptr); + if (cast(attr.getType()).getElementType().isInteger(1)) { + llvm::SmallVector newContents; + for (auto val : ptr) { + APInt apval(1, val); + newContents.push_back(apval); + } + denseAttr = DenseElementsAttr::get(ty, newContents); + } else { + denseAttr = DenseElementsAttr::getFromRawBuffer(ty, ptr); + } + rewriter.replaceOpWithNewOp( binder.op, resultType, denseAttr); return success(); diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 7dc262228f1a..7c465d74bb5f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1409,6 +1409,25 @@ func.func @dense_constant() -> () attributes {torch.onnx_meta.ir_version = 8 : s // ----- +// CHECK-LABEL: @dense_constant_i1 +func.func @dense_constant_i1() -> !torch.vtensor<[5],i1> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 17 : si64} { + // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense<[true, false, false, true, true]> : tensor<5xi1>) : !torch.vtensor<[5],i1> + // CHECK: return %[[CST]] : !torch.vtensor<[5],i1> + %0 = torch.operator "onnx.Constant"() {torch.onnx.value = dense_resource<_> : tensor<5xi1>} : () -> !torch.vtensor<[5],i1> + return %0 : !torch.vtensor<[5],i1> +} + +{-# + dialect_resources: { + builtin: { + _: "0x080000000100000101" + } + } +#-} + +// ----- + + // CHECK-LABEL: @test_flatten_4d_axis_2 func.func @test_flatten_4d_axis_2(%arg0: !torch.vtensor<[2,3,4,5],f32>) -> !torch.vtensor<[6,20],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK-DAG: %[[RIGHT_START:.*]] = torch.constant.int 2 From bc0527676b10f5e6d2d9a55b54ec150cdce2b226 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Tue, 5 Mar 2024 15:01:21 -0800 Subject: [PATCH 252/283] [torch] Add support for `torch.split_with_sizes` via decompose (#2979) Convert to individiual slices and tuple together as a list. --------- Co-authored-by: Scott Todd --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 24 ++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 128 ++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 12 +- .../build_tools/torch_ods_gen.py | 1 + .../test_suite/slice_like.py | 22 +++ 5 files changed, 177 insertions(+), 10 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 7b698793ae10..be294e97b053 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -12910,6 +12910,30 @@ def Torch_AtenSplitWithSizesOp : Torch_Op<"aten.split_with_sizes", [ }]; } +def Torch_AtenSplitSizesOp : Torch_Op<"aten.split.sizes", [ + AllowsTypeRefinement, + ReadOnly + ]> { + let summary = "Generated op for `aten::split.sizes : (Tensor, int[], int) -> (Tensor[])`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchListOfTorchIntType:$split_size, + Torch_IntType:$dim + ); + let results = (outs + AnyTorchListOfTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenSplitSizesOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenSplitSizesOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; +} + def Torch_AtenUnbindIntOp : Torch_Op<"aten.unbind.int", [ AllowsTypeRefinement, ReadOnly diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 736d66544e2d..b71e86c33568 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -693,6 +693,131 @@ class DecomposeAtenSelectIntOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposePrimTolistOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(PrimTolistOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + auto self = op.getOperands()[0]; + auto selfTy = dyn_cast(self.getType()); + if (!selfTy || !selfTy.hasSizes()) + return rewriter.notifyMatchFailure(op, "Unknown self shape"); + + int64_t rank = selfTy.getSizes().size(); + if (rank != 1) + return rewriter.notifyMatchFailure(op, "Expected rank-1"); + + int64_t length = selfTy.getSizes().back(); + if (length == Torch::kUnknownSize) + return rewriter.notifyMatchFailure(op, "Tolist length is unknown"); + + auto resultTy = dyn_cast(op.getType(0)); + if (!resultTy) + return rewriter.notifyMatchFailure(op, "Result type is not list"); + + auto scalarTy = resultTy.getContainedType(); + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + auto extractTy = rewriter.getType( + llvm::SmallVector{1}, selfTy.getOptionalDtype()); + llvm::SmallVector results; + llvm::SmallVector sizes(selfTy.getSizes()); + for (int64_t i = 0; i < length; ++i) { + Value iv = + rewriter.create(loc, rewriter.getI64IntegerAttr(i)); + Value extract = rewriter.create( + loc, extractTy, self, /*dim=*/zero, /*index=*/iv); + Value scalar = rewriter.create(loc, scalarTy, extract); + results.push_back(scalar); + } + + rewriter.replaceOpWithNewOp(op, resultTy, results); + return failure(); + } +}; +} // namespace + +namespace { +class DecomposeAtenSplitSizesOp : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSplitSizesOp op, + PatternRewriter &rewriter) const override { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), op.getSelf(), op.getSplitSize(), op.getDim()); + return success(); + } +}; +} // namespace + +namespace { +class DecomposeAtenSplitWithSizesOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenSplitWithSizesOp op, + PatternRewriter &rewriter) const override { + auto loc = op.getLoc(); + Value self = op.getSelf(); + SmallVector splitSizes; + if (!getListConstructElements(op.getSplitSizes(), splitSizes)) + return rewriter.notifyMatchFailure(op, "Unable to get sizes"); + + if (splitSizes.empty()) + return rewriter.notifyMatchFailure(op, "No split sizes"); + + auto selfTy = dyn_cast(self.getType()); + if (!selfTy || !selfTy.hasSizes()) + return rewriter.notifyMatchFailure(op, "Self shape unknown"); + + int64_t rank = selfTy.getSizes().size(); + auto resultTy = dyn_cast(op.getResult().getType()); + if (!resultTy) + return rewriter.notifyMatchFailure(op, "Result type not a list"); + + auto sliceTy = + dyn_cast_or_null(resultTy.getContainedType()); + if (!isa(sliceTy)) + return rewriter.notifyMatchFailure(op, "Slice type is unknown"); + + int64_t dimInt = 0; + bool hasDim = matchPattern(op.getDim(), m_TorchConstantInt(&dimInt)); + if (dimInt < 0) + dimInt += rank; + + auto intTy = rewriter.getType(); + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + Value begin = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + + llvm::SmallVector slices; + llvm::SmallVector sliceSizes(sliceTy.getSizes()); + int64_t defaultLength = !hasDim ? Torch::kUnknownSize : sliceSizes[dimInt]; + for (auto size : splitSizes) { + Value end = rewriter.create(loc, intTy, begin, size); + + int64_t sizeInt; + if (hasDim && matchPattern(size, m_TorchConstantInt(&sizeInt))) { + sliceSizes[dimInt] = sizeInt; + } else if (hasDim) { + sliceSizes[dimInt] = defaultLength; + } + + sliceTy = rewriter.getType(sliceSizes, + sliceTy.getOptionalDtype()); + Value slice = rewriter.create( + loc, sliceTy, op.getSelf(), + /*dim=*/op.getDim(), /*start=*/begin, /*end=*/end, /*step=*/one); + slices.push_back(slice); + begin = end; + } + + rewriter.replaceOpWithNewOp(op, resultTy, slices); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenNarrowOp : public OpRewritePattern { public: @@ -7008,6 +7133,8 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); @@ -7035,6 +7162,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index abb019f8ed06..63c70e364e91 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -20,7 +20,8 @@ # 'linalg.depthwise_conv_2d_nchw_chw' op inferred input/output operand #1 has shape's dimension #0 to be 4, but found 8 "Conv2dWithPaddingDilationStrideStaticModule_depthwise_multiplier", "IscloseStaticModule_basic", - "IscloseStaticModuleTrue_basic" + "IscloseStaticModuleTrue_basic", + "SplitWithSizes_Module_basic", } TORCHDYNAMO_XFAIL_SET = { @@ -1478,15 +1479,6 @@ "VarBiasedModule_basic", "VarMeanBiasedModule_basic", - # Failure - constant int lowering - "SplitTensorGetItem_Module_basic", - "SplitTensorLastSmallerModule_basic", - "SplitTensorListUnpackModule_basic", - "SplitTensorNegativeDimModule_basic", - "SplitWithSizesListUnpackModule_basic", - "UnbindIntGetItem_Module_basic", - "UnbindIntListUnpack_Module_basic", - # Failure - incorrect numerics "AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic", "AdaptiveAvgPool2dUnitOutputSizeDynamicModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 3f6cf1b51dd4..ad030e97e8e1 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -741,6 +741,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::sort : (Tensor, int, bool) -> (Tensor, Tensor)", has_folder=True) emit("aten::split.Tensor : (Tensor, int, int) -> (Tensor[])") emit("aten::split_with_sizes : (Tensor, int[], int) -> (Tensor[])") + emit("aten::split.sizes : (Tensor, int[], int) -> (Tensor[])") emit("aten::unbind.int : (Tensor, int) -> (Tensor[])") emit("aten::chunk : (Tensor, int, int) -> (Tensor[])") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py index 8014758a7411..62d8948dbbba 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/slice_like.py @@ -897,3 +897,25 @@ def forward(self, x): @register_test_case(module_factory=lambda: ChunkListUnpackUnevenDynamic_Module()) def ChunkListUnpackUnevenDynamic_Module_basic(module, tu: TestUtils): module.forward(tu.rand(2, 13, 2)) + +# ============================================================================== + +class SplitWithSizes_Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([5, -1, -1], torch.float32, True), + ]) + def forward(self, x): + split = torch.split(x, [2, 1, 2], dim=0) + return split[0], split[1], split[2] + +@register_test_case(module_factory=lambda: SplitWithSizes_Module()) +def SplitWithSizes_Module_basic(module, tu: TestUtils): + module.forward(tu.rand(5, 2, 2)) + + + From aa7c9a965342116f09d991e20c6d6335a673f729 Mon Sep 17 00:00:00 2001 From: Ze Zhang Date: Tue, 5 Mar 2024 16:31:01 -0800 Subject: [PATCH 253/283] e2e support aten.linalg_norm to aten.linalg_vector_norm (#2953) Add e2d support for `aten.linalg_norm` by decompose it to `aten.linalg_vector_norm`. Lowering to `aten.linalg_matrix_norm` is still unsupported. To Test: `python -m e2e_testing.main -v` --------- Co-authored-by: Ze Zhang --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 27 +++++++++ .../Transforms/AbstractInterpLibrary.cpp | 59 +++++++++++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 32 ++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 3 + .../build_tools/abstract_interp_lib_gen.py | 26 ++++++++ .../build_tools/torch_ods_gen.py | 1 + .../test_suite/reduction.py | 36 +++++++++++ 8 files changed, 185 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index be294e97b053..dfb1b0382918 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -7938,6 +7938,33 @@ def Torch_AtenLinalgVectorNormOp : Torch_Op<"aten.linalg_vector_norm", [ }]; } +def Torch_AtenLinalgNormOp : Torch_Op<"aten.linalg_norm", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchOptionalScalarType:$ord, + AnyTorchOptionalListOfTorchIntType:$dim, + Torch_BoolType:$keepdim, + AnyTorchOptionalIntType:$dtype + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLinalgNormOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 5, 1); + } + void AtenLinalgNormOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 5, 1); + } + }]; +} + def Torch_AtenLinalgQrOp : Torch_Op<"aten.linalg_qr", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index f9ad383d1000..19c84617a2a1 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9336,6 +9336,11 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg2, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" " return %1 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.linalg_norm\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.list {\n" +" %0 = torch.derefine %arg4 : !torch.optional to !torch.any\n" +" %1 = call @__torch__.torch.jit._shape_functions.sum_mean_dim(%arg0, %arg2, %arg3, %0) : (!torch.list, !torch.optional>, !torch.bool, !torch.any) -> !torch.list\n" +" return %1 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.frobenius_norm.dim\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.bool) -> !torch.list {\n" " %int0 = torch.constant.int 0\n" " %0 = torch.derefine %arg1 : !torch.list to !torch.optional>\n" @@ -12058,6 +12063,60 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %4 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.linalg_norm\"(%arg0: !torch.tuple, %arg1: !torch.optional, %arg2: !torch.optional>, %arg3: !torch.bool, %arg4: !torch.optional) -> !torch.int {\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %2 = torch.aten.__not__ %1 : !torch.bool -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.__isnot__ %arg4, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" %5 = torch.prim.unchecked_cast %arg4 : !torch.optional -> !torch.int\n" +" %6 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%5) : (!torch.int) -> !torch.bool\n" +" %7 = torch.aten.__not__ %6 : !torch.bool -> !torch.bool\n" +" torch.prim.If %7 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %8 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%0#1) : (!torch.int) -> !torch.bool\n" +" %9 = torch.prim.If %8 -> (!torch.int) {\n" +" %10 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%5) : (!torch.int) -> !torch.bool\n" +" torch.prim.If %10 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %11 = torch.prim.TupleConstruct %0#0, %5 : !torch.int, !torch.int -> !torch.tuple\n" +" %12 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%11, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %12 : !torch.int\n" +" } else {\n" +" %10 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_complex_dtype(%5) : (!torch.int) -> !torch.bool\n" +" %11 = torch.aten.__not__ %10 : !torch.bool -> !torch.bool\n" +" torch.prim.If %11 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" torch.prim.If.yield %9 : !torch.int\n" +" } else {\n" +" %5 = func.call @\"__torch_mlir_dtype_fn.aten.std\"(%arg0, %true) : (!torch.tuple, !torch.bool) -> !torch.int\n" +" torch.prim.If.yield %5 : !torch.int\n" +" }\n" +" return %4 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.norm.Scalar\"(%arg0: !torch.tuple, %arg1: !torch.number) -> !torch.int {\n" " %true = torch.constant.bool true\n" " %int5 = torch.constant.int 5\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index b71e86c33568..b67c43076a8a 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -6971,6 +6971,37 @@ class DecomposeAtenReshapeAsOp : public OpRewritePattern { }; } // namespace +namespace { +// Decompose AtenLinalgNormOp to AtenLinalgVectorNormOp only +class DecomposeAtenLinalgNormOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLinalgNormOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + SmallVector dimList; + if (!getListConstructElements(op.getDim(), dimList)) { + return rewriter.notifyMatchFailure( + op, "dim should comes from a PrimListConstructOp"); + } + if (dimList.size() != 1) { + return rewriter.notifyMatchFailure( + op, "Unimplemented: only dim size of 1 is supported"); + } + + // default ord value is 2 for vector_norm + auto ord = op.getOrd(); + if (ord.getType().isa()) { + ord = rewriter.create(loc, rewriter.getI64IntegerAttr(2)); + } + rewriter.replaceOpWithNewOp( + op, op.getType(), op.getSelf(), ord, op.getDim(), op.getKeepdim(), + op.getDtype()); + return success(); + } +}; +} // namespace + namespace { class DecomposeComplexOpsPass : public DecomposeComplexOpsBase { @@ -7177,6 +7208,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); // More specific conv ops addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index e123f1dee4ac..f52c46789350 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -520,6 +520,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); for (auto &opName : backendLegalOpsSet) { target.addLegalOp( OperationName(kTorchOpPrefix + opName.first().str(), context)); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 63c70e364e91..894ede3ca953 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1112,6 +1112,7 @@ "LiftFreshCopyModule_basic", "LinalgVectorNormKeepDimModule_basic", "LinalgVectorNormModule_basic", + "LinalgNormKeepDimModule_basic", "MaskedFillScalarDefaultModule_basic", "MaskedFillScalarIntValueModule_basic", "MaskedFillScalarIntValueStaticModule_basic", @@ -1885,6 +1886,8 @@ "ScatterReduceIntSumModuleIncludeSelf", "TileBigDimsSizeModule_basic", "TileSmallDimsSizeModule_basic", + "LinalgNormKeepDimModule_basic", + "LinalgNormModule_basic", # Failure - onnx_lowering: onnx.AveragePool "AdaptiveAvgPool1dNonUnitOutputSizeStaticModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index a9bf5640d5e3..8ef43b0082b0 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1722,6 +1722,9 @@ def aten〇nonzero_static〡shape(self: List[int], size: int, fill_value: int = def aten〇linalg_vector_norm〡shape(self: List[int], ord: float = 2, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) +def aten〇linalg_norm〡shape(self: List[int], ord: Optional[float] = None, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> List[int]: + return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, dtype) + def aten〇frobenius_norm〇dim〡shape(self: List[int], dim: List[int], keepdim: bool = False) -> List[int]: return upstream_shape_functions.sum_mean_dim(self, dim, keepdim, 0) @@ -3938,6 +3941,29 @@ def aten〇linalg_vector_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Uni return dtype return aten〇std〡dtype(self_rank_dtype) +@check_dtype_function( + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64}) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.complex64, torch.complex128}, dtype=torch.float64) + + _check_tensors_with_the_same_dtype( + num_of_tensors=1, + error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64, torch.bfloat16, torch.float16, torch.float32, torch.float64}, dtype=torch.complex128) + + [ErrorInvocation(NonZeroDTensorWithDtype(torch.float32), dtype=torch.int32)]) +def aten〇linalg_norm〡dtype(self_rank_dtype: Tuple[int, int], ord: Optional[Union[int, float, complex]] = None, dim: Optional[List[int]] = None, keepdim: bool = False, dtype: Optional[int] = None) -> int: + self_rank, self_dtype = self_rank_dtype + assert not is_integer_dtype(self_dtype) + if dtype is not None: + assert not is_integer_dtype(dtype) + if is_complex_dtype(self_dtype): + assert is_complex_dtype(dtype) + return aten〇std〡dtype((self_rank, dtype)) + assert not is_complex_dtype(dtype) + return dtype + return aten〇std〡dtype(self_rank_dtype) + @check_dtype_function( _check_tensors_with_the_same_dtype( num_of_tensors=1, diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index ad030e97e8e1..2b0ec4aee1cb 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -542,6 +542,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::nll_loss_backward : (Tensor, Tensor, Tensor, Tensor?, int, int, Tensor) -> (Tensor)") emit("aten::bincount : (Tensor, Tensor?, int) -> (Tensor)") emit("aten::linalg_vector_norm : (Tensor, Scalar, int[]?, bool, int?) -> (Tensor)") + emit("aten::linalg_norm : (Tensor, Scalar?, int[]?, bool, int?) -> (Tensor)") emit("aten::linalg_qr : (Tensor, str) -> (Tensor, Tensor)") emit("aten::frobenius_norm.dim : (Tensor, int[], bool) -> (Tensor)") emit("aten::mse_loss : (Tensor, Tensor, int) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py index d0d6c2ea2dfa..9f6f358735c3 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/reduction.py @@ -1228,6 +1228,42 @@ def LinalgVectorNormKeepDimModule_basic(module, tu: TestUtils): # ============================================================================== +class LinalgNormModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.linalg_norm(a, ord=None, dim=[0], keepdim=False) + +@register_test_case(module_factory=lambda: LinalgNormModule()) +def LinalgNormModule_basic(module, tu: TestUtils): + module.forward(torch.rand(3, 4, 5)) + +# ============================================================================== + +class LinalgNormKeepDimModule(torch.nn.Module): + def __init__(self) -> None: + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1], torch.float32, True), + ]) + def forward(self, a): + return torch.ops.aten.linalg_norm(a, ord=None, dim=[0], keepdim=True) + +@register_test_case(module_factory=lambda: LinalgNormKeepDimModule()) +def LinalgNormKeepDimModule_basic(module, tu: TestUtils): + module.forward(torch.rand(3, 4, 5)) + +# ============================================================================== + class MseLossNoReductionModule(torch.nn.Module): def __init__(self): super().__init__() From 06292d9429e4f0052fc0a15cc548b95acd154651 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 6 Mar 2024 10:19:18 -0800 Subject: [PATCH 254/283] [torch] Rework `aten.repeat` to use flatten and unsqueeze (#2984) Current implementation depends on using `aten.view` which has issues inferring tensor collapse/expand operations during the lowering to `linalg`. Using flatten and unsqueeze better infers what the later reshape behavior. --- .../Torch/Transforms/DecomposeComplexOps.cpp | 184 +++++++++--------- projects/pt1/e2e_testing/xfail_sets.py | 1 - 2 files changed, 89 insertions(+), 96 deletions(-) diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index b67c43076a8a..89fc2f0372a4 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -2398,31 +2398,9 @@ class DecomposeAtenRollOp : public OpRewritePattern { }; } // namespace -// Decompose aten.repeat into aten.expand and aten.view ops. +// Decompose aten.repeat into aten.squeeze, aten.unsqueeze, and aten.broadcast. // // Ref: https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html -// -// For shape [S1, S2, S3] and repeats [M0, M1, M2, M3] -// MS0 = M0; MS1 = M1 * S1; MS2 = M2 * S2; MS3 = M3 * S3 -// -// def aten_repeat(self, repeats): -// sizes = self.size() -// unsqueezed_sizes = [] -// expanded_sizes = [] -// reshape_sizes = [] -// leading_rank = repeats.size() - sizes.size() -// for r in range(leading_rank): -// unsqueezed_sizes.append(1) -// expanded_sizes.append(repeats[r]) -// reshaped_sizes.append(repeats[r]) -// -// for s, m in zip(sizes, repeats[leading_rank:]): -// unsqueezed_sizes += [1, s] -// expanded_sizes += [m, s] -// reshaped_sizes += [m * s] -// return -// self.view(unsqueezed_sizes).expand(expanded_sizes).view(reshaped_sizes) -// namespace { class DecomposeAtenRepeatOp : public OpRewritePattern { public: @@ -2431,94 +2409,110 @@ class DecomposeAtenRepeatOp : public OpRewritePattern { PatternRewriter &rewriter) const override { Location loc = op.getLoc(); Value self = op.getSelf(); - MLIRContext *context = op.getContext(); - std::optional maybeRank = getTensorRank(self); - if (!maybeRank) - return rewriter.notifyMatchFailure(op, "Unimplemented: unranked tensor"); - unsigned rank = *maybeRank; + auto selfTy = cast(self.getType()); + if (!selfTy.hasSizes()) + return rewriter.notifyMatchFailure( + op, "Unimplemented: no implementation for rankless tensor"); SmallVector repeats; if (!getListConstructElements(op.getRepeats(), repeats)) return rewriter.notifyMatchFailure( op, "Unimplemented: repeats not list of Scalar"); - if (rank > repeats.size()) { + int64_t rank = selfTy.getSizes().size(); + if (rank > static_cast(repeats.size())) { return rewriter.notifyMatchFailure( op, "repeats are not matched with self's rank"); } - auto insertDimSizes = [](SmallVector &dimSizes, - SmallVector &shape, - const ArrayRef &vals) { - dimSizes.insert(dimSizes.end(), vals.begin(), vals.end()); - std::transform(vals.begin(), vals.end(), std::back_inserter(shape), - [&](Value val) -> int64_t { - int64_t cst_val; - if (matchPattern(val, m_TorchConstantInt(&cst_val))) { - return cst_val; - } else { - return kUnknownSize; - } - }); - }; + int64_t repeatSz = repeats.size(); + int64_t batch = repeatSz - rank; - Value one = rewriter.create( - loc, rewriter.getI64IntegerAttr(1)); + if (!selfTy.hasSizes()) + return rewriter.notifyMatchFailure(op, "input sizes unknown"); - SmallVector unsqueezedSizes, expandedSizes, reshapedSizes; - SmallVector unsqueezedIntSizes, expandedIntSizes; - assert(repeats.size() >= rank && "leadingRank should greater than 0"); - auto leadingRank = repeats.size() - rank; - for (size_t i = 0; i < leadingRank; ++i) { - insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, ArrayRef{one}); - insertDimSizes(expandedSizes, expandedIntSizes, - ArrayRef{repeats[i]}); - reshapedSizes.push_back(repeats[i]); + // Materialize out 1 dimensions to broadcast along. This includes + // materializing out preceding batch dimensions: + for (int i = 0; i < repeatSz; ++i) { + auto oldSizes = selfTy.getSizes(); + llvm::SmallVector sizes; + int64_t squeezeDim = i < batch ? i : i * 2 - batch; + + for (int j = 0; j < squeezeDim; ++j) + sizes.push_back(oldSizes[j]); + sizes.push_back(1); + for (int j = squeezeDim, s = oldSizes.size(); j < s; j++) + sizes.push_back(oldSizes[j]); + + Value dim = rewriter.create(loc, squeezeDim); + selfTy = + rewriter.getType(sizes, selfTy.getOptionalDtype()); + self = rewriter.create(loc, selfTy, self, dim); } - auto selfType = self.getType().dyn_cast(); - auto selfShape = selfType.getSizes(); - for (unsigned i = 0; i < rank; i++) { - auto scale = repeats[i + leadingRank]; - Value dimSize; - if (selfShape[i] == kUnknownSize) { - Value dim = rewriter.create( - loc, rewriter.getI64IntegerAttr(i)); - dimSize = rewriter.create(loc, self, dim); - } else { - dimSize = rewriter.create( - loc, rewriter.getI64IntegerAttr(selfShape[i])); + llvm::SmallVector lengths; + for (int i = 0; i < repeatSz; ++i) { + if (i < batch) { + lengths.push_back(repeats[i]); + continue; } - insertDimSizes(unsqueezedSizes, unsqueezedIntSizes, - ArrayRef{one, dimSize}); - insertDimSizes(expandedSizes, expandedIntSizes, - ArrayRef{scale, dimSize}); - - Value scaledSize = rewriter.create(loc, dimSize, scale); - reshapedSizes.push_back(scaledSize); - } - - Type dtype = self.getType().cast().getOptionalDtype(); - Type unsqueezedType = ValueTensorType::get( - context, llvm::ArrayRef(unsqueezedIntSizes), dtype); - Type expandedType = - ValueTensorType::get(context, llvm::ArrayRef(expandedIntSizes), dtype); - - auto listType = Torch::ListType::get(Torch::IntType::get(op.getContext())); - Value unsqueezedDims = - rewriter.create(loc, listType, unsqueezedSizes); - Value expandedDims = - rewriter.create(loc, listType, expandedSizes); - Value reshapedDims = - rewriter.create(loc, listType, reshapedSizes); - auto reshaped = rewriter.create(loc, unsqueezedType, - op.getSelf(), unsqueezedDims); - auto expanded = rewriter.create(loc, expandedType, - reshaped, expandedDims); - - rewriter.replaceOpWithNewOp(op, op.getType(), expanded, - reshapedDims); + Value iv = rewriter.create( + loc, rewriter.getI64IntegerAttr(i * 2 + 1 - batch)); + Value dim = rewriter.create(loc, self, /*dim=*/iv); + lengths.push_back(repeats[i]); + lengths.push_back(dim); + } + + Value lengthv = rewriter.create( + loc, ListType::get(rewriter.getType()), lengths); + + llvm::SmallVector expandShape(selfTy.getSizes()); + for (int i = 0; i < repeatSz; ++i) { + int64_t repeatDim = i < batch ? i : i * 2 - batch; + int64_t repeat; + if (!matchPattern(repeats[i], m_TorchConstantInt(&repeat))) + repeat = Torch::kUnknownSize; + expandShape[repeatDim] = repeat; + } + + auto mulDim = [](int64_t lhs, int64_t rhs) { + if (lhs == Torch::kUnknownSize || rhs == Torch::kUnknownSize) + return Torch::kUnknownSize; + return lhs * rhs; + }; + + BaseTensorType expandTy = rewriter.getType( + expandShape, selfTy.getOptionalDtype()); + Value expand = + rewriter.create(loc, expandTy, self, lengthv); + + for (int i = 0; i < rank; ++i) { + auto oldShape = expandTy.getSizes(); + llvm::SmallVector newShape; + int64_t flattenDim = i + batch; + for (int j = 0; j < flattenDim; ++j) + newShape.push_back(oldShape[j]); + newShape.push_back( + mulDim(oldShape[flattenDim], oldShape[flattenDim + 1])); + for (int j = flattenDim + 2, s = oldShape.size(); j < s; ++j) + newShape.push_back(oldShape[j]); + + expandTy = rewriter.getType(newShape, + expandTy.getOptionalDtype()); + + // Used to keep the return type the same on the last flatten: + expandTy = i < rank - 1 ? expandTy : cast(op.getType()); + + Value start = rewriter.create( + loc, rewriter.getI64IntegerAttr(flattenDim)); + Value end = rewriter.create( + loc, rewriter.getI64IntegerAttr(flattenDim + 1)); + expand = rewriter.create(loc, expandTy, expand, + start, end); + } + + rewriter.replaceOp(op, expand); return success(); } }; diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 894ede3ca953..7218bd4c945d 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2142,7 +2142,6 @@ "IndexTensorMultiInputThreeIndexers_basic", "IndexTensorMultiInput_basic", "IndexTensorStaticContiguousWithNoneModule_basic", - "RepeatModule_basic", "SelectIntModule_basic", "SliceSingleIdxModule_basic", "ViewFlattenAndExpandModule_basic", From ea76dd12ba08abf3cdfa02b74f2d2bde37ab0529 Mon Sep 17 00:00:00 2001 From: Andreas Falkenberg <149819731+afalkenberg1@users.noreply.github.com> Date: Wed, 6 Mar 2024 10:56:58 -0800 Subject: [PATCH 255/283] [onnx][torch] Gridsampler E2E test and corrections of gridsampler (#2987) The addition of an e2e test is actually provided in the Shark-Testsuite. This adds 2 test cases for the gridsampler e2e test. Also as intended there were some items found which needed correction, so the Gridsampler op has also a change. --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 6 +- .../TorchToLinalg/Uncategorized.cpp | 22 +++--- projects/pt1/e2e_testing/xfail_sets.py | 4 ++ .../test_suite/__init__.py | 1 + .../test_suite/gridsampler.py | 71 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 4 +- 6 files changed, 94 insertions(+), 14 deletions(-) create mode 100644 projects/pt1/python/torch_mlir_e2e_test/test_suite/gridsampler.py diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 6855dcc3df33..01b47bc56cbc 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -93,7 +93,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return success(); }); patterns.onOp( - "GridSample", 20, + "GridSample", 17, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value input; @@ -140,9 +140,9 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( if (binder.s64IntegerAttr(align, "align_corners", 0)) return rewriter.notifyMatchFailure(binder.op, "align_corners bind failure"); - if (align != 0) + if (align != 1) return rewriter.notifyMatchFailure( - binder.op, "currently only align_corners : 0 supported"); + binder.op, "currently only align_corners = 1 supported"); Value interpolationMode = rewriter.create( binder.getLoc(), rewriter.getType(), diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 8b4297a62e17..86bc4578178f 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -2514,19 +2514,23 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { auto resultType = getTypeConverter() ->convertType(op.getResult().getType()) .cast(); - llvm::SmallVector resultSize{ - rewriter.create(loc, input, 0), - rewriter.create(loc, input, 1), - rewriter.create(loc, grid, 1), - rewriter.create(loc, grid, 2)}; + SmallVector resultSize{}; + if (resultType.isDynamicDim(0)) + resultSize.push_back(rewriter.create(loc, input, 0)); + if (resultType.isDynamicDim(1)) + resultSize.push_back(rewriter.create(loc, input, 1)); + if (resultType.isDynamicDim(2)) + resultSize.push_back(rewriter.create(loc, grid, 1)); + if (resultType.isDynamicDim(3)) + resultSize.push_back(rewriter.create(loc, grid, 2)); Value resultFinal = rewriter.create(loc, resultType, resultSize); auto sGrid = rewriter.create( loc, TypeRange{resultType}, ValueRange{gridCollapsed0, gridCollapsed1}, ValueRange(resultFinal), gridMaps, gridIterators, [&](OpBuilder &b, Location loc, ValueRange args) { - Value gr0 = args[0]; - Value gr1 = args[1]; + Value gr0 = args[1]; + Value gr1 = args[0]; Value gplus0 = b.create(loc, gr0, oneFloat); Value gplus1 = b.create(loc, gr1, oneFloat); Value result0 = b.create(loc, gplus0, innerDim0e); @@ -2571,8 +2575,8 @@ class ConvertAtenGridSamplerOp : public OpConversionPattern { b.create(loc, notValid1, zeroFloat, result11a); Value lw0a = b.create(loc, floatType, lower0); Value lw1a = b.create(loc, floatType, lower1); - Value d0 = b.create(loc, result0, lw0a); - Value d1 = b.create(loc, result1, lw1a); + Value d1 = b.create(loc, result0, lw0a); + Value d0 = b.create(loc, result1, lw1a); Value resultScaled0 = lambdaInter(b, loc, result00, result01a, d0); Value resultScaled1 = lambdaInter(b, loc, result10a, result11b, d0); Value resultScaled = diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 7218bd4c945d..2570581befb2 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -331,6 +331,10 @@ "FloatImplicitModule_basic", "IntImplicitModule_basic", + + # Others + "GridSamplerBasic1_basic", + "GridSamplerBasic2_basic", } TORCHDYNAMO_CRASHING_SET = { diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py index 10130a73fe85..c4d21ea08eaa 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/__init__.py @@ -61,3 +61,4 @@ def register_all_tests(): from . import stats from . import padding from . import diagonal + from . import gridsampler diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/gridsampler.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/gridsampler.py new file mode 100644 index 000000000000..2960041bdc68 --- /dev/null +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/gridsampler.py @@ -0,0 +1,71 @@ +# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +# See https://llvm.org/LICENSE.txt for license information. +# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +# Also available under a BSD-style license. See LICENSE. + +import torch + +from torch_mlir_e2e_test.framework import TestUtils +from torch_mlir_e2e_test.registry import register_test_case +from torch_mlir_e2e_test.annotations import annotate_args, export + +# ============================================================================== + +class GridSamplerBasic1(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([7, 8, 12, 4], torch.float32, True), + ([7, 11, 13, 2], torch.float32, True) + ]) + def forward(self, x, g): + interpolation_mode=0, + padding_mode=0, + align_corners=True, + tRes = torch.ops.aten.grid_sampler(x, g, interpolation_mode[0], + padding_mode[0], align_corners[0]) + return tRes + +@register_test_case( + module_factory=lambda: GridSamplerBasic1()) +def GridSamplerBasic1_basic( + module, tu: TestUtils): + inp = torch.rand(7,8,12,4) + grd = torch.rand(7,11,13,2)*2-1 + module.forward(inp, grd) + + +class GridSamplerBasic2(torch.nn.Module): + + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([1, 1, 4, 4], torch.float32, True), + ([1, 1, 3, 2], torch.float32, True) + ]) + def forward(self, x, g): + interpolation_mode=0, + padding_mode=0, + align_corners=True, + tRes = torch.ops.aten.grid_sampler(x, g, interpolation_mode[0], + padding_mode[0], align_corners[0]) + return tRes + +@register_test_case( + module_factory=lambda: GridSamplerBasic2()) +def GridSamplerBasic2_basic( + module, tu: TestUtils): + inp = torch.tensor([[[[0.4963, 0.7682, 0.0885, 0.1320], + [0.3074, 0.6341, 0.4901, 0.8964], + [0.4556, 0.6323, 0.3489, 0.4017], + [0.0223, 0.1689, 0.2939, 0.5185]]]]).type(torch.FloatTensor) + grd = torch.tensor([[[[-0.3498, -0.8196],[-0.2127, 0.2138],[-0.6515, -0.0513]]]]).type(torch.FloatTensor) + module.forward(inp, grd) + diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index ee93f13e2c40..493ec8cebfdb 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -427,8 +427,8 @@ func.func @test_gelu_tanh_2(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtenso // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 // CHECK: %[[B0:.*]] = torch.constant.bool false // CHECK: %[[A0:.*]] = torch.aten.grid_sampler %arg0, %arg1, %[[INT0]], %[[INT0_0]], %[[B0]] : !torch.vtensor<[5,10,10,4],f32> -func.func @test_grid_sampler(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - %4 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 0 : si64, torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> +func.func @test_grid_sampler(%arg0: !torch.vtensor<[5,10,10,4],f32>, %arg1: !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 17 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %4 = torch.operator "onnx.GridSample" (%arg0, %arg1) {torch.onnx.align_corners = 1 : si64, torch.onnx.padding_mode = "zeros"} : (!torch.vtensor<[5,10,10,4],f32>, !torch.vtensor<[5,7,8,2],f32>) -> !torch.vtensor<[?,?,?,?],f32> return %4 : !torch.vtensor<[?,?,?,?],f32> } From a78659742a20d1d99cd71c860e212a9156f03a25 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 6 Mar 2024 16:48:21 -0800 Subject: [PATCH 256/283] [onnx] Migrate `onnx.ReduceMax` to match `onnx.ReduceMin` (#2981) This mostly copy-pastes the reduce minimum implementation to reduce max to improve test coverage. We also improve the aten lowering for min/max dim for unsigned types. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 254 +++++++++++------- lib/Conversion/TorchToLinalg/Reduction.cpp | 37 ++- .../Torch/Transforms/DecomposeComplexOps.cpp | 4 +- projects/pt1/e2e_testing/xfail_sets.py | 15 +- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 166 ++++++++---- 5 files changed, 293 insertions(+), 183 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 2b08630705c4..df3449939138 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -758,107 +758,6 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( binder.op, resultType, operand, vAlpha, vScale, vInputScale); return success(); }); - patterns.onOp( - "ReduceMax", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - SmallVector operands; - int64_t keepDims, noop_with_empty_axes; - - if (binder.tensorOperandsList(operands) || - binder.tensorResultType(resultType) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", - 0)) - return failure(); - Value data = operands[0]; - - if (operands.size() == 1) { - if (noop_with_empty_axes == 0) { - MLIRContext *context = binder.op->getContext(); - int rank = - data.getType().cast().getSizes().size(); - SmallVector dims; - for (int i = 0; i < rank; i++) { - dims.push_back(rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(i))); - } - Value dimsList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(context)), dims); - Value keepDimsConstInt = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); - Value keepDimsBool = rewriter.create( - binder.getLoc(), keepDimsConstInt); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, /*dim=*/dimsList, - /*keepdim=*/keepDimsBool); - } else { - rewriter.replaceOp(binder.op, data); - } - return success(); - } - - Value axes = operands[1]; - - SmallVector dimList; - - Torch::BaseTensorType axesType = - axes.getType().cast(); - SmallVector selectSizes; - selectSizes.push_back(1); - Type selectResultType = axesType.getWithSizesAndDtype( - llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); - auto sizes = - dyn_cast(axes.getType()).getSizes(); - - Value zero = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); - int64_t adjustmentInt = - cast(data.getType()).getSizes().size(); - Value adjustment = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - adjustmentInt)); - for (int i = 0; i < sizes[0]; i++) { - // Go through the axes list and get each dim in the list - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, axes, zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); - // deal with neg axis: if (axis < 0) axis += rank - Value isNegative = - rewriter.create(binder.getLoc(), dim, zero); - isNegative = rewriter.create(binder.getLoc(), - isNegative); - Value finalOffset = rewriter.create( - binder.getLoc(), isNegative, adjustment); - Value finalDim = rewriter.create( - binder.getLoc(), dim, finalOffset); - dimList.push_back(finalDim); - } - - Value dimValueList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - dimList); - - Value keepDimBool; - if (keepDims == 1) { - keepDimBool = - rewriter.create(binder.getLoc(), true); - } else { - keepDimBool = - rewriter.create(binder.getLoc(), false); - } - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, dimValueList, keepDimBool); - return success(); - }); patterns.onOp( "ReduceSum", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { @@ -1102,6 +1001,159 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*dtype=*/noneVal); return success(); }); + patterns.onOp( + "ReduceMax", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // AtenAmaxOp allows us to pass a list of dims + Torch::ValueTensorType resultType; + Value data; + Value axes; + int64_t keepDims; + int64_t noop_with_empty_axes; + if (binder.tensorOperandAtIndex(data, 0) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(keepDims, "keepdims", 1) || + binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", + 0)) + return failure(); + + auto dataTy = cast(data.getType()); + Torch::IntType torchIntTy = rewriter.getType(); + + // If any of the input dims are 0 we set to the upper limit: + if (llvm::any_of(dataTy.getSizes(), [](int64_t d) { return d == 0; }) && + (llvm::any_of(dataTy.getSizes(), + [](int64_t d) { return d == Torch::kUnknownSize; }) || + keepDims)) { + auto dty = dataTy.getDtype(); + Value scalar; + if (FloatType fpTy = dyn_cast(dty)) { + auto inf = APFloat::getInf(fpTy.getFloatSemantics()); + scalar = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getFloatAttr(rewriter.getF64Type(), + inf.convertToDouble())); + } + + if (IntegerType intTy = dyn_cast(dty)) { + auto mx = + intTy.isSigned() + ? APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth()) + : APInt::getMaxValue(intTy.getIntOrFloatBitWidth()); + scalar = rewriter.create( + binder.getLoc(), torchIntTy, + rewriter.getIntegerAttr(rewriter.getIntegerType(64), + mx.getSExtValue())); + } + + llvm::SmallVector fillDims; + for (int i = 0, s = resultType.getSizes().size(); i < s; ++i) { + auto staticDim = resultType.getSizes()[i]; + if (staticDim != Torch::kUnknownSize) { + fillDims.push_back(rewriter.create( + binder.getLoc(), torchIntTy, + rewriter.getI64IntegerAttr(staticDim))); + continue; + } + + Value iv = rewriter.create( + binder.getLoc(), torchIntTy, rewriter.getI64IntegerAttr(i)); + fillDims.push_back(rewriter.create( + binder.getLoc(), torchIntTy, data, iv)); + } + + Value none = rewriter.create(binder.getLoc()); + Value fillDimsList = rewriter.create( + binder.getLoc(), Torch::ListType::get(torchIntTy), fillDims); + rewriter.replaceOpWithNewOp( + binder.op, resultType, fillDimsList, scalar, none, none, none, + none); + return success(); + } + + // Previous version of the operation had the axes as an attribute: + SmallVector axesList; + llvm::SmallVector axesAttr; + if (!binder.s64IntegerArrayAttr(axesAttr, "axes", {})) { + for (int i = 0, s = axesAttr.size(); i < s; ++i) { + axesList.push_back(rewriter.create( + binder.getLoc(), torchIntTy, + rewriter.getI64IntegerAttr(axesAttr[i]))); + } + } + + // Extract the axes values from the axes operand: + if (!binder.tensorOperandAtIndex(axes, 1)) { + Torch::BaseTensorType axesType = + axes.getType().cast(); + SmallVector selectSizes{1}; + Type selectResultType = axesType.getWithSizesAndDtype( + selectSizes, axesType.getOptionalDtype()); + auto sizes = axesType.getSizes(); + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + + // Extract the value of each axes: + for (int i = 0; i < sizes[0]; i++) { + // Go through the axes list and get each dim in the list + Value selectIndex = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + Value extract = rewriter.create( + binder.getLoc(), selectResultType, axes, zero, selectIndex); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + axesList.push_back(dim); + } + } + + // Handle the noop case: + if (axesList.empty() && noop_with_empty_axes) { + rewriter.replaceOp(binder.op, data); + return success(); + } + + // Deal with case when no axes arg is passed but not a noop: + if (axesList.empty()) { + int64_t numDims = dyn_cast(data.getType()) + .getSizes() + .size(); + for (int i = 0; i < numDims; i++) { + Value curr = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); + axesList.push_back(curr); + } + } + + // Handle negative axis: + Value rankVal = rewriter.create(binder.getLoc(), + torchIntTy, data); + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(0)); + for (Value &axes : axesList) { + Value isNegative = + rewriter.create(binder.getLoc(), axes, zero); + isNegative = rewriter.create(binder.getLoc(), + isNegative); + Value finalOffset = rewriter.create( + binder.getLoc(), isNegative, rankVal); + axes = rewriter.create(binder.getLoc(), axes, + finalOffset); + } + + Value dimValueList = rewriter.create( + binder.getLoc(), Torch::ListType::get(torchIntTy), axesList); + Value keepDimBool = + rewriter.create(binder.getLoc(), keepDims); + rewriter.replaceOpWithNewOp( + binder.op, resultType, data, dimValueList, keepDimBool); + return success(); + }); + patterns.onOp( "ReduceMin", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/lib/Conversion/TorchToLinalg/Reduction.cpp b/lib/Conversion/TorchToLinalg/Reduction.cpp index 92f50523c764..952610c5404d 100644 --- a/lib/Conversion/TorchToLinalg/Reduction.cpp +++ b/lib/Conversion/TorchToLinalg/Reduction.cpp @@ -87,6 +87,7 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { return rewriter.notifyMatchFailure(op, "dim is not a valid dim"); Type inElementType = inputType.getElementType(); + bool isUnsigned = false; if (!inElementType.isa()) { if (inElementType.isa()) { auto integerTy = op.getSelf() @@ -94,10 +95,7 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { .template cast() .getDtype() .template dyn_cast(); - if (integerTy.isUnsigned()) - return rewriter.notifyMatchFailure( - op, opName + " to linalg.* requires input element type " - "to be signed in case of integer"); + isUnsigned = integerTy.isUnsigned(); } else { return rewriter.notifyMatchFailure( op, opName + " to linalg.* requires Float or Integer " @@ -130,12 +128,17 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { APFloat::getInf( inElementType.cast().getFloatSemantics(), /*Negative=*/isMax))); - } else { + } else if (!isUnsigned) { auto width = inElementType.cast().getWidth(); auto init = isMax ? APSInt::getSignedMinValue(width) : APSInt::getSignedMaxValue(width); fillValue = rewriter.create( loc, rewriter.getIntegerAttr(inElementType, init)); + } else if (isUnsigned) { + auto width = inElementType.cast().getWidth(); + auto init = isMax ? APInt::getMinValue(width) : APInt::getMaxValue(width); + fillValue = rewriter.create( + loc, rewriter.getIntegerAttr(inElementType, init)); } Value filledTensorVal = @@ -193,13 +196,25 @@ class ConvertAtenMinMaxDimOp : public OpConversionPattern { } else { arith::CmpIPredicate predType; if (isMax) { - predType = arith::CmpIPredicate::sgt; - resultVal = rewriter.create(nestedLoc, newValue, - oldValue); + predType = isUnsigned ? arith::CmpIPredicate::ugt + : arith::CmpIPredicate::sgt; + if (isUnsigned) { + resultVal = rewriter.create(nestedLoc, newValue, + oldValue); + } else { + resultVal = rewriter.create(nestedLoc, newValue, + oldValue); + } } else { - predType = arith::CmpIPredicate::slt; - resultVal = rewriter.create(nestedLoc, newValue, - oldValue); + predType = isUnsigned ? arith::CmpIPredicate::ult + : arith::CmpIPredicate::slt; + if (isUnsigned) { + resultVal = rewriter.create(nestedLoc, newValue, + oldValue); + } else { + resultVal = rewriter.create(nestedLoc, newValue, + oldValue); + } } predicate = rewriter.create(nestedLoc, predType, newValue, oldValue); diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 89fc2f0372a4..157d6f227ae1 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -71,8 +71,8 @@ static Type computeReductionType(PatternRewriter &rewriter, Operation *op, } Type resultType = tensorType.getWithSizesAndDtype( - sizes.size() == 0 ? std::optional>() - : llvm::ArrayRef(sizes), + !tensorType.hasSizes() ? std::optional>() + : llvm::ArrayRef(sizes), tensorType.getOptionalDtype()); return resultType; } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 2570581befb2..2229118181a5 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1515,9 +1515,6 @@ "BroadcastToModule_basic", "ExpandModule_basic", "MoveDimIntNegativeIndexModule_basic", - "ReduceAmaxKeepDim_basic", - "ReduceMaxKeepDimReturnBoth_basic", - "ReduceMaxNegativeDim_basic", "ViewSizeFromOtherTensor_basic", # Failure - onnx_export @@ -2122,18 +2119,8 @@ "TriuBroadcastModule_basic", "TriuModule_basic", - # Failure - rankless return - "ReduceAmaxMultiDim_basic", - "ReduceAmaxOutOfOrderDim_basic", - "ReduceAmaxSingleDim_basic", - "ReduceMaxAllDims_basic", - "ReduceMaxAlongDimNegative_basic", - "ReduceMaxAlongDimSignedInt_basic", + # Failure - incorrect dtype "ReduceMaxAlongDimUnsignedInt_basic", - "ReduceMaxAlongDim_basic", - "ReduceMaxFloatModule_basic", - "ReduceMaxSignedIntModule_basic", - "ReduceMaxUnsignedIntModule_basic", # Failure - torch.aten.view lower "IndexTensorDyanmicInputContiguousWithNoneModule_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 11c9e4e6e5da..977c557739b5 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -747,65 +747,121 @@ func.func @test_selu(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, // ----- -// CHECK-LABEL: func.func @test_reduce_max_keepdims_example -func.func @test_reduce_max_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.*]] = torch.constant.int 0 - // CHECK: %[[RANK:.*]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: %[[SELECT_DIM0:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ITEM0:.*]] = torch.aten.item %[[SELECT_DIM0]] : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[LTZERO_0:.*]] = torch.aten.lt.int %[[ITEM0]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[ISNEG_0:.*]] = torch.aten.Int.bool %[[LTZERO_0]] : !torch.bool -> !torch.int - // CHECK: %[[ADJUSTMENT_0:.*]] = torch.aten.mul.int %[[ISNEG_0]], %[[RANK]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[FINAL_0:.*]] = torch.aten.add.int %[[ITEM0]], %[[ADJUSTMENT_0]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[INT1:.*]] = torch.constant.int 1 - // CHECK: %[[SELECT_DIM1:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[2],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ITEM1:.*]] = torch.aten.item %[[SELECT_DIM1]] : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[LTZERO_1:.*]] = torch.aten.lt.int %[[ITEM1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[ISNEG_1:.*]] = torch.aten.Int.bool %[[LTZERO_1]] : !torch.bool -> !torch.int - // CHECK: %[[ADJUSTMENT_1:.*]] = torch.aten.mul.int %[[ISNEG_1]], %[[RANK]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[FINAL_1:.*]] = torch.aten.add.int %[[ITEM1]], %[[ADJUSTMENT_1]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[FINAL_0]], %[[FINAL_1]] : (!torch.int, !torch.int) -> !torch.list - // CHECK: %[[KEEPDIMS:.*]] = torch.constant.bool true - // CHECK: torch.aten.amax %arg0, %[[DIMS]], %[[KEEPDIMS]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[3,1,1],f32> - %0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[2],si64>) -> !torch.vtensor<[3,1,1],f32> - return %0 : !torch.vtensor<[3,1,1],f32> - } +// CHECK-LABEL: func.func @test_reduce_max_empty_set_fp +func.func @test_reduce_max_empty_set_fp(%arg0: !torch.vtensor<[2,0,4],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[INF:.+]] = torch.constant.float 0x7FF0000000000000 + // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 + // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT1]], %[[INT4]] + // CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[INF]], %[[NONE]], %[[NONE]], %[[NONE]] + // CHECK: return %[[FULL]] + %0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],f32> + return %0 : !torch.vtensor<[2,1,4],f32> +} // ----- -// CHECK-LABEL: func.func @test_reduce_max_default_axes_keepdim_example -func.func @test_reduce_max_default_axes_keepdim_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.*]] = torch.constant.int 0 - // CHECK: %[[INT1:.*]] = torch.constant.int 1 - // CHECK: %[[INT2:.*]] = torch.constant.int 2 - // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[INT0]], %[[INT1]], %[[INT2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %[[INT1_0:.*]] = torch.constant.int 1 - // CHECK: %[[KEEPDIMS:.*]] = torch.aten.Bool.int %[[INT1_0]] : !torch.int -> !torch.bool - // CHECK: torch.aten.amax %arg0, %[[DIMS]], %[[KEEPDIMS]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[1,1,1],f32> - %0 = torch.operator "onnx.ReduceMax"(%arg0) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[1,1,1],f32> - return %0 : !torch.vtensor<[1,1,1],f32> - } +// CHECK-LABEL: func.func @test_reduce_max_empty_set_int +func.func @test_reduce_max_empty_set_int(%arg0: !torch.vtensor<[2,0,4],si32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK-DAG: %[[INF:.+]] = torch.constant.int 2147483647 + // CHECK-DAG: %[[INT2:.+]] = torch.constant.int 2 + // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[INT4:.+]] = torch.constant.int 4 + // CHECK-DAG: %[[NONE:.+]] = torch.constant.none + // CHECK-DAG: %[[LIST:.+]] = torch.prim.ListConstruct %[[INT2]], %[[INT1]], %[[INT4]] + // CHECK-DAG: %[[FULL:.+]] = torch.aten.full %[[LIST]], %[[INF]], %[[NONE]], %[[NONE]], %[[NONE]] + // CHECK: return %[[FULL]] + %0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[2,0,4],si32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[2,1,4],si32> + return %0 : !torch.vtensor<[2,1,4],si32> +} // ----- -// CHECK-LABEL: func.func @test_reduce_max_do_not_keepdims_example - func.func @test_reduce_max_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.*]] = torch.constant.int 0 - // CHECK: %[[RANK:.*]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.*]] = torch.constant.int 0 - // CHECK: %[[SELECT_DIM:.*]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: %[[ITEM:.*]] = torch.aten.item %[[SELECT_DIM]] : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: %[[LTZERO:.*]] = torch.aten.lt.int %[[ITEM]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool - // CHECK: %[[ISNEG:.*]] = torch.aten.Int.bool %[[LTZERO]] : !torch.bool -> !torch.int - // CHECK: %[[ADJUSTMENT:.*]] = torch.aten.mul.int %[[ISNEG]], %[[RANK]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[FINAL:.*]] = torch.aten.add.int %[[ITEM]], %[[ADJUSTMENT]] : !torch.int, !torch.int -> !torch.int - // CHECK: %[[DIMS:.*]] = torch.prim.ListConstruct %[[FINAL]] : (!torch.int) -> !torch.list - // CHECK: %[[FALSE:.*]] = torch.constant.bool false - // CHECK: torch.aten.amax %arg0, %[[DIMS]], %[[FALSE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool -> !torch.vtensor<[3,2],f32> - %0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> - return %0 : !torch.vtensor<[3,2],f32> - } +// CHECK-LABEL: func.func @test_reduce_max_bool_inputs +func.func @test_reduce_max_bool_inputs(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[IDX:.+]] = torch.constant.int 0 + // CHECK: %[[SZ:.+]] = torch.constant.int 0 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] + // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[C0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LST]], %[[TRUE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4,1],i1> + // CHECK: return %[[AMAX]] : !torch.vtensor<[4,1],i1> + %0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4,1],i1> + return %0 : !torch.vtensor<[4,1],i1> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_max_bool_inputs_nokeepdims +func.func @test_reduce_max_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[IDX:.+]] = torch.constant.int 0 + // CHECK: %[[SZ:.+]] = torch.constant.int 0 + // CHECK: %[[SEL:.+]] = torch.aten.select.int %arg1, %[[IDX]], %[[SZ]] + // CHECK: %[[ITEM:.+]] = torch.aten.item %[[SEL]] + // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[C0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[ITEM]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[ITEM]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> + // CHECK: return %[[AMAX]] : !torch.vtensor<[4],i1> + %0 = torch.operator "onnx.ReduceMax"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[4,2],i1>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: func.func @test_reduce_max_all_dims_default +func.func @test_reduce_max_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[I0:.+]] = torch.constant.int 0 + // CHECK: %[[I1:.+]] = torch.constant.int 1 + // CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[C0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I0]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[A0:.+]] = torch.aten.add.int %[[I0]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[I1]], %[[C0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[RANK]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[A1:.+]] = torch.aten.add.int %[[I1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[A0]], %[[A1]] + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[MAX:.+]] = torch.aten.amax %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[],i1> + // CHECK: return %[[MAX]] : !torch.vtensor<[],i1> + %0 = torch.operator "onnx.ReduceMax"(%arg0) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> + return %0 : !torch.vtensor<[],i1> +} + +// ----- + +func.func @test_reduce_max_attr(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[INT1:.+]] = torch.constant.int 1 + // CHECK: %[[DIM:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int + // CHECK: %[[INT0:.+]] = torch.constant.int 0 + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] : !torch.int, !torch.int -> !torch.bool + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] : !torch.bool -> !torch.int + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[DIM]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] : (!torch.int) -> !torch.list + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[AMAX:.+]] = torch.aten.amax %arg0, %[[LIST]], %[[FALSE]] : !torch.vtensor<[4,2],i1>, !torch.list, !torch.bool -> !torch.vtensor<[4],i1> + // CHECK: return %[[AMAX]] + %0 = torch.operator "onnx.ReduceMax"(%arg0) {torch.onnx.keepdims = 0 : si64, torch.onnx.axes=[1 : si64]} : (!torch.vtensor<[4,2],i1>) -> !torch.vtensor<[4],i1> + return %0 : !torch.vtensor<[4],i1> +} // ----- @@ -1064,8 +1120,8 @@ func.func @test_reduce_min_bool_inputs_nokeepdims(%arg0: !torch.vtensor<[4,2],i1 // ----- -// CHECK-LABEL: func.func @test_reduce_all_dims_default -func.func @test_reduce_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK-LABEL: func.func @test_reduce_min_all_dims_default +func.func @test_reduce_min_all_dims_default(%arg0: !torch.vtensor<[4,2],i1>) -> !torch.vtensor<[],i1> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 20 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: %[[I0:.+]] = torch.constant.int 0 // CHECK: %[[I1:.+]] = torch.constant.int 1 // CHECK: %[[RANK:.+]] = torch.aten.dim %arg0 : !torch.vtensor<[4,2],i1> -> !torch.int From c15f1a2bd2276b2ed6e9b47fdb9b8f9b8da5b2dd Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Wed, 6 Mar 2024 17:01:05 -0800 Subject: [PATCH 257/283] [onnx] Adding lowering for `onnx.Size` operation (#2985) We can support `onnx.Size` by requesing the size of each dimensions and taking the product of the results, then packing it into a tensor. --------- Co-authored-by: Scott Todd --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 44 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 21 +++++++++ 2 files changed, 65 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index df3449939138..34282bfef531 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -2032,6 +2032,50 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( none, none, none); return success(); }); + patterns.onOp( + "Size", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType)) + return failure(); + + auto loc = binder.getLoc(); + auto &op = binder.op; + auto operandTy = cast(operand.getType()); + + if (!operandTy.hasSizes()) + return rewriter.notifyMatchFailure(op, "input rank unknown"); + + llvm::SmallVector dims; + int64_t rank = operandTy.getSizes().size(); + for (int i = 0; i < rank; ++i) { + auto iv = rewriter.create( + loc, rewriter.getI64IntegerAttr(i)); + Value dim = rewriter.create( + loc, rewriter.getType(), operand, iv); + dims.push_back(dim); + } + + Value cstFalse = rewriter.create(loc, false); + Value none = rewriter.create(loc); + + if (dims.empty()) { + Value one = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + rewriter.replaceOpWithNewOp( + op, resultType, one, none, none, cstFalse); + return success(); + } + + Value prod = dims[0]; + for (int i = 1, s = dims.size(); i < s; ++i) + prod = rewriter.create(loc, prod, dims[i]); + + rewriter.replaceOpWithNewOp( + op, resultType, prod, none, none, cstFalse); + return success(); + }); patterns.onOp( "Tile", 6, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index 977c557739b5..bba74b6d9877 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -1649,3 +1649,24 @@ func.func @test_sign(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4, %0 = torch.operator "onnx.Sign"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> return %0 : !torch.vtensor<[3,4,5],f32> } + +// ----- + +// CHECK-LABEL: func.func @test_size +func.func @test_size(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[],si32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 9 : si64} { + // CHECK-DAG %[[INT0:.+]] = torch.constant.int 0 + // CHECK-DAG %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG %[[INT2:.+]] = torch.constant.int 2 + // CHECK-DAG %[[D0:.+]] = torch.aten.size.int %arg0, %[[INT0]] + // CHECK-DAG %[[D1:.+]] = torch.aten.size.int %arg0, %[[INT1]] + // CHECK-DAG %[[D2:.+]] = torch.aten.size.int %arg0, %[[INT2]] + // CHECK-DAG %[[FALSE:.+]] = torch.constant.bool false + // CHECK-DAG %[[NONE:.+]] = torch.constant.none + // CHECK-DAG %[[MUL0:.+]] = torch.aten.mul.int %[[D0]], %[[D1]] + // CHECK-DAG %[[MUL1:.+]] = torch.aten.mul.int %[[MUL0]], %[[D3]] + // CHECK-DAG %[[TENSOR:.+]] = torch.aten.tensor.int %[[MUL1]], %[[NONE]], %[[NONE]], %[[FALSE]] + // CHECK return %[[TENSOR]] + %0 = torch.operator "onnx.Size"(%arg0) : (!torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[],si32> + return %0 : !torch.vtensor<[],si32> +} + From d5693b3f51a8414cddf4e486daff52d8fa87cfa5 Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Thu, 7 Mar 2024 11:52:34 +0800 Subject: [PATCH 258/283] [doc] fix broken links in documents (#2990) Co-authored-by: wenyangwang --- docs/adding_an_e2e_test.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/adding_an_e2e_test.md b/docs/adding_an_e2e_test.md index 7b74b904a0f8..91eee0520f56 100644 --- a/docs/adding_an_e2e_test.md +++ b/docs/adding_an_e2e_test.md @@ -87,7 +87,7 @@ following order: 1. Shape of input tensor. Use `-1` for dynamic dimensions 2. Dtype of the input tensor -3. Boolean representing whether the input tensor [has value semantics](https://github.com/llvm/torch-mlir/blob/ba17a4d6c09b4bbb4ef21b1d8d4a93cb056be109/python/torch_mlir/jit_ir_importer/csrc/class_annotator.h#L54-L67). This +3. Boolean representing whether the input tensor [has value semantics](https://github.com/llvm/torch-mlir/blob/main/projects/jit_ir_common/csrc/jit_ir_importer/class_annotator.h#L54-L67). This will always be true for E2E tests, since the [Torch-MLIR backend contract](architecture.md#the-backend-contract) requires all tensors in the IR to eventually have value semantics. From 6e84752c395a828eb612b21be4ab26d9f7e60b22 Mon Sep 17 00:00:00 2001 From: Vivek Khandelwal Date: Thu, 7 Mar 2024 21:42:38 +0530 Subject: [PATCH 259/283] build: manually update PyTorch version (#2992) Set PyTorch and TorchVision version to nightly release 2024-03-07. This commit also removes the deprecated constraints API: https://github.com/pytorch/pytorch/commit/342e7929b804ec56121e82e92d6a199b549c38b1 Signed-Off By: Vivek Khandelwal --- projects/pt1/e2e_testing/xfail_sets.py | 7 ++++++- python/torch_mlir/fx.py | 3 +-- pytorch-hash.txt | 2 +- pytorch-requirements.txt | 2 +- test/python/fx_importer/sparse_test.py | 2 +- torchvision-requirements.txt | 2 +- 6 files changed, 11 insertions(+), 7 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 2229118181a5..ee0b3608ca1c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2167,7 +2167,6 @@ "ElementwiseTanIntModule_basic", "ElementwiseUnaryIntModule_basic", "ElementwiseUnsqueezeNegDimsModule_basic", - "ElementwiseWhereScalarModule_basic", "EmbeddingModuleF16_basic", "EmbeddingModuleI32_basic", "EmbeddingModuleI64_basic", @@ -2192,5 +2191,11 @@ "TensorsStackPromoteDTypeModule_basic", } +if torch_version_for_comparison() < version.parse("2.3.0.dev"): + ONNX_XFAIL_SET = ONNX_XFAIL_SET | { + # ERROR: dtype (torch.float64) is not equal to golden dtype (torch.float32) + "ElementwiseWhereScalarModule_basic", + } + ONNX_CRASHING_SET = { } diff --git a/python/torch_mlir/fx.py b/python/torch_mlir/fx.py index 76cd91f82e0a..3622efafd9d2 100644 --- a/python/torch_mlir/fx.py +++ b/python/torch_mlir/fx.py @@ -20,7 +20,6 @@ def export_and_import( f, *args, fx_importer: Optional[FxImporter] = None, - constraints: Optional[torch.export.Constraint] = None, experimental_support_mutation: bool = False, hooks: Optional[FxImporterHooks] = None, func_name: str = "main", @@ -31,7 +30,7 @@ def export_and_import( if fx_importer is None: fx_importer = FxImporter(context=context, hooks=hooks) - prog = torch.export.export(f, args, kwargs, constraints=constraints) + prog = torch.export.export(f, args, kwargs) decomp_table = get_decomposition_table() prog = prog.run_decompositions(decomp_table) if experimental_support_mutation: diff --git a/pytorch-hash.txt b/pytorch-hash.txt index 81f0390b4ebb..a5e23f46ea17 100644 --- a/pytorch-hash.txt +++ b/pytorch-hash.txt @@ -1 +1 @@ -8efa066dc0870521652c1319bd6b5b0f6dc3fe25 +ce013333221ff2d1285a8e8cf7c427584e65fea2 diff --git a/pytorch-requirements.txt b/pytorch-requirements.txt index 26abce08d1aa..e1bb617456de 100644 --- a/pytorch-requirements.txt +++ b/pytorch-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torch==2.3.0.dev20240220 +torch==2.3.0.dev20240307 diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 6d801a1d8799..6260a5bbaab3 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -104,7 +104,7 @@ def sparse_export( mask = [a.layout in SPARSE_LAYOUTS for a in args] # Build the regular FX traced graph with only dense arguments # (the current version would crash otherwise, see issue above). - prog = torch.export.export(f, dargs, kwargs, constraints=None) + prog = torch.export.export(f, dargs, kwargs) # Annotate sparse arguments in the graph. Note that we currently # only account for sparsity defined by the user inputs to the model. # TODO: support sparsity in model parameters (weights, biases) diff --git a/torchvision-requirements.txt b/torchvision-requirements.txt index ce099fb91709..a0b4c6fe6bed 100644 --- a/torchvision-requirements.txt +++ b/torchvision-requirements.txt @@ -1,3 +1,3 @@ -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html --pre -torchvision==0.18.0.dev20240220 +torchvision==0.18.0.dev20240307 From 7b18646defbe24653e041279ab4a772b757f6a23 Mon Sep 17 00:00:00 2001 From: Scott Todd Date: Thu, 7 Mar 2024 09:25:14 -0800 Subject: [PATCH 260/283] [onnx] Handle optional arguments in Clip op pattern. (#2976) Spec: https://onnx.ai/onnx/operators/onnx__Clip.html --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 80 +++++++++++++------ projects/pt1/e2e_testing/xfail_sets.py | 8 -- .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 24 ++++++ 3 files changed, 78 insertions(+), 34 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 52cd59e898a6..785c631c1bc9 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -602,38 +602,66 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return success(); }); patterns.onOp( - "Clip", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "Clip", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + // https://onnx.ai/onnx/operators/onnx__Clip.html + + // Inputs and outputs must be tensors. + Value source; Torch::ValueTensorType resultType; - if (binder.op->getNumOperands() == 1) { - Value source; - if (binder.tensorOperand(source) || - binder.tensorResultType(resultType)) + if (binder.tensorOperandAtIndex(source, 0) || + binder.tensorResultType(resultType)) { + return failure(); + } + + // Min and max can be args (version 11+) or attributes (version 6-). + // They default to numeric_limits::lowest() and numeric_limits::max(). + Value min; + Value max; + if (binder.op->getNumOperands() >= 2) + min = binder.op->getOperand(1); + if (binder.op->getNumOperands() == 3) + max = binder.op->getOperand(2); + + // Note: attribute versions of the op only support float types. + auto resultDtype = resultType.getDtype(); + if (!min && binder.op->hasAttr("torch.onnx.min")) { + float minValue; + if (binder.f32FloatAttr(minValue, "min", + std::numeric_limits::lowest())) return failure(); - Value cstNone = - rewriter.create(binder.getLoc()); - rewriter.replaceOpWithNewOp( - binder.op, resultType, source, /*min=*/cstNone, /*max=*/cstNone); - return success(); - } else if (binder.op->getNumOperands() == 2) { - Value source, min; - if (binder.tensorOperands(source, min) || - binder.tensorResultType(resultType)) + auto minSplatAttr = SplatElementsAttr::get( + resultType.toBuiltinTensor().clone(resultDtype), + rewriter.getFloatAttr(resultDtype, minValue)); + min = rewriter.create( + binder.getLoc(), resultType, minSplatAttr); + } + if (!max && binder.op->hasAttr("torch.onnx.max")) { + float maxValue; + if (binder.f32FloatAttr(maxValue, "max", + std::numeric_limits::max())) return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, source, /*min=*/min); + auto maxSplatAttr = SplatElementsAttr::get( + resultType.toBuiltinTensor().clone(resultDtype), + rewriter.getFloatAttr(resultDtype, maxValue)); + max = rewriter.create( + binder.getLoc(), resultType, maxSplatAttr); + } + + if (!min && !max) { + // Cliping with no limits is a no-op. + rewriter.replaceOp(binder.op, source); return success(); - } else if (binder.op->getNumOperands() == 3) { - Value source, min, max; - if (binder.tensorOperandAtIndex(source, 0) || - binder.tensorOperandAtIndex(min, 1) || - binder.tensorOperandAtIndex(max, 2) || - binder.tensorResultType(resultType)) - return failure(); - rewriter.replaceOpWithNewOp( - binder.op, resultType, source, min, max); + } + + if (!max) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, source, min); return success(); } - return failure(); + + rewriter.replaceOpWithNewOp( + binder.op, resultType, source, min, max); + return success(); }); patterns.onOp( "Concat", 13, [](OpBinder binder, ConversionPatternRewriter &rewriter) { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index ee0b3608ca1c..4e96fca863ed 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1913,14 +1913,6 @@ "TypeConversionI64ToI32Module_basic", # Failure - onnx_lowering: onnx.Clip - "ElementwiseClampMaxModule_basic", - "ElementwiseClampMinModule_basic", - "ElementwiseClampMinTensorFloatModule_basic", - "ElementwiseClampMinTensorIntModule_basic", - "ElementwiseClampModule_basic", - "ElementwiseClampTensorFloatModule_basic", - "ElementwiseClampTensorInt8Module_basic", - "ElementwiseClampTensorIntModule_basic", "NormalizeModule_basic", # Failure - onnx_lowering: onnx.Einsum diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 7c465d74bb5f..2c013553bb3c 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -522,6 +522,16 @@ func.func @test_clip_default_int8_min(%arg0: !torch.vtensor<[3,4,5],si8>, %arg1: // ----- +// CHECK-LABEL: @test_clip_default_int8_max +func.func @test_clip_default_int8_max(%arg0: !torch.vtensor<[3,4,5],si8>, %arg1: !torch.vtensor<[],si8>) -> !torch.vtensor<[3,4,5],si8> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + %none = torch.constant.none + // CHECK: torch.aten.clamp.Tensor %arg0, %none, %arg1 : !torch.vtensor<[3,4,5],si8>, !torch.none, !torch.vtensor<[],si8> -> !torch.vtensor<[3,4,5],si8> + %0 = torch.operator "onnx.Clip"(%arg0, %none, %arg1) : (!torch.vtensor<[3,4,5],si8>, !torch.none, !torch.vtensor<[],si8>) -> !torch.vtensor<[3,4,5],si8> + return %0 : !torch.vtensor<[3,4,5],si8> +} + +// ----- + // CHECK-LABEL: @test_clip_default_min func.func @test_clip_default_min(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.clamp_min.Tensor %arg0, %arg1 : !torch.vtensor<[3,4,5],f32>, !torch.vtensor<[],f32> -> !torch.vtensor<[3,4,5],f32> @@ -549,6 +559,20 @@ func.func @test_clip(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vtensor<[ // ----- +module { + func.func @test_clip_attrs(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 6 : si64} { + %none = torch.constant.none + + // CHECK: %[[MIN:.+]] = torch.vtensor.literal(dense<-5.000000e-01> : tensor<3x4xf32>) : !torch.vtensor<[3,4],f32> + // CHECK: %[[MAX:.+]] = torch.vtensor.literal(dense<5.000000e-01> : tensor<3x4xf32>) : !torch.vtensor<[3,4],f32> + // CHECK: %[[CLAMP:.+]] = torch.aten.clamp.Tensor %arg0, %[[MIN]], %[[MAX]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32>, !torch.vtensor<[3,4],f32> -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Clip"(%arg0) {torch.onnx.max = 5.000000e-01 : f32, torch.onnx.min = -5.000000e-01 : f32} : (!torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> + } +} + +// ----- + // CHECK-LABEL: @test_cos_example func.func @test_cos_example(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 3 : si64, torch.onnx_meta.opset_version = 7 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.cos %arg0 : !torch.vtensor<[3],f32> -> !torch.vtensor<[3],f32> From 1964208d19a296ce55d267b5f8a8895025cb09a3 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Thu, 7 Mar 2024 13:29:50 -0800 Subject: [PATCH 261/283] [onnx] Fix constant pad for dynamic shape (#2989) The current padding operation was not functional for dynamic shapes. Updated and enabled tests so that onnx.pad tests pass. Work TBD for reflection padding. --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 110 ++++++++---------- .../TorchToLinalg/TensorConstructors.cpp | 63 +++++++--- projects/pt1/e2e_testing/xfail_sets.py | 6 - .../torch_mlir_e2e_test/test_suite/padding.py | 2 - .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 22 ++-- 5 files changed, 105 insertions(+), 98 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 01b47bc56cbc..b956666f8f7b 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -908,7 +908,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return success(); }); patterns.onOp( - "Pad", 19, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + "Pad", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data, pads, axes; std::string mode; @@ -925,36 +925,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( return failure(); Location loc = binder.getLoc(); - Value constantValue; - if (binder.getNumOperands() >= 3) { - if (binder.tensorOperandAtIndex(constantValue, 2)) { - llvm::errs() << "failed to bind to index 2\n"; - return failure(); - } - } else { - auto dataTensorType = data.getType().cast(); - - auto maybeZeroAttr = [&]() -> std::optional { - if (dataTensorType.getDtype().isa()) { - return rewriter.getI64IntegerAttr(0); - } - if (dataTensorType.getDtype().isa()) { - return rewriter.getFloatAttr(dataTensorType.getDtype(), 0.0f); - } - return std::nullopt; - }(); - - if (!maybeZeroAttr) { - return rewriter.notifyMatchFailure( - binder.op, "expected integer or float data tensor"); - } - - auto shapedType = dataTensorType.toBuiltinTensor(); - auto splat = SplatElementsAttr::get(shapedType, *maybeZeroAttr); - constantValue = rewriter.create( - loc, dataTensorType, splat); - } - // Get pads shape and rank. The pads tensor is expected to be 1-D // tensor. auto padsTensorType = pads.getType().cast(); @@ -964,14 +934,48 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( } ArrayRef padsShape = padsTensorType.getSizes(); int64_t padsRank = padsShape.size(); - if (padsRank != 1) { + if (padsRank != 1) + return rewriter.notifyMatchFailure(binder.op, + "expect 1-d pad tensor"); + + int64_t padsSize = padsShape[0]; + if (padsSize == Torch::kUnknownSize) return rewriter.notifyMatchFailure(binder.op, - "Expect 1-D pad tensor"); + "pad length is unknown"); + + Value constantValue; + if (binder.getNumOperands() >= 3) { + if (!binder.tensorOperandAtIndex(constantValue, 2)) { + auto constTy = + dyn_cast(constantValue.getType()); + if (!constTy || !constTy.hasDtype()) + return rewriter.notifyMatchFailure( + binder.op, "constant ty is unsupport type"); + + Type scalarTy = rewriter.getType(); + if (isa(constTy.getDtype())) + scalarTy = rewriter.getType(); + constantValue = rewriter.create(loc, scalarTy, + constantValue); + } + } + + if (!constantValue) { + auto dataTensorType = data.getType().cast(); + if (dataTensorType.getDtype().isa()) + constantValue = rewriter.create( + loc, rewriter.getI64IntegerAttr(0)); + if (dataTensorType.getDtype().isa()) + constantValue = rewriter.create( + loc, rewriter.getF64FloatAttr(0.0f)); + + if (!constantValue) + return rewriter.notifyMatchFailure( + binder.op, "expected integer or float data tensor"); } // Extract all the values of 1-D pad tensor and create a list of all // these values as torch.pad op expects pad list. - int64_t padsSize = padsShape[0]; Value constZero = rewriter.create( loc, rewriter.getI64IntegerAttr(0)); SmallVector padsTensorValue; @@ -982,8 +986,11 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( for (uint32_t i = 0; i < padsSize; ++i) { Value index = rewriter.create( loc, rewriter.getI64IntegerAttr(i)); - padsTensorValue.emplace_back(rewriter.create( - loc, padsElemType, pads, constZero, index)); + auto select = rewriter.create( + loc, padsElemType, pads, constZero, index); + Value selectInt = rewriter.create( + loc, rewriter.getType(), select); + padsTensorValue.push_back(selectInt); } // The torch.pad op expects a different arrangement of padding pairs for @@ -991,43 +998,22 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( // tensor to satisfy torch.pad op semantics. SmallVector padsRearrange; for (uint32_t i = 0; i < padsSize / 2; i++) { - padsRearrange.emplace_back(padsTensorValue[(padsSize / 2) - 1 - i]); - padsRearrange.emplace_back(padsTensorValue[padsSize - 1 - i]); + padsRearrange.emplace_back(padsTensorValue[i]); + padsRearrange.emplace_back(padsTensorValue[(padsSize / 2) + i]); } Value padsSizeList = rewriter - .create( + .create( loc, Torch::ListType::get(rewriter.getType()), padsRearrange) - .getResult(0); + .getResult(); Value modeVal = rewriter.create( loc, rewriter.getStringAttr(mode)); - // The constant value is a 0-d tensor, which needs to be converted to a - // float scalar as torch.pad op expects a float scalar. - auto constValueType = - constantValue.getType().cast(); - if (!constValueType) { - return rewriter.notifyMatchFailure(binder.op, - "Expect non-none constant value"); - } - auto resultTensorType = Torch::ValueTensorType::get( - constValueType.getContext(), emptyShape, rewriter.getF64Type()); - Value none = rewriter.create(loc); - Value cstFalse = rewriter.create(loc, false); - Value constFloatValue = rewriter.create( - loc, resultTensorType, constantValue, - Torch::getDtypeIntValueForType(rewriter, loc, - resultTensorType.getOptionalDtype()), - /*non_blocking=*/cstFalse, /*copy=*/cstFalse, - /*memory_format=*/none); - Value constScalar = rewriter.create( - loc, rewriter.getType(), constFloatValue); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, padsSizeList, modeVal, constScalar); + binder.op, resultType, data, padsSizeList, modeVal, constantValue); return success(); }); patterns.onOp("Pow", 1, diff --git a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp index 2b8eac49447a..385f5b435e1b 100644 --- a/lib/Conversion/TorchToLinalg/TensorConstructors.cpp +++ b/lib/Conversion/TorchToLinalg/TensorConstructors.cpp @@ -45,36 +45,69 @@ class ConvertAtenConstantPadNdOp auto type = self.getType().cast(); int64_t rank = type.getRank(); - // Pattern match against the op's original operands, because otherwise we - // will get the lowered version of the operands which is harder to pattern - // match. - SmallVector padInts; - if (!matchPattern(op.getPad(), m_TorchListOfConstantInts(padInts))) - return rewriter.notifyMatchFailure( - op, "only support constant int pad ranges"); - uint64_t padRank = padInts.size() / 2; - if (padRank * 2 != padInts.size()) + auto primList = op.getPad().getDefiningOp(); + if (!primList) { + return rewriter.notifyMatchFailure(op, "unable to get pad values"); + } + + SmallVector padVals(primList.getOperands()); + + uint64_t padRank = padVals.size() / 2; + if (padRank * 2 != padVals.size()) return rewriter.notifyMatchFailure(op, "pad range size is not even"); if (rank < 0 || padRank > (uint64_t)rank) return rewriter.notifyMatchFailure(op, "padding exceeds tensor rank"); // Initialize low/high paddings with the dims that should not be padded. - SmallVector lowPadding(/*Size=*/rank - padRank, /*Value=*/0); - SmallVector highPadding(/*Size=*/rank - padRank, /*Value=*/0); + int64_t noPad = rank - padRank; + Attribute zero = rewriter.getIndexAttr(0); + SmallVector staticLow(noPad, 0); + SmallVector staticHigh(noPad, 0); + SmallVector lowPad(noPad, zero); + SmallVector highPad(noPad, zero); + + auto tc = getTypeConverter(); + // Add the requested padding - note op.pad() is highest dim first ordered // pairs of low,high. for (uint64_t i = padRank; i > 0; --i) { - lowPadding.push_back(padInts[i * 2 - 2]); - highPadding.push_back(padInts[i * 2 - 1]); + int64_t lowi, highi; + Value lowv = padVals[i * 2 - 2]; + Value highv = padVals[i * 2 - 1]; + if (!matchPattern(lowv, m_TorchConstantInt(&lowi))) { + Type cty = tc->convertType(lowv.getType()); + lowv = tc->materializeTargetConversion(rewriter, loc, cty, lowv); + lowv = rewriter.create(loc, rewriter.getIndexType(), + lowv); + lowPad.push_back(lowv); + staticLow.push_back(ShapedType::kDynamic); + } else { + lowPad.push_back(rewriter.getIndexAttr(lowi)); + staticLow.push_back(lowi); + } + + if (!matchPattern(highv, m_TorchConstantInt(&highi))) { + Type cty = tc->convertType(highv.getType()); + highv = tc->materializeTargetConversion(rewriter, loc, cty, highv); + highv = rewriter.create( + loc, rewriter.getIndexType(), highv); + highPad.push_back(highv); + staticHigh.push_back(ShapedType::kDynamic); + } else { + highPad.push_back(rewriter.getIndexAttr(highi)); + staticHigh.push_back(highi); + } } Type newResultType = getTypeConverter()->convertType(op.getType()); Type elementType = newResultType.cast().getElementType(); Value castedValue = convertScalarToDtype(rewriter, loc, adaptor.getValue(), elementType); - Value paddedInput = torch_to_linalg::getPaddedTensor( - op, rewriter, self, lowPadding, highPadding, castedValue); + Type padType = tensor::PadOp::inferResultType( + self.getType().cast(), staticLow, staticHigh); + Value paddedInput = rewriter.create( + loc, padType, self, lowPad, highPad, castedValue); rewriter.replaceOpWithNewOp(op, newResultType, paddedInput); return success(); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 4e96fca863ed..9f5db2c86f6f 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1955,12 +1955,6 @@ "OneHotModule_basic", # Failure - onnx_lowering: onnx.Pad - "ConstantPad2dStaticModule_basic", - "ConstantPadNdModule_basic", - "ConstantPadNdPartialStaticModule_basic", - "ConstantPadNdStaticModule_basic", - "PadModule_basic", - "PadWithNoneValModule_basic", "ReflectionPad1dModule2dInput_Right", "ReflectionPad1dModule2dInput_basic", "ReflectionPad1dModule3dInput_Left", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py index 6b7bdeab2b48..59961fedcc27 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/padding.py @@ -109,5 +109,3 @@ def forward(self, x): @register_test_case(module_factory=lambda: ReflectionPad2dModuleRight()) def ReflectionPad2dModule_Right(module, tu: TestUtils): module.forward(tu.rand(2, 3, 20, 20)) - -# ============================================================================== diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 493ec8cebfdb..dcd0d28932e5 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -447,23 +447,23 @@ func.func @test_less_or_equal(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch. // CHECK-LABEL: func.func @test_pad func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>, %arg2: !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} { + // CHECK: %[[VAL:.+]] = torch.aten.item %arg2 : !torch.vtensor<[],f32> -> !torch.float // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 // CHECK: %[[SELECT_0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM_0:.+]] = torch.aten.item %[[SELECT_0]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT1:.+]] = torch.constant.int 1 // CHECK: %[[SELECT_1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM_1:.+]] = torch.aten.item %[[SELECT_1]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT2:.+]] = torch.constant.int 2 // CHECK: %[[SELECT_2:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT2]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> + // CHECK: %[[ITEM_2:.+]] = torch.aten.item %[[SELECT_2]] : !torch.vtensor<[],si64> -> !torch.int // CHECK: %[[INT3:.+]] = torch.constant.int 3 // CHECK: %[[SELECT_3:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT3]] : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> - // CHECK: %[[LIST:.+]] = torch.prim.tolist(%[[SELECT_1]], %[[SELECT_3]], %[[SELECT_0]], %[[SELECT_2]]) : !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.list + // CHECK: %[[ITEM_3:.+]] = torch.aten.item %[[SELECT_3]] : !torch.vtensor<[],si64> -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM_0]], %[[ITEM_2]], %[[ITEM_1]], %[[ITEM_3]] : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list // CHECK: %[[STR:.+]] = torch.constant.str "constant" - // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: %[[INT7:.+]] = torch.constant.int 7 - // CHECK: %[[CONVERT:.+]] = torch.aten.to.dtype %arg2, %[[INT7]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f64> - // CHECK: %[[ITEM:.+]] = torch.aten.item %[[CONVERT]] : !torch.vtensor<[],f64> -> !torch.float - // CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[STR]], %[[ITEM]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32> + // CHECK: %[[PAD:.+]] = torch.aten.pad %arg0, %[[LIST]], %[[STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32> // CHECK: return %[[PAD]] : !torch.vtensor<[5,4],f32> %0 = torch.operator "onnx.Pad"(%arg0, %arg1, %arg2) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>, !torch.vtensor<[], f32>) -> !torch.vtensor<[5,4],f32> return %0 : !torch.vtensor<[5,4],f32> @@ -474,13 +474,9 @@ func.func @test_pad(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], // CHECK-LABEL: @test_pad_optional_constant // CHECK-SAME: %[[ARG0:.*]]: !torch.vtensor<[3,4],f32> // CHECK-SAME: %[[ARG1:.*]]: !torch.vtensor<[4],si64> +// CHECK: %[[VAL:.+]] = torch.constant.float 0 // CHECK: %[[CONST_STR:.*]] = torch.constant.str "constant" -// CHECK: %[[NONE:.*]] = torch.constant.none -// CHECK: %[[FALSE:.*]] = torch.constant.bool false -// CHECK: %[[SEVEN:.*]] = torch.constant.int 7 -// CHECK: %[[DTYPE:.*]] = torch.aten.to.dtype %0, %[[SEVEN]], %[[FALSE]], %[[FALSE]], %[[NONE]] : !torch.vtensor<[3,4],f32>, !torch.int, !torch.bool, !torch.bool, !torch.none -> !torch.vtensor<[],f64> -// CHECK: %[[ITEM:.*]] = torch.aten.item %[[DTYPE]] : !torch.vtensor<[],f64> -> !torch.float -// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[CONST_STR]], %[[ITEM]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32> +// CHECK: torch.aten.pad %[[ARG0]], %{{.*}}, %[[CONST_STR]], %[[VAL]] : !torch.vtensor<[3,4],f32>, !torch.list, !torch.str, !torch.float -> !torch.vtensor<[5,4],f32> func.func @test_pad_optional_constant(%arg0: !torch.vtensor<[3,4],f32>, %arg1: !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> attributes {torch.onnx_meta.opset_version = 19 : si64} { %0 = torch.operator "onnx.Pad"(%arg0, %arg1) {torch.onnx.mode = "constant"} : (!torch.vtensor<[3,4],f32>, !torch.vtensor<[4], si64>) -> !torch.vtensor<[5,4],f32> From 551a4e45f36574b6f0bf892a2d67d877e0253441 Mon Sep 17 00:00:00 2001 From: Andreas Falkenberg <149819731+afalkenberg1@users.noreply.github.com> Date: Thu, 7 Mar 2024 15:58:38 -0800 Subject: [PATCH 262/283] [onnx] Add support for `onnx.Gemm` with no bias (#2993) Previous gemm version required a bias vector. This provides an alternate path to `Torch::AtenMm` with no bias operation. --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 11 ++++++++++- projects/pt1/e2e_testing/xfail_sets.py | 8 -------- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 13 +++++++++++-- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index b956666f8f7b..8a677b8ce058 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -687,7 +687,6 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( int64_t transA, transB; if (binder.tensorOperandAtIndex(a, 0) || binder.tensorOperandAtIndex(b, 1) || - binder.tensorOperandAtIndex(c, 2) || binder.s64IntegerAttr(transA, "transA", 0) || binder.s64IntegerAttr(transB, "transB", 0) || binder.f32FloatAttr(alpha, "alpha", 1.0f) || @@ -724,6 +723,16 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( b = transpose(b); } + if (binder.getNumOperands() == 2) { + rewriter.replaceOpWithNewOp(binder.op, resultType, a, + b); + return success(); + } + + if (binder.tensorOperandAtIndex(c, 2)) + return rewriter.notifyMatchFailure(binder.op, + "Expected either 2 or 3 inputs"); + Value mm = rewriter.create(binder.getLoc(), resultType, a, b); if (alpha == 1.0 && beta == 1.0) { diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 9f5db2c86f6f..d16a20893dbf 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1920,14 +1920,6 @@ "EinsumStaticFourDimensionModule_basic", "EinsumStaticModule_basic", - # Failure - onnx_lowering: onnx.Gemm - "AtenMmFloatTypes_basic", - "AtenMmIntTypes_basic", - "MmDagModule_basic", - "MmModule_basic", - "MmModule_chained", - "MmTanhModule_basic", - # Failure - onnx_lowering: onnx.HardSwish "HardswishModule_basic", "HardswishRandomModule_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index dcd0d28932e5..d1f4307d4de6 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -97,8 +97,17 @@ func.func @test_gather_elements(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torc // ----- -// CHECK-LABEL: func.func @test_gemm_default -func.func @test_gemm_default(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { +// CHECK-LABEL: func.func @test_gemm_defaultA +func.func @test_gemm_defaultA(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK-DAG: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32> + %0 = torch.operator "onnx.Gemm"(%arg0, %arg1) : (!torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32>) -> !torch.vtensor<[3,4],f32> + return %0 : !torch.vtensor<[3,4],f32> +} + +// ----- + +// CHECK-LABEL: func.func @test_gemm_defaultB +func.func @test_gemm_defaultB(%arg0: !torch.vtensor<[3,5],f32>, %arg1: !torch.vtensor<[5,4],f32>, %arg2: !torch.vtensor<[1,4],f32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { // CHECK: %[[I1:.+]] = torch.constant.int 1 // CHECK: %[[MM:.+]] = torch.aten.mm %arg0, %arg1 : !torch.vtensor<[3,5],f32>, !torch.vtensor<[5,4],f32> -> !torch.vtensor<[3,4],f32> // CHECK: torch.aten.add.Tensor %[[MM]], %arg2, %[[I1]] : !torch.vtensor<[3,4],f32>, !torch.vtensor<[1,4],f32>, !torch.int -> !torch.vtensor<[3,4],f32> From 6b3a7d07c2c76f5e8437ff4e88110899621557b9 Mon Sep 17 00:00:00 2001 From: Dmitry Babokin Date: Thu, 7 Mar 2024 20:26:53 -0800 Subject: [PATCH 263/283] Fix link to roadmap in README.md (#2995) The file was renamed by PR https://github.com/llvm/torch-mlir/pull/2842. --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a10b9ac36bb5..1b0fff13bdb3 100644 --- a/README.md +++ b/README.md @@ -29,7 +29,7 @@ We have few paths to lower down to the Torch MLIR Dialect. - LazyTensorCore Read more details [here](docs/ltc_backend.md). - We also have basic TorchDynamo/PyTorch 2.0 support, see our - [long-term roadmap](docs/long_term_roadmap.md) and + [long-term roadmap](docs/roadmap.md) and [Thoughts on PyTorch 2.0](https://discourse.llvm.org/t/thoughts-on-pytorch-2-0/67000/3) for more details. From 80c7bc3f7ae12413836a2f610a6491794b4dbb08 Mon Sep 17 00:00:00 2001 From: Daniel Garvey <34486624+dan-garvey@users.noreply.github.com> Date: Fri, 8 Mar 2024 14:58:50 -0600 Subject: [PATCH 264/283] fximporter: support newer torch versions (#2999) uses version checking since attributes exist in both versions, the only thing that changes is what we're receiving as an fx graph --- python/torch_mlir/extras/fx_importer.py | 72 ++++++++++++++++++------- 1 file changed, 52 insertions(+), 20 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index a703db45398f..952b638c1988 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -220,19 +220,47 @@ "gt": torch.ops.aten.gt, } -SYMBOLIC_TORCH_OPS = { - torch.ops.aten.sym_size, - torch.ops.aten.sym_stride, - torch.ops.aten.sym_numel, -} - -SYMBOLIC_OP_TO_TORCH_OP = { - (torch.ops.aten.sym_size, 1): torch.ops.aten.size.default, - (torch.ops.aten.sym_size, 2): torch.ops.aten.size.int, - (torch.ops.aten.sym_stride, 1): torch.ops.aten.stride.default, - (torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int, - (torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default, -} +# torch with cuda has a __version__ that looks like "2.1.0+cu113", +# so split by + and 0 index will always give the base version +_IS_TORCH_2_1_OR_EARLIER = torch.__version__.split("+")[0] <= "2.1.0" + +# The following are maps from symbolic ops to their non symbolic equivalents. +# In <=2.1.0, imported fx graphs come with a type inspecific torch.ops.aten.sym_size +# We identify it using the number of args in the node, 1 being default, 2 being int +# In the mapping below (torch.aten.sym_size, 2) indicates len(args)=2 therefore +# map to torch.aten.size.int. +# Thankfully, newer versions provide a specific torch.ops.aten.sym_size.. +# Once we drop support for <2.1.0, we can get rid of the the SYMBOLIC_TORCH_OPS +# set and just check key existence in SYMBOLIC_OP_TO_TORCH_OP + +if _IS_TORCH_2_1_OR_EARLIER: + SYMBOLIC_TORCH_OPS = { + torch.ops.aten.sym_size, + torch.ops.aten.sym_stride, + torch.ops.aten.sym_numel, + } + + SYMBOLIC_OP_TO_TORCH_OP = { + (torch.ops.aten.sym_size, 1): torch.ops.aten.size.default, + (torch.ops.aten.sym_size, 2): torch.ops.aten.size.int, + (torch.ops.aten.sym_stride, 1): torch.ops.aten.stride.default, + (torch.ops.aten.sym_stride, 2): torch.ops.aten.stride.int, + (torch.ops.aten.sym_numel, 1): torch.ops.aten.numel.default, + } +else: + SYMBOLIC_TORCH_OPS = { + torch.ops.aten.sym_size.int, + torch.ops.aten.sym_stride.int, + torch.ops.aten.sym_numel.default, + } + + SYMBOLIC_OP_TO_TORCH_OP = { + torch.ops.aten.sym_size.default: torch.ops.aten.size.default, + torch.ops.aten.sym_size.int: torch.ops.aten.size.int, + torch.ops.aten.sym_stride.default: torch.ops.aten.stride.default, + torch.ops.aten.sym_stride.int: torch.ops.aten.stride.int, + torch.ops.aten.sym_numel.default: torch.ops.aten.numel.default, + } @dataclass(frozen=True) @@ -638,7 +666,9 @@ def import_program( node_importer.return_node_values(loc, user_outputs) self.symbol_table.insert(func_op) - def import_frozen_program(self, prog: torch.export.ExportedProgram, func_name: str = "main"): + def import_frozen_program( + self, prog: torch.export.ExportedProgram, func_name: str = "main" + ): """Imports a consolidated torch.export.ExportedProgram instance. If using the new torch.export path (vs a lower level precursor), then this is @@ -1137,14 +1167,14 @@ def import_nodes( raise NotImplementedError( f"General getitem access to non-multi-result ops" ) - elif isinstance(target, TorchOpOverload): - # Dispatch to an ATen op. - self._import_torch_op_overload(loc, node, target) elif target in SYMBOLIC_TORCH_OPS or ( is_symbolic(node.meta.get("val")) and is_builtin_function_or_method(target) ): self._import_symbolic_torch_op(loc, node, target) + elif isinstance(target, TorchOpOverload): + # Dispatch to an ATen op. + self._import_torch_op_overload(loc, node, target) else: raise NotImplementedError( f"FIX ME: Unimplemented call_function: target={node.target}, {node.meta}" @@ -1227,7 +1257,10 @@ def _import_symbolic_torch_op( ), f"Unsupported builtin function for symbolic types: {target} with args {node.args}" concrete_target = getattr(torch_op, op_overload) else: - concrete_target = SYMBOLIC_OP_TO_TORCH_OP.get((target, len(node.args))) + if _IS_TORCH_2_1_OR_EARLIER: + concrete_target = SYMBOLIC_OP_TO_TORCH_OP.get((target, len(node.args))) + else: + concrete_target = SYMBOLIC_OP_TO_TORCH_OP.get(target) assert ( concrete_target is not None @@ -1628,8 +1661,7 @@ def lookup(self, t: type) -> Any: # Opaque value to indicate something is empty. Used in cases where 'None' # may have a different meaning. -class EmptyType: - ... +class EmptyType: ... Empty = EmptyType() From 07235849361e01152decd9471be0985c42b4a1f7 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 8 Mar 2024 13:44:00 -0800 Subject: [PATCH 265/283] [torch] Add folder for torch.aten.*.Scalar comparisons (#3000) This folds small version of the tensor-scalar comparison operators as they are commonly used for shape computations. This includes le, lt, ge, gt, eq, and ne. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 1120 +++++++++-------- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 2 +- lib/Dialect/Torch/IR/TorchOps.cpp | 191 +++ projects/pt1/e2e_testing/xfail_sets.py | 24 - .../build_tools/torch_ods_gen.py | 12 +- .../test_suite/elementwise.py | 2 +- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 4 +- test/Dialect/Torch/canonicalize.mlir | 125 ++ 8 files changed, 889 insertions(+), 591 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index dfb1b0382918..41ca1f5801dc 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -2131,12 +2131,12 @@ def Torch_AtenDiv_ScalarOp : Torch_Op<"aten.div_.Scalar", [ }]; } -def Torch_AtenNeScalarOp : Torch_Op<"aten.ne.Scalar", [ +def Torch_AtenFmodScalarOp : Torch_Op<"aten.fmod.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchScalarType:$other @@ -2146,20 +2146,20 @@ def Torch_AtenNeScalarOp : Torch_Op<"aten.ne.Scalar", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNeScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenFmodScalarOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenNeScalarOp::print(OpAsmPrinter &printer) { + void AtenFmodScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenNe_ScalarOp : Torch_Op<"aten.ne_.Scalar", [ +def Torch_AtenFmod_ScalarOp : Torch_Op<"aten.fmod_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::ne_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::fmod_.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, AnyTorchScalarType:$other @@ -2169,638 +2169,628 @@ def Torch_AtenNe_ScalarOp : Torch_Op<"aten.ne_.Scalar", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenNe_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenFmod_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenNe_ScalarOp::print(OpAsmPrinter &printer) { + void AtenFmod_ScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenEqScalarOp : Torch_Op<"aten.eq.Scalar", [ +def Torch_AtenMaskedFillScalarOp : Torch_Op<"aten.masked_fill.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other + AnyTorchTensorType:$mask, + AnyTorchScalarType:$value ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenEqScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenMaskedFillScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenEqScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenMaskedFillScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenEq_ScalarOp : Torch_Op<"aten.eq_.Scalar", [ +def Torch_AtenMaskedFill_ScalarOp : Torch_Op<"aten.masked_fill_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::eq_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::masked_fill_.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - AnyTorchScalarType:$other + Torch_NonValueTensorType:$mask, + AnyTorchScalarType:$value ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenEq_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenMaskedFill_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenEq_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenMaskedFill_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenGtScalarOp : Torch_Op<"aten.gt.Scalar", [ +def Torch_AtenClampOp : Torch_Op<"aten.clamp", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other + AnyTorchOptionalScalarType:$min, + AnyTorchOptionalScalarType:$max ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenGtScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenClampOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenGtScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenClampOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenGt_ScalarOp : Torch_Op<"aten.gt_.Scalar", [ +def Torch_AtenClamp_Op : Torch_Op<"aten.clamp_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::gt_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp_ : (Tensor, Scalar?, Scalar?) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - AnyTorchScalarType:$other + AnyTorchOptionalScalarType:$min, + AnyTorchOptionalScalarType:$max ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenGt_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenClamp_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenGt_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenClamp_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenGeScalarOp : Torch_Op<"aten.ge.Scalar", [ +def Torch_AtenClampTensorOp : Torch_Op<"aten.clamp.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other + AnyTorchOptionalTensorType:$min, + AnyTorchOptionalTensorType:$max ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenGeScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenClampTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenGeScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenClampTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenGe_ScalarOp : Torch_Op<"aten.ge_.Scalar", [ +def Torch_AtenClamp_TensorOp : Torch_Op<"aten.clamp_.Tensor", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::ge_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp_.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - AnyTorchScalarType:$other + AnyTorchOptionalNonValueTensorType:$min, + AnyTorchOptionalNonValueTensorType:$max ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenGe_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenClamp_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenGe_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenClamp_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenLtScalarOp : Torch_Op<"aten.lt.Scalar", [ +def Torch_AtenClampMinOp : Torch_Op<"aten.clamp_min", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp_min : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other + AnyTorchScalarType:$min ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLtScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenClampMinOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenLtScalarOp::print(OpAsmPrinter &printer) { + void AtenClampMinOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenLt_ScalarOp : Torch_Op<"aten.lt_.Scalar", [ +def Torch_AtenClampMin_Op : Torch_Op<"aten.clamp_min_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::lt_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp_min_ : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - AnyTorchScalarType:$other + AnyTorchScalarType:$min ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLt_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenClampMin_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenLt_ScalarOp::print(OpAsmPrinter &printer) { + void AtenClampMin_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenLeScalarOp : Torch_Op<"aten.le.Scalar", [ +def Torch_AtenClampMinTensorOp : Torch_Op<"aten.clamp_min.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::le.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp_min.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other + AnyTorchTensorType:$min ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLeScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenClampMinTensorOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenLeScalarOp::print(OpAsmPrinter &printer) { + void AtenClampMinTensorOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenLe_ScalarOp : Torch_Op<"aten.le_.Scalar", [ +def Torch_AtenClampMin_TensorOp : Torch_Op<"aten.clamp_min_.Tensor", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::le_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp_min_.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - AnyTorchScalarType:$other + Torch_NonValueTensorType:$min ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLe_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenClampMin_TensorOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenLe_ScalarOp::print(OpAsmPrinter &printer) { + void AtenClampMin_TensorOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenFmodScalarOp : Torch_Op<"aten.fmod.Scalar", [ +def Torch_AtenClampMaxOp : Torch_Op<"aten.clamp_max", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp_max : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other + AnyTorchScalarType:$max ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFmodScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenClampMaxOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenFmodScalarOp::print(OpAsmPrinter &printer) { + void AtenClampMaxOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenFmod_ScalarOp : Torch_Op<"aten.fmod_.Scalar", [ +def Torch_AtenClampMax_Op : Torch_Op<"aten.clamp_max_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::fmod_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp_max_ : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - AnyTorchScalarType:$other + AnyTorchScalarType:$max ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFmod_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenClampMax_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenFmod_ScalarOp::print(OpAsmPrinter &printer) { + void AtenClampMax_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenMaskedFillScalarOp : Torch_Op<"aten.masked_fill.Scalar", [ +def Torch_AtenClampMaxTensorOp : Torch_Op<"aten.clamp_max.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp_max.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$mask, - AnyTorchScalarType:$value + AnyTorchTensorType:$max ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMaskedFillScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenClampMaxTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenMaskedFillScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenClampMaxTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenMaskedFill_ScalarOp : Torch_Op<"aten.masked_fill_.Scalar", [ +def Torch_AtenClampMax_TensorOp : Torch_Op<"aten.clamp_max_.Tensor", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::masked_fill_.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::clamp_max_.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$mask, - AnyTorchScalarType:$value + Torch_NonValueTensorType:$max ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMaskedFill_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenClampMax_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenMaskedFill_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenClampMax_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenClampOp : Torch_Op<"aten.clamp", [ +def Torch_AtenLog2Op : Torch_Op<"aten.log2", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)`"; + let summary = "Generated op for `aten::log2 : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchOptionalScalarType:$min, - AnyTorchOptionalScalarType:$max + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClampOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenLog2Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenClampOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenLog2Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenClamp_Op : Torch_Op<"aten.clamp_", [ +def Torch_AtenLog2_Op : Torch_Op<"aten.log2_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::clamp_ : (Tensor, Scalar?, Scalar?) -> (Tensor)`"; + let summary = "Generated op for `aten::log2_ : (Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self, - AnyTorchOptionalScalarType:$min, - AnyTorchOptionalScalarType:$max + Torch_NonValueTensorType:$self ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClamp_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenLog2_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenClamp_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenLog2_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenClampTensorOp : Torch_Op<"aten.clamp.Tensor", [ +def Torch_AtenLog10Op : Torch_Op<"aten.log10", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::clamp.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)`"; + let summary = "Generated op for `aten::log10 : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchOptionalTensorType:$min, - AnyTorchOptionalTensorType:$max + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClampTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenLog10Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenClampTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenLog10Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenClamp_TensorOp : Torch_Op<"aten.clamp_.Tensor", [ +def Torch_AtenLog10_Op : Torch_Op<"aten.log10_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::clamp_.Tensor : (Tensor, Tensor?, Tensor?) -> (Tensor)`"; + let summary = "Generated op for `aten::log10_ : (Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self, - AnyTorchOptionalNonValueTensorType:$min, - AnyTorchOptionalNonValueTensorType:$max + Torch_NonValueTensorType:$self ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClamp_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenLog10_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenClamp_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenLog10_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenClampMinOp : Torch_Op<"aten.clamp_min", [ +def Torch_AtenSqrtOp : Torch_Op<"aten.sqrt", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::clamp_min : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::sqrt : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$min + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClampMinOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenSqrtOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenClampMinOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenSqrtOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenClampMin_Op : Torch_Op<"aten.clamp_min_", [ +def Torch_AtenSqrt_Op : Torch_Op<"aten.sqrt_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::clamp_min_ : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::sqrt_ : (Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self, - AnyTorchScalarType:$min + Torch_NonValueTensorType:$self ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClampMin_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenSqrt_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenClampMin_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenSqrt_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenClampMinTensorOp : Torch_Op<"aten.clamp_min.Tensor", [ +def Torch_AtenLog1pOp : Torch_Op<"aten.log1p", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::clamp_min.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::log1p : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$min + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClampMinTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenLog1pOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenClampMinTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenLog1pOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenClampMin_TensorOp : Torch_Op<"aten.clamp_min_.Tensor", [ +def Torch_AtenLog1p_Op : Torch_Op<"aten.log1p_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::clamp_min_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::log1p_ : (Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$min + Torch_NonValueTensorType:$self ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClampMin_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenLog1p_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenClampMin_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenLog1p_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenClampMaxOp : Torch_Op<"aten.clamp_max", [ +def Torch_AtenLogitOp : Torch_Op<"aten.logit", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::clamp_max : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::logit : (Tensor, float?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$max + AnyTorchOptionalFloatType:$eps ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClampMaxOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenLogitOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenClampMaxOp::print(OpAsmPrinter &printer) { + void AtenLogitOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenClampMax_Op : Torch_Op<"aten.clamp_max_", [ +def Torch_AtenLogit_Op : Torch_Op<"aten.logit_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::clamp_max_ : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::logit_ : (Tensor, float?) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - AnyTorchScalarType:$max + AnyTorchOptionalFloatType:$eps ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClampMax_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenLogit_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenClampMax_Op::print(OpAsmPrinter &printer) { + void AtenLogit_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenClampMaxTensorOp : Torch_Op<"aten.clamp_max.Tensor", [ +def Torch_AtenRsqrtOp : Torch_Op<"aten.rsqrt", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::clamp_max.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::rsqrt : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$max + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClampMaxTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenRsqrtOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenClampMaxTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenRsqrtOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenClampMax_TensorOp : Torch_Op<"aten.clamp_max_.Tensor", [ +def Torch_AtenRsqrt_Op : Torch_Op<"aten.rsqrt_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::clamp_max_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::rsqrt_ : (Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$max + Torch_NonValueTensorType:$self ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenClampMax_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenRsqrt_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenClampMax_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenRsqrt_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenLog2Op : Torch_Op<"aten.log2", [ +def Torch_AtenAbsOp : Torch_Op<"aten.abs", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::log2 : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::abs : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -2809,20 +2799,20 @@ def Torch_AtenLog2Op : Torch_Op<"aten.log2", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLog2Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenAbsOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenLog2Op::print(OpAsmPrinter &printer) { + void AtenAbsOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenLog2_Op : Torch_Op<"aten.log2_", [ +def Torch_AtenAbs_Op : Torch_Op<"aten.abs_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::log2_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::abs_ : (Tensor) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self ); @@ -2831,21 +2821,21 @@ def Torch_AtenLog2_Op : Torch_Op<"aten.log2_", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLog2_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenAbs_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenLog2_Op::print(OpAsmPrinter &printer) { + void AtenAbs_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenLog10Op : Torch_Op<"aten.log10", [ +def Torch_AtenReciprocalOp : Torch_Op<"aten.reciprocal", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::log10 : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::reciprocal : (Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self ); @@ -2854,20 +2844,20 @@ def Torch_AtenLog10Op : Torch_Op<"aten.log10", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLog10Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenReciprocalOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenLog10Op::print(OpAsmPrinter &printer) { + void AtenReciprocalOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenLog10_Op : Torch_Op<"aten.log10_", [ +def Torch_AtenReciprocal_Op : Torch_Op<"aten.reciprocal_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::log10_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::reciprocal_ : (Tensor) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self ); @@ -2876,904 +2866,931 @@ def Torch_AtenLog10_Op : Torch_Op<"aten.log10_", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLog10_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenReciprocal_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenLog10_Op::print(OpAsmPrinter &printer) { + void AtenReciprocal_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenSqrtOp : Torch_Op<"aten.sqrt", [ +def Torch_AtenBitwiseAndTensorOp : Torch_Op<"aten.bitwise_and.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::sqrt : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSqrtOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenBitwiseAndTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenSqrtOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenBitwiseAndTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenSqrt_Op : Torch_Op<"aten.sqrt_", [ +def Torch_AtenBitwiseAnd_TensorOp : Torch_Op<"aten.bitwise_and_.Tensor", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::sqrt_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_and_.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$other ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSqrt_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenBitwiseAnd_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenSqrt_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenBitwiseAnd_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenLog1pOp : Torch_Op<"aten.log1p", [ +def Torch_AtenBitwiseAndScalarOp : Torch_Op<"aten.bitwise_and.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::log1p : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchScalarType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLog1pOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenBitwiseAndScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenLog1pOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenBitwiseAndScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenLog1p_Op : Torch_Op<"aten.log1p_", [ +def Torch_AtenBitwiseAnd_ScalarOp : Torch_Op<"aten.bitwise_and_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::log1p_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_and_.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self + Torch_NonValueTensorType:$self, + AnyTorchScalarType:$other ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLog1p_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenBitwiseAnd_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenLog1p_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenBitwiseAnd_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenLogitOp : Torch_Op<"aten.logit", [ +def Torch_AtenBitwiseOrTensorOp : Torch_Op<"aten.bitwise_or.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::logit : (Tensor, float?) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchOptionalFloatType:$eps + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLogitOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenBitwiseOrTensorOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenLogitOp::print(OpAsmPrinter &printer) { + void AtenBitwiseOrTensorOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenLogit_Op : Torch_Op<"aten.logit_", [ +def Torch_AtenBitwiseOr_TensorOp : Torch_Op<"aten.bitwise_or_.Tensor", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::logit_ : (Tensor, float?) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_or_.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - AnyTorchOptionalFloatType:$eps + Torch_NonValueTensorType:$other ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenLogit_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenBitwiseOr_TensorOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenLogit_Op::print(OpAsmPrinter &printer) { + void AtenBitwiseOr_TensorOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenRsqrtOp : Torch_Op<"aten.rsqrt", [ +def Torch_AtenBitwiseXorTensorOp : Torch_Op<"aten.bitwise_xor.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::rsqrt : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_xor.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenRsqrtOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenBitwiseXorTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenRsqrtOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenBitwiseXorTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenRsqrt_Op : Torch_Op<"aten.rsqrt_", [ +def Torch_AtenBitwiseXor_TensorOp : Torch_Op<"aten.bitwise_xor_.Tensor", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::rsqrt_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_xor_.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$other ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenRsqrt_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenBitwiseXor_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenRsqrt_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenBitwiseXor_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenAbsOp : Torch_Op<"aten.abs", [ +def Torch_AtenBitwiseLeftShiftTensorOp : Torch_Op<"aten.bitwise_left_shift.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::abs : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_left_shift.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAbsOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenBitwiseLeftShiftTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenAbsOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenBitwiseLeftShiftTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenAbs_Op : Torch_Op<"aten.abs_", [ +def Torch_AtenBitwiseLeftShift_TensorOp : Torch_Op<"aten.bitwise_left_shift_.Tensor", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::abs_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_left_shift_.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$other ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAbs_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenBitwiseLeftShift_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenAbs_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenBitwiseLeftShift_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenReciprocalOp : Torch_Op<"aten.reciprocal", [ +def Torch_AtenBitwiseRightShiftTensorOp : Torch_Op<"aten.bitwise_right_shift.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::reciprocal : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenReciprocalOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenBitwiseRightShiftTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenReciprocalOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenBitwiseRightShiftTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenReciprocal_Op : Torch_Op<"aten.reciprocal_", [ +def Torch_AtenBitwiseRightShift_TensorOp : Torch_Op<"aten.bitwise_right_shift_.Tensor", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::reciprocal_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::bitwise_right_shift_.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$other ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenReciprocal_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenBitwiseRightShift_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenReciprocal_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenBitwiseRightShift_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenBitwiseAndTensorOp : Torch_Op<"aten.bitwise_and.Tensor", [ +def Torch_AtenThresholdOp : Torch_Op<"aten.threshold", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::bitwise_and.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$other + AnyTorchScalarType:$threshold, + AnyTorchScalarType:$value ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBitwiseAndTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenThresholdOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenBitwiseAndTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenThresholdOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenBitwiseAnd_TensorOp : Torch_Op<"aten.bitwise_and_.Tensor", [ +def Torch_AtenThreshold_Op : Torch_Op<"aten.threshold_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::bitwise_and_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::threshold_ : (Tensor, Scalar, Scalar) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$other + AnyTorchScalarType:$threshold, + AnyTorchScalarType:$value ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBitwiseAnd_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenThreshold_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenBitwiseAnd_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenThreshold_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenBitwiseAndScalarOp : Torch_Op<"aten.bitwise_and.Scalar", [ +def Torch_AtenSquareOp : Torch_Op<"aten.square", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::bitwise_and.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::square : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchScalarType:$other + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBitwiseAndScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenSquareOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenBitwiseAndScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenSquareOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenBitwiseAnd_ScalarOp : Torch_Op<"aten.bitwise_and_.Scalar", [ +def Torch_AtenSquare_Op : Torch_Op<"aten.square_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::bitwise_and_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::square_ : (Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self, - AnyTorchScalarType:$other + Torch_NonValueTensorType:$self ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBitwiseAnd_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenSquare_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenBitwiseAnd_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenSquare_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenBitwiseOrTensorOp : Torch_Op<"aten.bitwise_or.Tensor", [ +def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [ AllowsTypeRefinement, - HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::bitwise_or.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::unsqueeze : (Tensor, int) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$other + Torch_IntType:$dim ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBitwiseOrTensorOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenUnsqueezeOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenBitwiseOrTensorOp::print(OpAsmPrinter &printer) { + void AtenUnsqueezeOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenBitwiseOr_TensorOp : Torch_Op<"aten.bitwise_or_.Tensor", [ +def Torch_AtenUnsqueeze_Op : Torch_Op<"aten.unsqueeze_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::bitwise_or_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::unsqueeze_ : (Tensor, int) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$other + Torch_IntType:$dim ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBitwiseOr_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenUnsqueeze_Op::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenBitwiseOr_TensorOp::print(OpAsmPrinter &printer) { + void AtenUnsqueeze_Op::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenBitwiseXorTensorOp : Torch_Op<"aten.bitwise_xor.Tensor", [ +def Torch_AtenZeroOp : Torch_Op<"aten.zero", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::bitwise_xor.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::zero : (Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self, - AnyTorchTensorType:$other + AnyTorchTensorType:$self ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBitwiseXorTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenZeroOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenBitwiseXorTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenZeroOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenBitwiseXor_TensorOp : Torch_Op<"aten.bitwise_xor_.Tensor", [ +def Torch_AtenZero_Op : Torch_Op<"aten.zero_", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::bitwise_xor_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::zero_ : (Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$other + Torch_NonValueTensorType:$self ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBitwiseXor_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenZero_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); } - void AtenBitwiseXor_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenZero_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); } }]; } -def Torch_AtenBitwiseLeftShiftTensorOp : Torch_Op<"aten.bitwise_left_shift.Tensor", [ +def Torch_AtenFillScalarOp : Torch_Op<"aten.fill.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::bitwise_left_shift.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::fill.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$other + AnyTorchScalarType:$value ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBitwiseLeftShiftTensorOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenFillScalarOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenBitwiseLeftShiftTensorOp::print(OpAsmPrinter &printer) { + void AtenFillScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenBitwiseLeftShift_TensorOp : Torch_Op<"aten.bitwise_left_shift_.Tensor", [ +def Torch_AtenFill_ScalarOp : Torch_Op<"aten.fill_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::bitwise_left_shift_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::fill_.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$other + AnyTorchScalarType:$value ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBitwiseLeftShift_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenFill_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenBitwiseLeftShift_TensorOp::print(OpAsmPrinter &printer) { + void AtenFill_ScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenBitwiseRightShiftTensorOp : Torch_Op<"aten.bitwise_right_shift.Tensor", [ +def Torch_AtenFillTensorOp : Torch_Op<"aten.fill.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::bitwise_right_shift.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::fill.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$other + AnyTorchTensorType:$value ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBitwiseRightShiftTensorOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenFillTensorOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenBitwiseRightShiftTensorOp::print(OpAsmPrinter &printer) { + void AtenFillTensorOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenBitwiseRightShift_TensorOp : Torch_Op<"aten.bitwise_right_shift_.Tensor", [ +def Torch_AtenFill_TensorOp : Torch_Op<"aten.fill_.Tensor", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::bitwise_right_shift_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::fill_.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$other + Torch_NonValueTensorType:$value ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenBitwiseRightShift_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenFill_TensorOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenBitwiseRightShift_TensorOp::print(OpAsmPrinter &printer) { + void AtenFill_TensorOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenThresholdOp : Torch_Op<"aten.threshold", [ +def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$threshold, - AnyTorchScalarType:$value + AnyTorchTensorType:$other, + AnyTorchOptionalStringType:$rounding_mode ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenThresholdOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenDivTensorModeOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenThresholdOp::print(OpAsmPrinter &printer) { + void AtenDivTensorModeOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasCanonicalizer = 1; } -def Torch_AtenThreshold_Op : Torch_Op<"aten.threshold_", [ +def Torch_AtenDiv_TensorModeOp : Torch_Op<"aten.div_.Tensor_mode", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::threshold_ : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::div_.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - AnyTorchScalarType:$threshold, - AnyTorchScalarType:$value + Torch_NonValueTensorType:$other, + AnyTorchOptionalStringType:$rounding_mode ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenThreshold_Op::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenDiv_TensorModeOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenThreshold_Op::print(OpAsmPrinter &printer) { + void AtenDiv_TensorModeOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenSquareOp : Torch_Op<"aten.square", [ +def Torch_AtenMulTensorOp : Torch_Op<"aten.mul.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::square : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchTensorType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSquareOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenMulTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenSquareOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenMulTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; + let hasFolder = 1; + let hasCanonicalizer = 1; } -def Torch_AtenSquare_Op : Torch_Op<"aten.square_", [ +def Torch_AtenMul_TensorOp : Torch_Op<"aten.mul_.Tensor", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::square_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::mul_.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$other ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSquare_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenMul_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenSquare_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenMul_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenUnsqueezeOp : Torch_Op<"aten.unsqueeze", [ +def Torch_AtenAddTensorOp : Torch_Op<"aten.add.Tensor", [ AllowsTypeRefinement, + HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::unsqueeze : (Tensor, int) -> (Tensor)`"; + let summary = "Generated op for `aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - Torch_IntType:$dim + AnyTorchTensorType:$other, + AnyTorchScalarType:$alpha ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenUnsqueezeOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenAddTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenUnsqueezeOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenAddTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; + let hasCanonicalizer = 1; } -def Torch_AtenUnsqueeze_Op : Torch_Op<"aten.unsqueeze_", [ +def Torch_AtenAdd_TensorOp : Torch_Op<"aten.add_.Tensor", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::unsqueeze_ : (Tensor, int) -> (Tensor)`"; + let summary = "Generated op for `aten::add_.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - Torch_IntType:$dim + Torch_NonValueTensorType:$other, + AnyTorchScalarType:$alpha ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenUnsqueeze_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenAdd_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenUnsqueeze_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenAdd_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenZeroOp : Torch_Op<"aten.zero", [ +def Torch_AtenSubTensorOp : Torch_Op<"aten.sub.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::zero : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`"; let arguments = (ins - AnyTorchTensorType:$self + AnyTorchTensorType:$self, + AnyTorchTensorType:$other, + AnyTorchScalarType:$alpha ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenZeroOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenSubTensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenZeroOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenSubTensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasFolder = 1; + let hasCanonicalizer = 1; } -def Torch_AtenZero_Op : Torch_Op<"aten.zero_", [ +def Torch_AtenSub_TensorOp : Torch_Op<"aten.sub_.Tensor", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::zero_ : (Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::sub_.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`"; let arguments = (ins - Torch_NonValueTensorType:$self + Torch_NonValueTensorType:$self, + Torch_NonValueTensorType:$other, + AnyTorchScalarType:$alpha ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenZero_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); + ParseResult AtenSub_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenZero_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); + void AtenSub_TensorOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenFillScalarOp : Torch_Op<"aten.fill.Scalar", [ +def Torch_AtenAddScalarOp : Torch_Op<"aten.add.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::fill.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$value + AnyTorchScalarType:$other, + AnyTorchScalarType:$alpha ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFillScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenAddScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenFillScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenAddScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasCanonicalizer = 1; } -def Torch_AtenFill_ScalarOp : Torch_Op<"aten.fill_.Scalar", [ +def Torch_AtenAdd_ScalarOp : Torch_Op<"aten.add_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::fill_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::add_.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - AnyTorchScalarType:$value + AnyTorchScalarType:$other, + AnyTorchScalarType:$alpha ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFill_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenAdd_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenFill_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenAdd_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenFillTensorOp : Torch_Op<"aten.fill.Tensor", [ +def Torch_AtenSubScalarOp : Torch_Op<"aten.sub.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::fill.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$value + AnyTorchScalarType:$other, + AnyTorchScalarType:$alpha ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFillTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenSubScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenFillTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenSubScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; + let hasCanonicalizer = 1; } -def Torch_AtenFill_TensorOp : Torch_Op<"aten.fill_.Tensor", [ +def Torch_AtenSub_ScalarOp : Torch_Op<"aten.sub_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::fill_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::sub_.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$value + AnyTorchScalarType:$other, + AnyTorchScalarType:$alpha ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenFill_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 2, 1); + ParseResult AtenSub_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); } - void AtenFill_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 2, 1); + void AtenSub_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); } }]; } -def Torch_AtenDivTensorModeOp : Torch_Op<"aten.div.Tensor_mode", [ +def Torch_AtenMulScalarOp : Torch_Op<"aten.mul.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::div.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)`"; + let summary = "Generated op for `aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$other, - AnyTorchOptionalStringType:$rounding_mode + AnyTorchScalarType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenDivTensorModeOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenMulScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenDivTensorModeOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenMulScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; let hasCanonicalizer = 1; } -def Torch_AtenDiv_TensorModeOp : Torch_Op<"aten.div_.Tensor_mode", [ +def Torch_AtenMul_ScalarOp : Torch_Op<"aten.mul_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::div_.Tensor_mode : (Tensor, Tensor, str?) -> (Tensor)`"; + let summary = "Generated op for `aten::mul_.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$other, - AnyTorchOptionalStringType:$rounding_mode + AnyTorchScalarType:$other ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenDiv_TensorModeOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenMul_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenDiv_TensorModeOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenMul_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenMulTensorOp : Torch_Op<"aten.mul.Tensor", [ +def Torch_AtenEqTensorOp : Torch_Op<"aten.eq.Tensor", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::mul.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchTensorType:$other @@ -3783,22 +3800,21 @@ def Torch_AtenMulTensorOp : Torch_Op<"aten.mul.Tensor", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMulTensorOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenEqTensorOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenMulTensorOp::print(OpAsmPrinter &printer) { + void AtenEqTensorOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; let hasFolder = 1; - let hasCanonicalizer = 1; } -def Torch_AtenMul_TensorOp : Torch_Op<"aten.mul_.Tensor", [ +def Torch_AtenEq_TensorOp : Torch_Op<"aten.eq_.Tensor", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::mul_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::eq_.Tensor : (Tensor, Tensor) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, Torch_NonValueTensorType:$other @@ -3808,223 +3824,213 @@ def Torch_AtenMul_TensorOp : Torch_Op<"aten.mul_.Tensor", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMul_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenEq_TensorOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenMul_TensorOp::print(OpAsmPrinter &printer) { + void AtenEq_TensorOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenAddTensorOp : Torch_Op<"aten.add.Tensor", [ +def Torch_AtenLeScalarOp : Torch_Op<"aten.le.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::add.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::le.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$other, - AnyTorchScalarType:$alpha + AnyTorchScalarType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAddTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenLeScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenAddTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenLeScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; let hasFolder = 1; - let hasCanonicalizer = 1; } -def Torch_AtenAdd_TensorOp : Torch_Op<"aten.add_.Tensor", [ +def Torch_AtenLe_ScalarOp : Torch_Op<"aten.le_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::add_.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::le_.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$other, - AnyTorchScalarType:$alpha + AnyTorchScalarType:$other ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAdd_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenLe_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenAdd_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenLe_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenSubTensorOp : Torch_Op<"aten.sub.Tensor", [ +def Torch_AtenLtScalarOp : Torch_Op<"aten.lt.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::sub.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$other, - AnyTorchScalarType:$alpha + AnyTorchScalarType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSubTensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenLtScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenSubTensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenLtScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; let hasFolder = 1; - let hasCanonicalizer = 1; } -def Torch_AtenSub_TensorOp : Torch_Op<"aten.sub_.Tensor", [ +def Torch_AtenLt_ScalarOp : Torch_Op<"aten.lt_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::sub_.Tensor : (Tensor, Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::lt_.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$other, - AnyTorchScalarType:$alpha + AnyTorchScalarType:$other ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSub_TensorOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenLt_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenSub_TensorOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenLt_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenAddScalarOp : Torch_Op<"aten.add.Scalar", [ +def Torch_AtenGtScalarOp : Torch_Op<"aten.gt.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::add.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other, - AnyTorchScalarType:$alpha + AnyTorchScalarType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAddScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenGtScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenAddScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenGtScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; - let hasCanonicalizer = 1; + let hasFolder = 1; } -def Torch_AtenAdd_ScalarOp : Torch_Op<"aten.add_.Scalar", [ +def Torch_AtenGt_ScalarOp : Torch_Op<"aten.gt_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::add_.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::gt_.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - AnyTorchScalarType:$other, - AnyTorchScalarType:$alpha + AnyTorchScalarType:$other ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenAdd_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenGt_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenAdd_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenGt_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenSubScalarOp : Torch_Op<"aten.sub.Scalar", [ +def Torch_AtenGeScalarOp : Torch_Op<"aten.ge.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchScalarType:$other, - AnyTorchScalarType:$alpha + AnyTorchScalarType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSubScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenGeScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenSubScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenGeScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; - let hasCanonicalizer = 1; + let hasFolder = 1; } -def Torch_AtenSub_ScalarOp : Torch_Op<"aten.sub_.Scalar", [ +def Torch_AtenGe_ScalarOp : Torch_Op<"aten.ge_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::sub_.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::ge_.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - AnyTorchScalarType:$other, - AnyTorchScalarType:$alpha + AnyTorchScalarType:$other ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenSub_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 3, 1); + ParseResult AtenGe_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenSub_ScalarOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 3, 1); + void AtenGe_ScalarOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenMulScalarOp : Torch_Op<"aten.mul.Scalar", [ +def Torch_AtenEqScalarOp : Torch_Op<"aten.eq.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, AnyTorchScalarType:$other @@ -4034,21 +4040,21 @@ def Torch_AtenMulScalarOp : Torch_Op<"aten.mul.Scalar", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMulScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenEqScalarOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenMulScalarOp::print(OpAsmPrinter &printer) { + void AtenEqScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; - let hasCanonicalizer = 1; + let hasFolder = 1; } -def Torch_AtenMul_ScalarOp : Torch_Op<"aten.mul_.Scalar", [ +def Torch_AtenEq_ScalarOp : Torch_Op<"aten.eq_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::mul_.Scalar : (Tensor, Scalar) -> (Tensor)`"; + let summary = "Generated op for `aten::eq_.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, AnyTorchScalarType:$other @@ -4058,58 +4064,58 @@ def Torch_AtenMul_ScalarOp : Torch_Op<"aten.mul_.Scalar", [ ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenMul_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenEq_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenMul_ScalarOp::print(OpAsmPrinter &printer) { + void AtenEq_ScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; } -def Torch_AtenEqTensorOp : Torch_Op<"aten.eq.Tensor", [ +def Torch_AtenNeScalarOp : Torch_Op<"aten.ne.Scalar", [ AllowsTypeRefinement, HasValueSemantics, ReadOnly ]> { - let summary = "Generated op for `aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins AnyTorchTensorType:$self, - AnyTorchTensorType:$other + AnyTorchScalarType:$other ); let results = (outs AnyTorchTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenEqTensorOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenNeScalarOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenEqTensorOp::print(OpAsmPrinter &printer) { + void AtenNeScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; let hasFolder = 1; } -def Torch_AtenEq_TensorOp : Torch_Op<"aten.eq_.Tensor", [ +def Torch_AtenNe_ScalarOp : Torch_Op<"aten.ne_.Scalar", [ IsTrailingUnderscoreInplaceVariant, AllowsTypeRefinement ]> { - let summary = "Generated op for `aten::eq_.Tensor : (Tensor, Tensor) -> (Tensor)`"; + let summary = "Generated op for `aten::ne_.Scalar : (Tensor, Scalar) -> (Tensor)`"; let arguments = (ins Torch_NonValueTensorType:$self, - Torch_NonValueTensorType:$other + AnyTorchScalarType:$other ); let results = (outs Torch_NonValueTensorType:$result ); let hasCustomAssemblyFormat = 1; let extraClassDefinition = [{ - ParseResult AtenEq_TensorOp::parse(OpAsmParser &parser, OperationState &result) { + ParseResult AtenNe_ScalarOp::parse(OpAsmParser &parser, OperationState &result) { return parseDefaultTorchOp(parser, result, 2, 1); } - void AtenEq_TensorOp::print(OpAsmPrinter &printer) { + void AtenNe_ScalarOp::print(OpAsmPrinter &printer) { printDefaultTorchOp(printer, *this, 2, 1); } }]; diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 8a677b8ce058..a7bdddbc8d78 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -591,7 +591,7 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( Value one = rewriter.create( loc, intTy, rewriter.getI64IntegerAttr(1)); Value lt = - rewriter.create(loc, boolTy, indices, zero); + rewriter.create(loc, boolTy, indices, zero); Value dim = rewriter.create(loc, intTy, data, index); Value add = rewriter.create(loc, indicesTy, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 03f39be9c806..9f10c8bce3ba 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -1481,6 +1481,197 @@ OpFoldResult AtenEqTensorOp::fold(FoldAdaptor adaptor) { return nullptr; } +//===----------------------------------------------------------------------===// +// AtenLeScalarOp +//===----------------------------------------------------------------------===// + +using ComparisonFoldFpOperator = std::function; +using ComparisonFoldIntOperator = std::function; + +static OpFoldResult comparisonScaleFolder(DenseElementsAttr lhs, Attribute rhs, + ValueTensorType resultTy, + ComparisonFoldFpOperator fpFolder, + ComparisonFoldIntOperator intFolder) { + constexpr int64_t kMaxFold = 16; + if (!lhs || !rhs || !resultTy) + return nullptr; + if (!resultTy.hasSizes() || !resultTy.hasDtype()) + return nullptr; + + for (auto size : resultTy.getSizes()) + if (size == Torch::kUnknownSize) + return nullptr; + + auto ctx = lhs.getContext(); + auto resultETy = resultTy.getDtype(); + auto tensorETy = cast(lhs.getType()).getElementType(); + if (lhs.isSplat()) { + if (auto intAttr = dyn_cast(rhs)) { + auto unsign = cast(tensorETy).isUnsigned(); + auto scalarAP = intAttr.getValue(); + auto tensorAP = lhs.getSplatValue().getValue(); + tensorAP = APInt( + scalarAP.getBitWidth(), + unsign ? tensorAP.getZExtValue() : tensorAP.getSExtValue(), !unsign); + auto resultBool = intFolder(tensorAP, scalarAP, unsign); + auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool); + return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), + resultAP); + } + + if (auto floatAttr = dyn_cast(rhs)) { + APFloat scalarAP = floatAttr.getValue(); + APFloat tensorAP = lhs.getSplatValue().getValue(); + auto resultBool = + fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble()); + auto resultAP = IntegerAttr::get(IntegerType::get(ctx, 1), resultBool); + return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), + resultAP); + } + return nullptr; + } + + int64_t count = 1; + for (auto size : resultTy.getSizes()) + count *= size; + + if (count > kMaxFold) + return nullptr; + + if (auto intAttr = dyn_cast(rhs)) { + auto unsign = cast(tensorETy).isUnsigned(); + llvm::SmallVector values; + for (auto tensorAP : lhs.getValues()) { + auto scalarAP = intAttr.getValue(); + tensorAP = APInt( + scalarAP.getBitWidth(), + unsign ? tensorAP.getZExtValue() : tensorAP.getSExtValue(), !unsign); + auto resultBool = intFolder(tensorAP, scalarAP, unsign); + values.push_back(resultBool); + } + return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), + values); + } + + if (auto floatAttr = dyn_cast(rhs)) { + llvm::SmallVector values; + for (auto tensorAP : lhs.getValues()) { + APFloat scalarAP = floatAttr.getValue(); + auto resultBool = + fpFolder(tensorAP.convertToDouble(), scalarAP.convertToDouble()); + values.push_back(resultBool); + } + return DenseElementsAttr::get(resultTy.toBuiltinTensor().clone(resultETy), + values); + } + + return nullptr; +} + +OpFoldResult AtenLeScalarOp::fold(FoldAdaptor adaptor) { + auto self = dyn_cast_or_null(adaptor.getSelf()); + auto other = adaptor.getOther(); + auto resultTy = dyn_cast(getType()); + + auto fpFold = [](double lhs, double rhs) -> bool { return lhs <= rhs; }; + + auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool { + return unsign ? lhs.ule(rhs) : lhs.sle(rhs); + }; + + return comparisonScaleFolder(self, other, resultTy, fpFold, intFold); +} + +//===----------------------------------------------------------------------===// +// AtenLtScalarOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenLtScalarOp::fold(FoldAdaptor adaptor) { + auto self = dyn_cast_or_null(adaptor.getSelf()); + auto other = adaptor.getOther(); + auto resultTy = dyn_cast(getType()); + + auto fpFold = [](double lhs, double rhs) -> bool { return lhs < rhs; }; + + auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool { + return unsign ? lhs.ult(rhs) : lhs.slt(rhs); + }; + + return comparisonScaleFolder(self, other, resultTy, fpFold, intFold); +} + +//===----------------------------------------------------------------------===// +// AtenGtScalarOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenGtScalarOp::fold(FoldAdaptor adaptor) { + auto self = dyn_cast_or_null(adaptor.getSelf()); + auto other = adaptor.getOther(); + auto resultTy = dyn_cast(getType()); + + auto fpFold = [](double lhs, double rhs) -> bool { return lhs > rhs; }; + + auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool { + return unsign ? lhs.ugt(rhs) : lhs.sgt(rhs); + }; + + return comparisonScaleFolder(self, other, resultTy, fpFold, intFold); +} + +//===----------------------------------------------------------------------===// +// AtenGeScalarOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenGeScalarOp::fold(FoldAdaptor adaptor) { + auto self = dyn_cast_or_null(adaptor.getSelf()); + auto other = adaptor.getOther(); + auto resultTy = dyn_cast(getType()); + + auto fpFold = [](double lhs, double rhs) -> bool { return lhs >= rhs; }; + + auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool { + return unsign ? lhs.uge(rhs) : lhs.sge(rhs); + }; + + return comparisonScaleFolder(self, other, resultTy, fpFold, intFold); +} + +//===----------------------------------------------------------------------===// +// AtenEqScalarOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenEqScalarOp::fold(FoldAdaptor adaptor) { + auto self = dyn_cast_or_null(adaptor.getSelf()); + auto other = adaptor.getOther(); + auto resultTy = dyn_cast(getType()); + + auto fpFold = [](double lhs, double rhs) -> bool { return lhs == rhs; }; + + auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool { + return lhs.eq(rhs); + }; + + return comparisonScaleFolder(self, other, resultTy, fpFold, intFold); +} + +//===----------------------------------------------------------------------===// +// AtenNeScalarOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenNeScalarOp::fold(FoldAdaptor adaptor) { + auto self = dyn_cast_or_null(adaptor.getSelf()); + auto other = adaptor.getOther(); + auto resultTy = dyn_cast(getType()); + + auto fpFold = [](double lhs, double rhs) -> bool { return lhs != rhs; }; + + auto intFold = [](APInt lhs, APInt rhs, bool unsign) -> bool { + return lhs.ne(rhs); + }; + + return comparisonScaleFolder(self, other, resultTy, fpFold, intFold); +} + //===----------------------------------------------------------------------===// // AtenFloorOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index d16a20893dbf..695b51c18c2c 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1495,11 +1495,6 @@ "FlipNegativeIndexModule_basic", "HardsigmoidModule_basic", "HardsigmoidRandomModule_basic", - "IndexSelectDynamicInputSizeModule_basic", - "IndexSelectWholeDimensionModule_basic", - "IndexSelectWholeTensorModule_basic", - "IndexTensorStaticModule_basic", - "IndexTensorStaticNonContiguousWithNoneModule_basic", "PixelShuffleModuleStaticRank4Float32_basic", "ResNet18Module_basic", "SliceCopyEndGreaterThanDimSize_Module_basic", @@ -1998,24 +1993,15 @@ "NativeDropoutTrainModule_basic", "NativeDropoutTrainStaticShapeModule_basic", "ReduceProdDimIntFloatModule_basic", - "StdCorrectionAllDimReduceModule_basic", - "StdCorrectionKeepDimModule_basic", "StdCorrectionLargeInputModule_basic", "StdCorrectionModule_basic", "StdCorrectionNoneModule_basic", "StdDimNoneDimModule_basic", "StdUnbiasedModule_basic", - "VarCorrectionAllDimReduceModule_basic", - "VarCorrectionKeepDimModule_basic", "VarCorrectionLargeInputModule_basic", "VarCorrectionModule_basic", "VarCorrectionNoneModule_basic", - "VarDimAllDimReduceModule_basic", - "VarDimModule_basic", - "VarDimMultiDimModule_basic", "VarDimNoneDimModule_basic", - "VarDimSingleDimModule_basic", - "VarDimUnbiasedModule_basic", "VarMeanCorrectionNoneModule_basic", "VarMeanUnbiasedModule_basic", "VarUnbiasedModule_basic", @@ -2110,9 +2096,6 @@ "IndexTensorMultiInputOneDim_basic", "IndexTensorMultiInputThreeIndexers_basic", "IndexTensorMultiInput_basic", - "IndexTensorStaticContiguousWithNoneModule_basic", - "SelectIntModule_basic", - "SliceSingleIdxModule_basic", "ViewFlattenAndExpandModule_basic", "ViewSizeDimFollowedByCollapsedOnesModule_basic", "ViewSizeDimFollowedByExpandedOnesModule_basic", @@ -2151,7 +2134,6 @@ "FlattenDynamicModule_basic", "GluStaticModule_basic", "GroupNormModule_basic", - "IndexSelectDynamicModulebasic", "IndexTensorHackedTwinModule3dInput_basic", "IndexTensorHackedTwinModule_basic", "IndexTensorModule3dInput_basic", @@ -2169,11 +2151,5 @@ "TensorsStackPromoteDTypeModule_basic", } -if torch_version_for_comparison() < version.parse("2.3.0.dev"): - ONNX_XFAIL_SET = ONNX_XFAIL_SET | { - # ERROR: dtype (torch.float64) is not equal to golden dtype (torch.float32) - "ElementwiseWhereScalarModule_basic", - } - ONNX_CRASHING_SET = { } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 2b0ec4aee1cb..ba41d4220e2f 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -301,12 +301,6 @@ def emit_with_mutating_variants(key, **kwargs): "aten::le.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::ne.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::div.Scalar : (Tensor, Scalar) -> (Tensor)", - "aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)", - "aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)", - "aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)", - "aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)", - "aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)", - "aten::le.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::fmod.Scalar : (Tensor, Scalar) -> (Tensor)", "aten::masked_fill.Scalar : (Tensor, Tensor, Scalar) -> (Tensor)", "aten::clamp : (Tensor, Scalar?, Scalar?) -> (Tensor)", @@ -347,6 +341,12 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants("aten::sub.Scalar : (Tensor, Scalar, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::mul.Scalar : (Tensor, Scalar) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::eq.Tensor : (Tensor, Tensor) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::le.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::lt.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::gt.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py index d9921d23d677..689fe182f57c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/elementwise.py @@ -413,7 +413,7 @@ def __init__(self): ([-1, -1, -1], torch.float32, True), ]) def forward(self, a): - return torch.where(a > 0.5, 4.0, 8.0) + return torch.where(a > 0.5, 4.0, 8.0).to(torch.float) @register_test_case(module_factory=lambda: ElementwiseWhereScalarModule()) diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index d1f4307d4de6..9dceff316eaa 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -42,7 +42,7 @@ func.func @test_gather_nd(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch.vten // CHECK: %[[AXIS:.+]] = torch.constant.int 0 // CHECK: %[[ZERO:.+]] = torch.constant.int 0 // CHECK: %[[ONE:.+]] = torch.constant.int 1 - // CHECK: %[[LT:.+]] = torch.aten.le.Scalar %arg1, %[[ZERO]] + // CHECK: %[[LT:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]] // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]] // CHECK: %[[SEL:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %arg1 @@ -72,7 +72,7 @@ func.func @test_gather_scalar(%arg0: !torch.vtensor<[3,4,5],f32>, %arg1: !torch. // CHECK: %[[AXIS:.+]] = torch.constant.int 0 // CHECK: %[[ZERO:.+]] = torch.constant.int 0 // CHECK: %[[ONE:.+]] = torch.constant.int 1 - // CHECK: %[[LT:.+]] = torch.aten.le.Scalar %arg1, %[[ZERO]] + // CHECK: %[[LT:.+]] = torch.aten.lt.Scalar %arg1, %[[ZERO]] // CHECK: %[[SZ:.+]] = torch.aten.size.int %arg0, %[[AXIS]] // CHECK: %[[ADD:.+]] = torch.aten.add.Scalar %arg1, %[[SZ]], %[[ONE]] // CHECK: %[[SEL:.+]] = torch.aten.where.self %[[LT]], %[[ADD]], %arg1 diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index 2b5405b75197..a607365f4918 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -2708,3 +2708,128 @@ func.func @aten_cat_zero(%arg0 : !torch.vtensor<[4,5,6],f32>, %arg1 : !torch.vte %0 = torch.aten.cat %list, %dim : !torch.list, !torch.int -> !torch.vtensor<[4,5,6],f32> return %0 : !torch.vtensor<[4,5,6],f32> } + +// ----- + +// CHECK-LABEL: @aten_tensor_scalar_lt +func.func @aten_tensor_scalar_lt() -> (!torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>) { + // CHECK: %[[CST:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[CST]], %[[CST]] : !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1> + %intTensor = torch.vtensor.literal(dense<1> : tensor<4xsi8>) : !torch.vtensor<[4],si8> + %fpTensor = torch.vtensor.literal(dense<1.0> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %intScalar = torch.constant.int 2 + %fpScalar = torch.constant.float 2.0 + %intBool = torch.aten.lt.Scalar %intTensor, %intScalar : !torch.vtensor<[4],si8>, !torch.int -> !torch.vtensor<[4],i1> + %fpBool = torch.aten.lt.Scalar %fpTensor, %fpScalar : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],i1> + return %intBool, %fpBool : !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1> +} + + +// ----- + +// CHECK-LABEL: @aten_tensor_tensor_lt +func.func @aten_tensor_tensor_lt() -> (!torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>) { + // CHECK: %[[UNSIGN:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: %[[SIGNED:.+]] = torch.vtensor.literal(dense<[true, false, false, false]> : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[UNSIGN]], %[[SIGNED]], %[[SIGNED]] + %intTensor = torch.vtensor.literal(dense<[127, -128, -127, -126]> : tensor<4xsi8>) : !torch.vtensor<[4],si8> + %uintTensor = torch.vtensor.literal(dense<[127, 128, 129, 130]> : tensor<4xui8>) : !torch.vtensor<[4],ui8> + %fpTensor = torch.vtensor.literal(dense<[127.0, 128.0, 129.0, 130.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %intScalar = torch.constant.int 128 + %fpScalar = torch.constant.float 128.0 + %intBool = torch.aten.lt.Scalar %intTensor, %intScalar : !torch.vtensor<[4],si8>, !torch.int -> !torch.vtensor<[4],i1> + %uintBool = torch.aten.lt.Scalar %uintTensor, %intScalar : !torch.vtensor<[4],ui8>, !torch.int -> !torch.vtensor<[4],i1> + %fpBool = torch.aten.lt.Scalar %fpTensor, %fpScalar : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],i1> + return %intBool, %uintBool, %fpBool : !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_tensor_tensor_le +func.func @aten_tensor_tensor_le() -> (!torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>) { + // CHECK: %[[UNSIGN:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: %[[SIGNED:.+]] = torch.vtensor.literal(dense<[true, true, false, false]> : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[UNSIGN]], %[[SIGNED]], %[[SIGNED]] + %intTensor = torch.vtensor.literal(dense<[127, -128, -127, -126]> : tensor<4xsi8>) : !torch.vtensor<[4],si8> + %uintTensor = torch.vtensor.literal(dense<[127, 128, 129, 130]> : tensor<4xui8>) : !torch.vtensor<[4],ui8> + %fpTensor = torch.vtensor.literal(dense<[127.0, 128.0, 129.0, 130.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %intScalar = torch.constant.int 128 + %fpScalar = torch.constant.float 128.0 + %intBool = torch.aten.le.Scalar %intTensor, %intScalar : !torch.vtensor<[4],si8>, !torch.int -> !torch.vtensor<[4],i1> + %uintBool = torch.aten.le.Scalar %uintTensor, %intScalar : !torch.vtensor<[4],ui8>, !torch.int -> !torch.vtensor<[4],i1> + %fpBool = torch.aten.le.Scalar %fpTensor, %fpScalar : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],i1> + return %intBool, %uintBool, %fpBool : !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1> +} + + +// ----- + +// CHECK-LABEL: @aten_tensor_tensor_ge +func.func @aten_tensor_tensor_ge() -> (!torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>) { + // CHECK: %[[UNSIGN:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: %[[SIGNED:.+]] = torch.vtensor.literal(dense<[false, true, true, true]> : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[UNSIGN]], %[[SIGNED]], %[[SIGNED]] + %intTensor = torch.vtensor.literal(dense<[127, -128, -127, -126]> : tensor<4xsi8>) : !torch.vtensor<[4],si8> + %uintTensor = torch.vtensor.literal(dense<[127, 128, 129, 130]> : tensor<4xui8>) : !torch.vtensor<[4],ui8> + %fpTensor = torch.vtensor.literal(dense<[127.0, 128.0, 129.0, 130.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %intScalar = torch.constant.int 128 + %fpScalar = torch.constant.float 128.0 + %intBool = torch.aten.ge.Scalar %intTensor, %intScalar : !torch.vtensor<[4],si8>, !torch.int -> !torch.vtensor<[4],i1> + %uintBool = torch.aten.ge.Scalar %uintTensor, %intScalar : !torch.vtensor<[4],ui8>, !torch.int -> !torch.vtensor<[4],i1> + %fpBool = torch.aten.ge.Scalar %fpTensor, %fpScalar : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],i1> + return %intBool, %uintBool, %fpBool : !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_tensor_tensor_gt +func.func @aten_tensor_tensor_gt() -> (!torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>) { + // CHECK: %[[UNSIGN:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: %[[SIGNED:.+]] = torch.vtensor.literal(dense<[false, false, true, true]> : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[UNSIGN]], %[[SIGNED]], %[[SIGNED]] + %intTensor = torch.vtensor.literal(dense<[127, -128, -127, -126]> : tensor<4xsi8>) : !torch.vtensor<[4],si8> + %uintTensor = torch.vtensor.literal(dense<[127, 128, 129, 130]> : tensor<4xui8>) : !torch.vtensor<[4],ui8> + %fpTensor = torch.vtensor.literal(dense<[127.0, 128.0, 129.0, 130.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %intScalar = torch.constant.int 128 + %fpScalar = torch.constant.float 128.0 + %intBool = torch.aten.gt.Scalar %intTensor, %intScalar : !torch.vtensor<[4],si8>, !torch.int -> !torch.vtensor<[4],i1> + %uintBool = torch.aten.gt.Scalar %uintTensor, %intScalar : !torch.vtensor<[4],ui8>, !torch.int -> !torch.vtensor<[4],i1> + %fpBool = torch.aten.gt.Scalar %fpTensor, %fpScalar : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],i1> + return %intBool, %uintBool, %fpBool : !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_tensor_tensor_eq +func.func @aten_tensor_tensor_eq() -> (!torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>) { + // CHECK: %[[UNSIGN:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: %[[SIGNED:.+]] = torch.vtensor.literal(dense<[false, true, false, false]> : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[UNSIGN]], %[[SIGNED]], %[[SIGNED]] + %intTensor = torch.vtensor.literal(dense<[127, -128, -127, -126]> : tensor<4xsi8>) : !torch.vtensor<[4],si8> + %uintTensor = torch.vtensor.literal(dense<[127, 128, 129, 130]> : tensor<4xui8>) : !torch.vtensor<[4],ui8> + %fpTensor = torch.vtensor.literal(dense<[127.0, 128.0, 129.0, 130.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %intScalar = torch.constant.int 128 + %fpScalar = torch.constant.float 128.0 + %intBool = torch.aten.eq.Scalar %intTensor, %intScalar : !torch.vtensor<[4],si8>, !torch.int -> !torch.vtensor<[4],i1> + %uintBool = torch.aten.eq.Scalar %uintTensor, %intScalar : !torch.vtensor<[4],ui8>, !torch.int -> !torch.vtensor<[4],i1> + %fpBool = torch.aten.eq.Scalar %fpTensor, %fpScalar : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],i1> + return %intBool, %uintBool, %fpBool : !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1> +} + +// ----- + +// CHECK-LABEL: @aten_tensor_tensor_ne +func.func @aten_tensor_tensor_ne() -> (!torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>) { + // CHECK: %[[UNSIGN:.+]] = torch.vtensor.literal(dense : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: %[[SIGNED:.+]] = torch.vtensor.literal(dense<[true, false, true, true]> : tensor<4xi1>) : !torch.vtensor<[4],i1> + // CHECK: return %[[UNSIGN]], %[[SIGNED]], %[[SIGNED]] + %intTensor = torch.vtensor.literal(dense<[127, -128, -127, -126]> : tensor<4xsi8>) : !torch.vtensor<[4],si8> + %uintTensor = torch.vtensor.literal(dense<[127, 128, 129, 130]> : tensor<4xui8>) : !torch.vtensor<[4],ui8> + %fpTensor = torch.vtensor.literal(dense<[127.0, 128.0, 129.0, 130.0]> : tensor<4xf32>) : !torch.vtensor<[4],f32> + %intScalar = torch.constant.int 128 + %fpScalar = torch.constant.float 128.0 + %intBool = torch.aten.ne.Scalar %intTensor, %intScalar : !torch.vtensor<[4],si8>, !torch.int -> !torch.vtensor<[4],i1> + %uintBool = torch.aten.ne.Scalar %uintTensor, %intScalar : !torch.vtensor<[4],ui8>, !torch.int -> !torch.vtensor<[4],i1> + %fpBool = torch.aten.ne.Scalar %fpTensor, %fpScalar : !torch.vtensor<[4],f32>, !torch.float -> !torch.vtensor<[4],i1> + return %intBool, %uintBool, %fpBool : !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1>, !torch.vtensor<[4],i1> +} From bd7f1baa42f55dbdc19dd3f88cbd44bce937928c Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 8 Mar 2024 16:23:07 -0800 Subject: [PATCH 266/283] [onnx] Fix expand operation for dynamic shape max (#3001) If the broadcast shape is length-1 at a dim while `?` in the input dim then we need to broadcast to the dynamic dim. This is equivalent to taking a max of two dimensions. --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 36 +++++++-- projects/pt1/e2e_testing/xfail_sets.py | 2 - .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 78 ++++++++----------- 3 files changed, 62 insertions(+), 54 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 785c631c1bc9..c976b49842c8 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1495,23 +1495,34 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( patterns.onOp( "Expand", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { // uses ideas and code from onnx.Reshape + auto loc = binder.getLoc(); Torch::ValueTensorType resultType; Value data, shape; if (binder.tensorOperands(data, shape) || binder.tensorResultType(resultType)) return failure(); - Torch::BaseTensorType shapeType = - shape.getType().cast(); + + auto dataType = cast(data.getType()); + auto shapeType = cast(shape.getType()); + if (!dataType.hasSizes() || !shapeType.hasSizes()) + return failure(); + + auto shapeSizes = shapeType.getSizes(); + int64_t dataRank = dataType.getSizes().size(); + int64_t shapeRank = shapeSizes.size(); + if (shapeRank != 1 || shapeSizes[0] == Torch::kUnknownSize) + return failure(); + + auto rankDifference = dataRank - shapeSizes[0]; + SmallVector selectSizes; Type selectResultType = shapeType.getWithSizesAndDtype( llvm::ArrayRef(selectSizes), shapeType.getOptionalDtype()); // Variable to store 1-D onnx shape tensor, shapeSizes[0] has the // dimension size - auto shapeSizes = - dyn_cast(shape.getType()).getSizes(); // A constant zero value Value zero = rewriter.create( - binder.getLoc(), rewriter.getI64IntegerAttr(0)); + loc, rewriter.getI64IntegerAttr(0)); // Variable to store pytorch int list of shape (dimension) SmallVector dimList; @@ -1520,12 +1531,21 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( // takes list of int for (int i = 0; i < shapeSizes[0]; i++) { Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), + loc, rewriter.getType(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); Value extract = rewriter.create( - binder.getLoc(), selectResultType, shape, zero, selectIndex); + loc, selectResultType, shape, zero, selectIndex); Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); + loc, rewriter.getType(), extract); + + if (i + rankDifference >= 0) { + Value iv = + rewriter.create(loc, i + rankDifference); + auto sz = rewriter.create( + loc, rewriter.getType(), data, iv); + dim = rewriter.create(loc, dim, sz); + } + dimList.push_back(dim); } Value dimValueList = rewriter.create( diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 695b51c18c2c..dd49760188d2 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1507,8 +1507,6 @@ "ArangeStartOutDtypeModule_basic", "ArangeStartOutViewModule_basic", "BroadcastDynamicDimModule_basic", - "BroadcastToModule_basic", - "ExpandModule_basic", "MoveDimIntNegativeIndexModule_basic", "ViewSizeFromOtherTensor_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 2c013553bb3c..1e816a38e2ea 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1164,15 +1164,21 @@ func.func @test_exp(%arg0: !torch.vtensor<[3,4],f32>) -> !torch.vtensor<[3,4],f3 // CHECK-LABEL: @test_expand_dim2_shape2 func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32> attributes {torch.onnx_meta.ir_version = 9 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> - // CHECK: torch.aten.item %0 : !torch.vtensor<[],si32> -> !torch.int - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> - // CHECK: torch.aten.item %2 : !torch.vtensor<[],si32> -> !torch.int - // CHECK: torch.prim.ListConstruct %1, %3 : (!torch.int, !torch.int) -> !torch.list - // CHECK: torch.aten.broadcast_to %arg0, %4 : !torch.vtensor<[1,4],f32>, !torch.list -> !torch.vtensor<[3,4],f32> + // CHECK-DAG: %[[INT0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[INT0_0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[SEL0:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT0_0]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> + // CHECK-DAG: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] : !torch.vtensor<[],si32> -> !torch.int + // CHECK-DAG: %[[I0:.+]] = torch.constant.int 0 + // CHECK-DAG: %[[SZ0:.+]] = torch.aten.size.int %arg0, %[[I0]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int + // CHECK-DAG: %[[MX0:.+]] = torch.prim.max.int %[[ITEM0]], %[[SZ0]] : !torch.int, !torch.int -> !torch.int + // CHECK-DAG: %[[INT1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[INT0]], %[[INT1]] : !torch.vtensor<[2],si32>, !torch.int, !torch.int -> !torch.vtensor<[],si32> + // CHECK-DAG: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] : !torch.vtensor<[],si32> -> !torch.int + // CHECK-DAG: %[[I1:.+]] = torch.constant.int 1 + // CHECK-DAG: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[I1]] : !torch.vtensor<[1,4],f32>, !torch.int -> !torch.int + // CHECK-DAG: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[MX0]], %[[MX1]] : (!torch.int, !torch.int) -> !torch.list + // CHECK: torch.aten.broadcast_to %arg0, %[[LIST]] : !torch.vtensor<[1,4],f32>, !torch.list -> !torch.vtensor<[3,4],f32> %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,4],f32>, !torch.vtensor<[2],si32>) -> !torch.vtensor<[3,4],f32> return %0 : !torch.vtensor<[3,4],f32> } @@ -1181,47 +1187,31 @@ func.func @test_expand_dim2_shape2(%arg0: !torch.vtensor<[1,4],f32>, %arg1: !tor // CHECK-LABEL: @test_expand_dim2_shape3 func.func @test_expand_dim2_shape3(%arg0: !torch.vtensor<[3,1],f32>, %arg1: !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[],si64> -> !torch.int - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> - // CHECK: torch.aten.item %2 : !torch.vtensor<[],si64> -> !torch.int - // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[3],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> - // CHECK: torch.aten.item %4 : !torch.vtensor<[],si64> -> !torch.int - // CHECK: torch.prim.ListConstruct %1, %3, %5 : (!torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: torch.aten.broadcast_to %arg0, %6 : !torch.vtensor<[3,1],f32>, !torch.list -> !torch.vtensor<[2,3,6],f32> + // CHECK: %[[I0:.+]] = torch.constant.int 0 + // CHECK-NEXT: %[[I0_0:.+]] = torch.constant.int 0 + // CHECK-NEXT: %[[SEL0:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I0_0]] + // CHECK-NEXT: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] + // CHECK-NEXT: %[[I1:.+]] = torch.constant.int 1 + // CHECK-NEXT: %[[SEL1:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I1]] + // CHECK-NEXT: %[[ITEM1:.+]] = torch.aten.item %[[SEL1]] + // CHECK-NEXT: %[[D1:.+]] = torch.constant.int 0 + // CHECK-NEXT: %[[SZ1:.+]] = torch.aten.size.int %arg0, %[[D1]] + // CHECK-NEXT: %[[MX1:.+]] = torch.prim.max.int %[[ITEM1]], %[[SZ1]] : !torch.int, !torch.int -> !torch.int + // CHECK-NEXT: %[[I2:.+]] = torch.constant.int 2 + // CHECK-NEXT: %[[SEL2:.+]] = torch.aten.select.int %arg1, %[[I0]], %[[I2]] + // CHECK-NEXT: %[[ITEM2:.+]] = torch.aten.item %[[SEL2]] + // CHECK-NEXT: %[[D2:.+]] = torch.constant.int 1 + // CHECK-NEXT: %[[SZ2:.+]] = torch.aten.size.int %arg0, %[[D2]] + // CHECK-NEXT: %[[MX2:.+]] = torch.prim.max.int %[[ITEM2]], %[[SZ2]] + // CHECK-NEXT: %[[LIST:.+]] = torch.prim.ListConstruct %[[ITEM0]], %[[MX1]], %[[MX2]] + // CHECK-NEXT: %[[EXPAND:.+]] = torch.aten.broadcast_to %arg0, %[[LIST]] + // CHECK: return %[[EXPAND]] %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[3,1],f32>, !torch.vtensor<[3],si64>) -> !torch.vtensor<[2,3,6],f32> return %0 : !torch.vtensor<[2,3,6],f32> } // ----- -// CHECK-LABEL: @test_expand_dim3_shape4 -func.func @test_expand_dim3_shape4(%arg0: !torch.vtensor<[1,3,1],f32>, %arg1: !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,3,3,3],f32> attributes {torch.onnx_meta.ir_version = 4 : si64, torch.onnx_meta.opset_version = 9 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[],si64> -> !torch.int - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.select.int %arg1, %int0, %int1 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> - // CHECK: torch.aten.item %2 : !torch.vtensor<[],si64> -> !torch.int - // CHECK: %[[INT2:.+]] = torch.constant.int 2 - // CHECK: torch.aten.select.int %arg1, %int0, %int2 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> - // CHECK: torch.aten.item %4 : !torch.vtensor<[],si64> -> !torch.int - // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: torch.aten.select.int %arg1, %int0, %int3 : !torch.vtensor<[4],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64> - // CHECK: torch.aten.item %6 : !torch.vtensor<[],si64> -> !torch.int - // CHECK: torch.prim.ListConstruct %1, %3, %5, %7 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list - // CHECK: %9 = torch.aten.broadcast_to %arg0, %8 : !torch.vtensor<[1,3,1],f32>, !torch.list -> !torch.vtensor<[3,3,3,3],f32> - %0 = torch.operator "onnx.Expand"(%arg0, %arg1) : (!torch.vtensor<[1,3,1],f32>, !torch.vtensor<[4],si64>) -> !torch.vtensor<[3,3,3,3],f32> - return %0 : !torch.vtensor<[3,3,3,3],f32> -} - -// ----- - // CHECK-LABEL: @test_dropout func.func @test_dropout(%arg0: !torch.vtensor<[3],f32>) -> !torch.vtensor<[3],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.dropout %arg0, %float5.000000e-01, %false : !torch.vtensor<[3],f32>, !torch.float, !torch.bool -> !torch.vtensor<[3],f32 From a3fe130f73f160c4d5b984f8de5b26aa471cfc9a Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Sun, 10 Mar 2024 08:29:08 +0800 Subject: [PATCH 267/283] [Torch Dialect] emit aten::warn (#3003) * torch-mlir may not handle `aten.warn`. But it could be handled by custom users' backend which involves torch-mlir. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 23 +++++++++++++++++++ .../build_tools/torch_ods_gen.py | 1 + 2 files changed, 24 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 41ca1f5801dc..02e7ed2c74d0 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -13148,6 +13148,29 @@ def Torch_AtenJoinOp : Torch_Op<"aten.join", [ }]; } +def Torch_AtenWarnOp : Torch_Op<"aten.warn", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::warn : (str, int) -> ()`"; + let arguments = (ins + Torch_StringType:$message, + Torch_IntType:$stacklevel + ); + let results = (outs + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenWarnOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 0); + } + void AtenWarnOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 0); + } + }]; +} + def Torch_AtenFloatScalarOp : Torch_Op<"aten.Float.Scalar", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index ba41d4220e2f..54cdbea6ad59 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -753,6 +753,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::str : (t) -> (str)") emit("aten::format : (...) -> (str)") emit("aten::join : (str, str[]) -> (str)") + emit("aten::warn : (str, int) -> ()") # Type conversion ops. emit("aten::Float.Scalar : (Scalar) -> (float)", has_folder=True) From 229ca3a9e1cf1fc24a45e29315fd6af41322dcd3 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Mon, 11 Mar 2024 19:59:34 +0800 Subject: [PATCH 268/283] [Torch Dialect] emit aten::mul and add folder (#3007) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 +++++++++++++++++++ lib/Dialect/Torch/IR/TorchOps.cpp | 19 ++++++++++++++ .../build_tools/torch_ods_gen.py | 1 + 3 files changed, 45 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 02e7ed2c74d0..3e92d40992b8 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -14275,6 +14275,31 @@ def Torch_Aten_SetItemTOp : Torch_Op<"aten._set_item.t", [ }]; } +def Torch_AtenMulOp : Torch_Op<"aten.mul", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::mul : (Scalar, Scalar) -> (Scalar)`"; + let arguments = (ins + AnyTorchScalarType:$a, + AnyTorchScalarType:$b + ); + let results = (outs + AnyTorchScalarType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMulOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMulOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenDivOp : Torch_Op<"aten.div", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 9f10c8bce3ba..db2988f25539 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -3346,6 +3346,25 @@ OpFoldResult AtenAddOp::fold(FoldAdaptor adaptor) { [](double a, double b) -> double { return a + b; }); } +//===----------------------------------------------------------------------===// +// AtenMulOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenMulOp::fold(FoldAdaptor adaptor) { + if (!adaptor.getA() || !adaptor.getB()) { + return nullptr; + } + + if (adaptor.getA().isa() && adaptor.getB().isa()) { + return atenBinaryIntOperatorFoldHelper( + adaptor.getOperands(), + [](int64_t a, int64_t b) -> int64_t { return a * b; }); + } + return atenBinaryFloatOperatorFoldHelper( + adaptor.getOperands(), + [](double a, double b) -> double { return a * b; }); +} + //===----------------------------------------------------------------------===// // AtenSubOp //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 54cdbea6ad59..055f7127c9f2 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -805,6 +805,7 @@ def emit_with_mutating_variants(key, **kwargs): has_canonicalizer=True) emit("aten::__getitem__.t : (t[], int) -> (t)", has_canonicalizer=True) emit("aten::_set_item.t : (t[], int, t) -> (t[])") + emit("aten::mul : (Scalar, Scalar) -> (Scalar)", has_folder=True) emit("aten::div : (Scalar, Scalar) -> (float)", has_folder=True) emit("aten::add : (Scalar, Scalar) -> (Scalar)", has_folder=True) emit("aten::sub : (Scalar, Scalar) -> (Scalar)", has_folder=True) From 8fb28661f9168c7b76a691125d8ebdff1732f920 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 11 Mar 2024 11:32:53 -0700 Subject: [PATCH 269/283] [onnx] Fix onnx.ReduceMean lowering (#3002) Reduce mean lowerings did not succesfully lower to `linalg` via torched. There were two separate paths that could be consolidated to a single simpler pass. This resulted in a significant improvement in test coverage. --- .../TorchOnnxToTorch/DefaultDomainQtoZ.cpp | 179 ++++++------------ projects/pt1/e2e_testing/xfail_sets.py | 18 -- .../TorchOnnxToTorch/simple_ops_q_to_z.mlir | 120 ++++++------ 3 files changed, 116 insertions(+), 201 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp index 34282bfef531..b5e9162bc2bf 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainQtoZ.cpp @@ -845,157 +845,96 @@ void mlir::torch::onnx_c::populateDefaultDomainQtoZ( /*dtype=*/noneVal); return success(); }); - // onnx.ReduceMean with axes provided as argument introduced in opset 18 patterns.onOp( - "ReduceMean", 18, + "ReduceMean", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value data; - Value axes; int64_t keepDims; int64_t noop_with_empty_axes; - if (binder.tensorOperands(data, axes) || + if (binder.tensorOperandAtIndex(data, 0) || binder.tensorResultType(resultType) || binder.s64IntegerAttr(keepDims, "keepdims", 1) || binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", 0)) return failure(); - Torch::BaseTensorType axesType = - axes.getType().cast(); - SmallVector dimList; - SmallVector selectSizes; - selectSizes.push_back(1); - Type selectResultType = axesType.getWithSizesAndDtype( - llvm::ArrayRef(selectSizes), axesType.getOptionalDtype()); - auto sizes = - dyn_cast(axes.getType()).getSizes(); - Value noneVal = rewriter.create(binder.getLoc()); - // deal with case when axes is empty - if (sizes.size() == 1 && sizes[0] == 0) { - if (noop_with_empty_axes == 0) { - Value keepDimsConstInt = rewriter.create( + + SmallVector axesList; + + Value axesVal; + if (!binder.tensorOperandAtIndex(axesVal, 1)) { + Torch::BaseTensorType axesType = + axesVal.getType().cast(); + SmallVector dimList; + SmallVector selectSizes{1}; + auto selType = rewriter.getType( + selectSizes, axesType.getOptionalDtype()); + auto axesTy = dyn_cast(axesVal.getType()); + auto axesShape = axesTy.getSizes(); + + if (axesShape.size() != 1 || axesShape[0] == Torch::kUnknownSize) + return failure(); + + Value zero = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(0)); + int64_t numAxes = axesShape[0]; + for (int64_t i = 0; i < numAxes; ++i) { + Value iv = rewriter.create( binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); - Value keepDimsBool = rewriter.create( - binder.getLoc(), keepDimsConstInt); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, /*dim=*/noneVal, keepDimsBool, - /*dtype=*/noneVal); - } else { - rewriter.replaceOp(binder.op, data); + rewriter.getI64IntegerAttr(i)); + Value extract = rewriter.create( + binder.getLoc(), selType, axesVal, zero, iv); + Value dim = rewriter.create( + binder.getLoc(), rewriter.getType(), extract); + axesList.push_back(dim); } + } + + SmallVector axesInts; + if (!binder.s64IntegerArrayAttr(axesInts, "axes", {})) { + for (int64_t i = 0, s = axesInts.size(); i < s; ++i) { + Value iv = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getI64IntegerAttr(axesInts[i])); + axesList.push_back(iv); + } + } + + // deal with case when axes is empty + if (axesList.empty() && noop_with_empty_axes) { + rewriter.replaceOp(binder.op, data); return success(); } + Value zero = rewriter.create( binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), 0)); + rewriter.getI64IntegerAttr(0)); int64_t adjustmentInt = cast(data.getType()).getSizes().size(); Value adjustment = rewriter.create( binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), - adjustmentInt)); - // convert axes (tensor) into torch int list while dealing with neg axis - for (int i = 0; i < sizes[0]; i++) { - // Go through the axes list and get each dim in the list - Value selectIndex = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), i)); - Value extract = rewriter.create( - binder.getLoc(), selectResultType, axes, zero, selectIndex); - Value dim = rewriter.create( - binder.getLoc(), rewriter.getType(), extract); - // deal with neg axis: if (axis < 0) axis += rank - Value isNegative = - rewriter.create(binder.getLoc(), dim, zero); + rewriter.getI64IntegerAttr(adjustmentInt)); + + // Handle if the axes value is less than zero: + for (int i = 0, s = axesList.size(); i < s; i++) { + Value isNegative = rewriter.create( + binder.getLoc(), axesList[i], zero); isNegative = rewriter.create(binder.getLoc(), isNegative); Value finalOffset = rewriter.create( binder.getLoc(), isNegative, adjustment); Value finalDim = rewriter.create( - binder.getLoc(), dim, finalOffset); - dimList.push_back(finalDim); + binder.getLoc(), axesList[i], finalOffset); + axesList[i] = finalDim; } Value dimValueList = rewriter.create( binder.getLoc(), Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - dimList); - Value keepDimBool; - if (keepDims == 1) { - keepDimBool = - rewriter.create(binder.getLoc(), true); - } else { - keepDimBool = - rewriter.create(binder.getLoc(), false); - } - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, dimValueList, keepDimBool, - /*dtype=*/noneVal); - return success(); - }); - - // onnx.ReduceMean with axes provided as attribute - patterns.onOp( - "ReduceMean", 1, - [](OpBinder binder, ConversionPatternRewriter &rewriter) { - Torch::ValueTensorType resultType; - Value data; - llvm::SmallVector axes; - int64_t keepDims; - int64_t noop_with_empty_axes; - if (binder.tensorOperand(data) || binder.tensorResultType(resultType) || - binder.s64IntegerArrayAttr(axes, "axes", 0) || - binder.s64IntegerAttr(keepDims, "keepdims", 1) || - binder.s64IntegerAttr(noop_with_empty_axes, "noop_with_empty_axes", - 0)) - return failure(); - SmallVector dimList; - SmallVector selectSizes; - selectSizes.push_back(1); + axesList); + Value keepDimBool = + rewriter.create(binder.getLoc(), keepDims); Value noneVal = rewriter.create(binder.getLoc()); - // deal with case when axes is empty - if (axes.size() == 0) { - if (noop_with_empty_axes == 0) { - Value keepDimsConstInt = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), keepDims)); - Value keepDimsBool = rewriter.create( - binder.getLoc(), keepDimsConstInt); - rewriter.replaceOpWithNewOp( - binder.op, resultType, data, /*dim=*/noneVal, keepDimsBool, - /*dtype=*/noneVal); - } else { - rewriter.replaceOp(binder.op, data); - } - return success(); - } - int64_t adjustmentInt = - cast(data.getType()).getSizes().size(); - // convert axes (tensor) into torch int list while dealing with neg axis - for (uint64_t i = 0; i < axes.size(); i++) { - // Go through the axes list and get each dim in the list - int64_t dim = axes[i]; - if (dim < 0) { - dim += adjustmentInt; - } - // deal with neg axis: if (axis < 0) axis += rank - Value finalDim = rewriter.create( - binder.getLoc(), rewriter.getType(), - rewriter.getIntegerAttr(rewriter.getIntegerType(64), dim)); - dimList.push_back(finalDim); - } - Value dimValueList = rewriter.create( - binder.getLoc(), - Torch::ListType::get(Torch::IntType::get(binder.op->getContext())), - dimList); - Value keepDimBool; - if (keepDims == 1) { - keepDimBool = - rewriter.create(binder.getLoc(), true); - } else { - keepDimBool = - rewriter.create(binder.getLoc(), false); - } rewriter.replaceOpWithNewOp( binder.op, resultType, data, dimValueList, keepDimBool, /*dtype=*/noneVal); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index dd49760188d2..428c19788e18 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1474,15 +1474,7 @@ ONNX_XFAIL_SET = { # Failure - cast error - "MeanDimNoneDimModule_basic", - "MeanDtypeModule_basic", - "MeanDynamicSizesModule_basic", - "MeanModule_basic", - "MseLossMeanReductionModule_basic", "PermuteNegativeIndexModule_basic", - "StdBiasedModule_basic", - "VarBiasedModule_basic", - "VarMeanBiasedModule_basic", # Failure - incorrect numerics "AdaptiveAvgPool1dUnitOutputSizeDynamicModule_basic", @@ -1992,17 +1984,7 @@ "NativeDropoutTrainStaticShapeModule_basic", "ReduceProdDimIntFloatModule_basic", "StdCorrectionLargeInputModule_basic", - "StdCorrectionModule_basic", - "StdCorrectionNoneModule_basic", - "StdDimNoneDimModule_basic", - "StdUnbiasedModule_basic", "VarCorrectionLargeInputModule_basic", - "VarCorrectionModule_basic", - "VarCorrectionNoneModule_basic", - "VarDimNoneDimModule_basic", - "VarMeanCorrectionNoneModule_basic", - "VarMeanUnbiasedModule_basic", - "VarUnbiasedModule_basic", # Failure - onnx_lowering: onnx.ReduceSum "MseLossSumReductionWithDifferentElemTypeModule_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir index bba74b6d9877..508ed55d3337 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_q_to_z.mlir @@ -969,77 +969,71 @@ func.func @test_reduce_sum_negative_axes_keepdims_example(%arg0: !torch.vtensor< // ----- -// CHECK-LABEL: func.func @test_reduce_mean_default_axes_keepdims_example -func.func @test_reduce_mean_default_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[INT1:.+]] = torch.constant.int 1 - // CHECK: torch.aten.Bool.int %int1 : !torch.int -> !torch.bool - // CHECK: torch.aten.mean.dim %arg0, %[[NONE]], %0, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.none, !torch.bool, !torch.none -> !torch.vtensor<[1,1,1],f32> - %0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[0],si64>) -> !torch.vtensor<[1,1,1],f32> - return %0 : !torch.vtensor<[1,1,1],f32> +// CHECK-LABEL: @test_reduce_mean_negative_axes_keepdims_example +func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64} { + // CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + // CHECK: %[[DIM:.+]] = torch.constant.int 0 + // CHECK: %[[A0:.+]] = torch.constant.int 0 + // CHECK: %[[SEL0:.+]] = torch.aten.select.int %[[TENSOR]], %[[DIM]], %[[A0]] + // CHECK: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK: %[[RANK:.+]] = torch.constant.int 3 + // CHECK: %[[LT0:.+]] = torch.aten.lt.int %[[ITEM0]], %[[ZERO]] + // CHECK: %[[BOOL0:.+]] = torch.aten.Int.bool %[[LT0]] + // CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[BOOL0]], %[[RANK]] + // CHECK: %[[ADD0:.+]] = torch.aten.add.int %[[ITEM0]], %[[MUL0]] + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD0]] + // CHECK: %[[TRUE:.+]] = torch.constant.bool true + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.mean.dim %arg0, %[[LIST]], %[[TRUE]], %[[NONE]] + // CHECK: return %[[SUM]] + %cst = torch.vtensor.literal(dense<-2> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %0 = torch.operator "onnx.ReduceMean"(%arg0, %cst) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> + return %0 : !torch.vtensor<[3,1,2],f32> } // ----- -// CHECK-LABEL: func.func @test_reduce_mean_do_not_keepdims_example -func.func @test_reduce_mean_do_not_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: %[[FALSE:.+]] = torch.constant.bool false - // CHECK: torch.aten.mean.dim %arg0, %6, %false, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,2],f32> - %0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> +// CHECK-LABEL: @test_reduce_mean_one_axes_dropdims_example +func.func @test_reduce_mean_one_axes_dropdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: %[[TENSOR:.+]] = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + // CHECK: %[[DIM:.+]] = torch.constant.int 0 + // CHECK: %[[A0:.+]] = torch.constant.int 0 + // CHECK: %[[SEL0:.+]] = torch.aten.select.int %[[TENSOR]], %[[DIM]], %[[A0]] + // CHECK: %[[ITEM0:.+]] = torch.aten.item %[[SEL0]] + // CHECK: %[[ZERO:.+]] = torch.constant.int 0 + // CHECK: %[[RANK:.+]] = torch.constant.int 3 + // CHECK: %[[LT0:.+]] = torch.aten.lt.int %[[ITEM0]], %[[ZERO]] + // CHECK: %[[BOOL0:.+]] = torch.aten.Int.bool %[[LT0]] + // CHECK: %[[MUL0:.+]] = torch.aten.mul.int %[[BOOL0]], %[[RANK]] + // CHECK: %[[ADD0:.+]] = torch.aten.add.int %[[ITEM0]], %[[MUL0]] + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD0]] + // CHECK: %[[FALSE:.+]] = torch.constant.bool false + // CHECK: %[[NONE:.+]] = torch.constant.none + // CHECK: %[[SUM:.+]] = torch.aten.mean.dim %arg0, %[[LIST]], %[[FALSE]], %[[NONE]] + // CHECK: return %[[SUM]] + %cst = torch.vtensor.literal(dense<1> : tensor<1xsi64>) : !torch.vtensor<[1],si64> + %0 = torch.operator "onnx.ReduceMean"(%arg0, %cst) {torch.onnx.keepdims = 0 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,2],f32> return %0 : !torch.vtensor<[3,2],f32> } - // ----- -// CHECK-LABEL: func.func @test_reduce_mean_keepdims_example -func.func @test_reduce_mean_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { - // CHECK: %[[NONE:.+]] = torch.constant.none +// CHECK-LABEL: @test_reduce_mean_one_axesattr_dropdims_example +func.func @test_reduce_mean_one_axesattr_dropdims_example(%arg0: !torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64} { + // CHECK: %[[INT1:.+]] = torch.constant.int 1 // CHECK: %[[INT0:.+]] = torch.constant.int 0 // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: %[[TRUE:.+]] = torch.constant.bool true - // CHECK: torch.aten.mean.dim %arg0, %6, %true, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> - %0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> - return %0 : !torch.vtensor<[3,1,2],f32> -} - -// ----- - -// CHECK-LABEL: func.func @test_reduce_mean_negative_axes_keepdims_example -func.func @test_reduce_mean_negative_axes_keepdims_example(%arg0: !torch.vtensor<[3,2,2],f32>, %arg1: !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> attributes {torch.onnx_meta.ir_version = 8 : si64, torch.onnx_meta.opset_version = 18 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: %[[LT:.+]] = torch.aten.lt.int %[[INT1]], %[[INT0]] + // CHECK: %[[BOOL:.+]] = torch.aten.Int.bool %[[LT]] + // CHECK: %[[MUL:.+]] = torch.aten.mul.int %[[BOOL]], %[[INT3]] + // CHECK: %[[ADD:.+]] = torch.aten.add.int %[[INT1]], %[[MUL]] + // CHECK: %[[LIST:.+]] = torch.prim.ListConstruct %[[ADD]] + // CHECK: %[[FALSE:.+]] = torch.constant.bool false // CHECK: %[[NONE:.+]] = torch.constant.none - // CHECK: %[[INT0:.+]] = torch.constant.int 0 - // CHECK: %[[INT3:.+]] = torch.constant.int 3 - // CHECK: %[[INT0_0:.+]] = torch.constant.int 0 - // CHECK: torch.aten.select.int %arg1, %int0, %int0_0 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[1],si64> - // CHECK: torch.aten.item %0 : !torch.vtensor<[1],si64> -> !torch.int - // CHECK: torch.aten.lt.int %1, %int0 : !torch.int, !torch.int -> !torch.bool - // CHECK: torch.aten.Int.bool %2 : !torch.bool -> !torch.int - // CHECK: torch.aten.mul.int %3, %int3 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.aten.add.int %1, %4 : !torch.int, !torch.int -> !torch.int - // CHECK: torch.prim.ListConstruct %5 : (!torch.int) -> !torch.list - // CHECK: %[[TRUE:.+]] = torch.constant.bool true - // CHECK: torch.aten.mean.dim %arg0, %6, %true, %[[NONE]] : !torch.vtensor<[3,2,2],f32>, !torch.list, !torch.bool, !torch.none -> !torch.vtensor<[3,1,2],f32> - %0 = torch.operator "onnx.ReduceMean"(%arg0, %arg1) {torch.onnx.keepdims = 1 : si64} : (!torch.vtensor<[3,2,2],f32>, !torch.vtensor<[1],si64>) -> !torch.vtensor<[3,1,2],f32> - return %0 : !torch.vtensor<[3,1,2],f32> + // CHECK: %[[MEAN:.+]] = torch.aten.mean.dim %arg0, %[[LIST]], %[[FALSE]], %[[NONE]] + // CHECK: return %[[MEAN]] + %0 = torch.operator "onnx.ReduceMean"(%arg0) {torch.onnx.keepdims = 0 : si64, torch.onnx.axes = [1 : si64]} : (!torch.vtensor<[3,2,2],f32>) -> !torch.vtensor<[3,2],f32> + return %0 : !torch.vtensor<[3,2],f32> } // ----- @@ -1387,11 +1381,11 @@ func.func @test_slice_default_axes_and_slices(%arg0: !torch.vtensor<[20,10,5],f3 // CHECK: %[[ZERO0:.*]] = torch.constant.int 0 // CHECK-NEXT: %[[ZERO1:.*]] = torch.constant.int 0 // CHECK-NEXT: %[[SCALAR:.*]] = torch.prim.NumToTensor.Scalar %[[ZERO1]] : !torch.int -> !torch.vtensor<[1],si64> -// CHECK-NEXT: %[[SELECT0:.*]] = torch.aten.index_select %[[ARG1]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> +// CHECK-NEXT: %[[SELECT0:.*]] = torch.aten.index_select %[[ARG1]], %[[ZERO0]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> // CHECK-NEXT: %[[ITEM0:.*]] = torch.aten.item %[[SELECT0]] : !torch.vtensor<[1],si64> -> !torch.int -// CHECK-NEXT: %[[SELECT1:.*]] = torch.aten.index_select %[[ARG2]], %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> +// CHECK-NEXT: %[[SELECT1:.*]] = torch.aten.index_select %[[ARG2]], %[[ZERO0]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> // CHECK-NEXT: %[[ITEM1:.*]] = torch.aten.item %[[SELECT1]] : !torch.vtensor<[1],si64> -> !torch.int -// CHECK-NEXT: %[[SELECT3:.*]] = torch.aten.index_select %{{.*}}, %[[ZERO]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> +// CHECK-NEXT: %[[SELECT3:.*]] = torch.aten.index_select %{{.*}}, %[[ZERO0]], %[[SCALAR]] : !torch.vtensor<[1],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> // CHECK-NEXT: %[[ITEM3:.*]] = torch.aten.item %[[SELECT3]] : !torch.vtensor<[1],si64> -> !torch.int // CHECK: torch.aten.slice.Tensor %[[ARG0]], %[[ZERO1]], %[[ITEM0]], %[[ITEM1]], %[[ITEM3]] : !torch.vtensor<[20,10,5],f32>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[20,10,1],f32> From 4b1e87ce672a1f19ae04d2136fb83c93c95b545d Mon Sep 17 00:00:00 2001 From: Devjiu Date: Mon, 11 Mar 2024 20:22:05 +0100 Subject: [PATCH 270/283] [TorchDynamo] Enable Elemtwise ops for Scalar arg (#2744) This commit provides dummy solution to support elmentwise operations (mul, add) with scalar argument. ( op(Tensor, Scalar) ) It replaces `torch.aten.add.Tensor` with `torch.aten.add.Scalar`. ``` Unexpected outcome summary: (torchdynamo) ****** Unexpectedly Passed tests - 22 tests XPASS - "AddCDivModule_basic" XPASS - "BatchNorm1DModule_basic" XPASS - "BatchNorm1DStaticShapeModule_basic" XPASS - "BatchNorm1DWith2DInputModule_basic" XPASS - "BatchNorm2DModule_basic" XPASS - "BatchNorm3DModule_basic" XPASS - "ElementwiseAddScalarInt64Module_basic" XPASS - "ElementwiseAddScalarIntModule_basic" XPASS - "ElementwiseMulScalarModule_basic" XPASS - "ElementwiseMulScalarModule_float" XPASS - "ElementwiseMulScalarModule_int" XPASS - "GroupNormModule_basic" XPASS - "GroupNormNoWeightAndBiasModule_basic" XPASS - "MobilenetV3Module_basic" XPASS - "NativeBatchNorm1DModule_basic" XPASS - "NativeBatchNorm2DModule_basic" XPASS - "NativeBatchNorm3DModule_basic" XPASS - "NativeBatchNormNoneWeightModule_basic" XPASS - "NativeGroupNormBackwardModule_basic" XPASS - "NativeGroupNormModule_basic" XPASS - "ResNet18Module_basic" XPASS - "ResNet18StaticModule_basic" ``` And segfault for test "ElementwiseAddScalar_TensorLiteralInt32_Module_basic". Somehow this change doesn't allow to use Tensors, that are not forward arguments, but local variables of model. e.g. `self.x = torch.tensor(..)` See also: #2745 Signed-off-by: Dmitrii Makarenko --- projects/pt1/e2e_testing/xfail_sets.py | 30 ++-------------- .../configs/torchdynamo.py | 36 +++++++++++++++++++ 2 files changed, 39 insertions(+), 27 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 428c19788e18..810db7268753 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -215,10 +215,6 @@ 'ConstantBoolParameterModule_basic', # START tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' - "AddCDivModule_basic", - "ElementwiseMulScalarModule_basic", - "ElementwiseMulScalarModule_float", - "NativeGroupNormBackwardModule_basic", "UpSampleNearest2dDynamicSize_basic", "UpSampleNearest2dStaticFactor_basic", "UpSampleNearest2dStaticSize_basic", @@ -226,23 +222,7 @@ # END tests failing due to: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' # START tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' - "AtenInstanceNormModule_basic", - "BatchNorm1DModule_basic", - "BatchNorm1DWith2DInputModule_basic", - "BatchNorm2DModule_basic", - "BatchNorm3DModule_basic", - "BatchNorm1DStaticShapeModule_basic", "ElementwiseAddScalarFloatModule_basic", - "ElementwiseAddScalarInt64Module_basic", - "ElementwiseAddScalarIntModule_basic", - "MobilenetV3Module_basic", - "NativeBatchNorm1DModule_basic", - "NativeBatchNorm2DModule_basic", - "NativeBatchNorm3DModule_basic", - "NativeBatchNormNoneWeightModule_basic", - "NativeGroupNormModule_basic", - "ResNet18Module_basic", - "ResNet18StaticModule_basic", # END tests failing due to: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' # ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int' @@ -255,9 +235,6 @@ # ERROR: 'torch.aten.div.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int' "ElementwiseAtenDivIntScalarModule_basic", - # ERROR: 'torch.aten.mul.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.int' - "ElementwiseMulScalarModule_int", - # ERROR: 'torch.aten.sub.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' "ElementwiseSubScalarFloatModule_basic", "ElementwiseSubScalarIntModule_basic", @@ -315,10 +292,6 @@ # ERROR: shape (torch.Size([12])) is not equal to golden shape (torch.Size([3, 4])) "ArangeStartOutViewModule_basic", - # ERROR: 'torch.aten.add.Tensor' op operand #1 must be Any Torch tensor type, but got '!torch.float' - "GroupNormModule_basic", - "GroupNormNoWeightAndBiasModule_basic", - # Dynamo does not support tracing quantized tensors "ElementwiseDequantizePerChannelModule_basic", "ElementwiseDequantizePerTensorModule_basic", @@ -376,6 +349,9 @@ "MaxPool3dModule_basic", "MaxPool3dStaticCeilModeTrueModule_basic", "MaxPool3dStaticModule_basic", + + # Looks like incorrect fx graph conversion + "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", } STABLEHLO_PASS_SET = { diff --git a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py index e5c2475c7669..bdc410741cae 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py +++ b/projects/pt1/python/torch_mlir_e2e_test/configs/torchdynamo.py @@ -53,6 +53,40 @@ def _returns_empty_tuple(fx_graph: torch.fx.GraphModule) -> bool: return False return True +# Replaces torch.aten.add.Tensor/torch.aten.mul.Tensor to +# torch.aten.add.Scalar/torch.aten.mul.Scalar in case of Scalar argument +# Cannot be done on earlier stage, e.g. in _FXGraphImporter as it +# needs to check argument types, which are not yet determined. +# Maybe schema or target should be changed, but it decided in +# _dynamo eval_frame on pytorch side. Also Python schema not matches +# with mlir Schema - check include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +# So in general it covers some of overload cases, which done on Python side automatically. +# e.g. conversion Scalar -> Tensor and vice versa +def scalarize_tensor_ops_on_scalars(gm: torch.fx.GraphModule): + # Modify gm.graph + for node in gm.graph.nodes: + # Checks if we're calling a function (i.e: + # torch.add) + if node.op == 'call_function': + # The target attribute is the function + # that call_function calls. + # call_function[target=torch.ops.aten.add.Tensor](args = (%arg64_1, 1), kwargs = {}) + if node.target == torch.ops.aten.add.Tensor: + if len(node.args) != 2 or node.kwargs != {}: + continue + elif not isinstance(node.args[1], torch.fx.node.Node): + node.target = torch.ops.aten.add.Scalar + if node.target == torch.ops.aten.mul.Tensor: + if len(node.args) != 2 or node.kwargs != {}: + continue + elif not isinstance(node.args[1], torch.fx.node.Node): + node.target = torch.ops.aten.mul.Scalar + + gm.graph.lint() # Does some checks to make sure the + + # Recompile the forward() method of `gm` from its Graph + gm.recompile() + def jit( model: torch.nn.Module, @@ -87,6 +121,8 @@ def my_aot_autograd_backend(gm: torch.fx.GraphModule, # way of differentiating between the two. assert not _returns_empty_tuple(gm), "encountered graph that does not return anything" + scalarize_tensor_ops_on_scalars(gm) + nonlocal mlir_module *_, model_name, nth_graph = get_aot_compilation_context() mlir_module = import_fx_graph_as_func(gm.graph, model_name) From e78c99e74e115b4733f06f2ed186f74514982f74 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Mon, 11 Mar 2024 13:45:49 -0700 Subject: [PATCH 271/283] [torch] Update folders for splat operators (#3012) Splat operators required the output is 1-D. This was not a required restriction and was loosened to 2d. --- lib/Dialect/Torch/IR/TorchOps.cpp | 52 +++++++++++++++++-------------- 1 file changed, 29 insertions(+), 23 deletions(-) diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index db2988f25539..9ecf0e3e262e 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -3768,15 +3768,18 @@ OpFoldResult AtenOnesOp::fold(FoldAdaptor adaptor) { Type resultType = getResult().getType(); BaseTensorType resultTensorType = resultType.dyn_cast(); - if (!resultTensorType || !resultTensorType.hasDtype()) { + if (!resultTensorType || !resultTensorType.hasDtype() || + !resultTensorType.hasSizes()) { return nullptr; } - int64_t ct = sizes.size(); - if (resultTensorType.getSizes().size() != 1) - return nullptr; - if (resultTensorType.getSizes()[0] != ct) - return nullptr; + for (auto sz : sizes) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; + + for (auto sz : resultTensorType.getSizes()) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; ShapedType shapedty = mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType @@ -3804,15 +3807,18 @@ OpFoldResult AtenZerosOp::fold(FoldAdaptor adaptor) { Type resultType = getResult().getType(); BaseTensorType resultTensorType = resultType.dyn_cast(); - if (!resultTensorType || !resultTensorType.hasDtype()) { + if (!resultTensorType || !resultTensorType.hasDtype() || + !resultTensorType.hasSizes()) { return nullptr; } - int64_t ct = sizes.size(); - if (resultTensorType.getSizes().size() != 1) - return nullptr; - if (resultTensorType.getSizes()[0] != ct) - return nullptr; + for (auto sz : sizes) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; + + for (auto sz : resultTensorType.getSizes()) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; ShapedType shapedty = mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType @@ -3842,22 +3848,22 @@ OpFoldResult AtenFullOp::fold(FoldAdaptor adaptor) { Type resultType = getResult().getType(); BaseTensorType resultTensorType = resultType.dyn_cast(); - if (!resultTensorType || !resultTensorType.hasDtype()) { + if (!resultTensorType || !resultTensorType.hasDtype() || + !resultTensorType.hasSizes()) { return nullptr; } - int64_t ct = sizes.size(); - if (resultTensorType.getSizes().size() != 1) - return nullptr; - if (resultTensorType.getSizes()[0] != ct) - return nullptr; + for (auto sz : sizes) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; + + for (auto sz : resultTensorType.getSizes()) + if (sz == Torch::kUnknownSize || sz < 0) + return nullptr; ShapedType shapedty = - mlir::RankedTensorType::get( // convert Torch type to builtin ShapedType - sizes, resultTensorType.getDtype()); - if (!shapedty) { - return nullptr; - } + mlir::RankedTensorType::get(sizes, resultTensorType.getDtype()); + auto elementType = shapedty.getElementType(); if (elementType.isa()) { int64_t value = 0; From ad6159c7cb6e5105827c99aa6bb2cbe01f7a36d5 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Tue, 12 Mar 2024 08:58:20 +0800 Subject: [PATCH 272/283] [Stablehlo] lowering aten.round to stablehlo.round_nearest_even (#3011) --- lib/Conversion/TorchToStablehlo/Basic.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 2 ++ 2 files changed, 3 insertions(+) diff --git a/lib/Conversion/TorchToStablehlo/Basic.cpp b/lib/Conversion/TorchToStablehlo/Basic.cpp index d902202e8202..22743e6a9dee 100644 --- a/lib/Conversion/TorchToStablehlo/Basic.cpp +++ b/lib/Conversion/TorchToStablehlo/Basic.cpp @@ -1779,6 +1779,7 @@ void mlir::torch::torch_to_stablehlo::populateBasicOpPatternsAndLegality( INSERT_UNARY_FPONLY_PATTERN(AtenCosOp, stablehlo::CosineOp); INSERT_UNARY_FPONLY_PATTERN(AtenCeilOp, stablehlo::CeilOp); INSERT_UNARY_FPONLY_PATTERN(AtenFloorOp, stablehlo::FloorOp); + INSERT_UNARY_FPONLY_PATTERN(AtenRoundOp, stablehlo::RoundNearestEvenOp); #undef INSERT_UNARY_FPONLY_PATTERN #define INSERT_CONSTANT_FILL_PATTERN(AtenOp, fillVal) \ diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 810db7268753..885164e35cb3 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -405,6 +405,8 @@ "AtenItemIntOpModule_basic", "AtenMmFloatTypes_basic", "AtenMmIntTypes_basic", + "AtenRoundFloatHalfToEvenModule_basic", + "AtenRoundFloatModule_basic", "AtenRoundIntModule_basic", "AtenSubFloatModule_basic", "AtenToDeviceModule_basic", From 5ecc1d5c0dbb8ac166c36a083c982eb27e6f33eb Mon Sep 17 00:00:00 2001 From: Nithin Meganathan <18070964+nithinsubbiah@users.noreply.github.com> Date: Tue, 12 Mar 2024 15:07:45 -0700 Subject: [PATCH 273/283] Align softmax accumulation types with Torch's CUDA implementation (#2996) --- .../torch-mlir/Dialect/Torch/Utils/Utils.h | 9 ++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 31 ++++++++++++++----- lib/Dialect/Torch/Utils/Utils.cpp | 30 ++++++++++++++++++ 3 files changed, 63 insertions(+), 7 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/Utils/Utils.h b/include/torch-mlir/Dialect/Torch/Utils/Utils.h index beafe7d21adc..33a1c9f91fe7 100644 --- a/include/torch-mlir/Dialect/Torch/Utils/Utils.h +++ b/include/torch-mlir/Dialect/Torch/Utils/Utils.h @@ -131,6 +131,15 @@ Value createRank0Tensor(PatternRewriter &rewriter, Location loc, LogicalResult getTransposedType(BaseTensorType inType, int64_t dimA, int64_t dimB, Type &transposedType); +// Approximates the heuristic in the torch `acc_type` template for kernels +// that are defined in terms of it. For now, this just returns accumulators +// as if for CUDA from that implementation. In the future, this could be +// extended to look at hints on the `forOp` or its container to better +// control the behavior. Such support would be done in coordination with +// the fx_importer and APIs, which could add hints to the IR (based on +// Torch flags, user options, etc). +Type getDefaultAccType(PatternRewriter &rewriter, Type inputType); + } // namespace Torch } // namespace torch } // namespace mlir diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 157d6f227ae1..09d5b90f0eeb 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1237,22 +1237,32 @@ class DecomposeAtenTraceOp : public OpRewritePattern { // softmax = unnorm / sum(unnorm, dim, keepdim = True) template static Value getSoftmaxResult(OpTy op, Value self, Type resultType, - PatternRewriter &rewriter) { + Type accumulatorType, PatternRewriter &rewriter) { Location loc = op.getLoc(); Value dim = op.getDim(); + if (resultType != accumulatorType) + self = convertTensorToDtype(rewriter, loc, self, accumulatorType); Value xMax = createMaxAlongDimension(rewriter, loc, op, self, dim, /*keepDim=*/true); + if (!xMax) return nullptr; - Value unNormalized = createTensorSub(rewriter, loc, resultType, self, xMax); + Value unNormalized = + createTensorSub(rewriter, loc, self.getType(), self, xMax); Value unNormalizedExp = - rewriter.create(loc, resultType, unNormalized); + rewriter.create(loc, self.getType(), unNormalized); Value sum = createSumAlongDimension(rewriter, loc, op, unNormalizedExp, dim, /*keepDim=*/true); if (!sum) return nullptr; - return rewriter.create(loc, resultType, unNormalizedExp, - sum); + + Value result = rewriter.create(loc, self.getType(), + unNormalizedExp, sum); + if (resultType != accumulatorType) + result = convertTensorToDtype(rewriter, loc, result, + resultType.cast().getDtype()); + + return result; } // Decompose softmax into: exp(x) / sum(exp(x)) @@ -1284,7 +1294,10 @@ class DecomposeAtenSoftmaxIntOp : public OpRewritePattern { /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); } - Value result = getSoftmaxResult(op, self, resultTensorType, rewriter); + Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype); + + Value result = getSoftmaxResult(op, self, resultTensorType, + accumulatorTensorType, rewriter); if (!result) return failure(); rewriter.replaceOpWithNewOp(op, op.getType(), @@ -1329,7 +1342,11 @@ class DecomposeAten_SoftmaxOp : public OpRewritePattern { getDtypeIntValueForType(rewriter, loc, resultTensorDtype), /*non_blocking=*/cstFalse, /*copy=*/cstFalse, /*memory_format=*/none); } - Value result = getSoftmaxResult(op, self, resultTensorType, rewriter); + + Type accumulatorTensorType = getDefaultAccType(rewriter, resultTensorDtype); + + Value result = getSoftmaxResult(op, self, resultTensorType, + accumulatorTensorType, rewriter); if (!result) return op.emitError("failed to get softmax result"); rewriter.replaceOpWithNewOp(op, resultTensorType, diff --git a/lib/Dialect/Torch/Utils/Utils.cpp b/lib/Dialect/Torch/Utils/Utils.cpp index e2abee51b817..0cd672e5d2cc 100644 --- a/lib/Dialect/Torch/Utils/Utils.cpp +++ b/lib/Dialect/Torch/Utils/Utils.cpp @@ -525,3 +525,33 @@ LogicalResult Torch::getTransposedType(BaseTensorType inType, int64_t dimA, inType.getOptionalDtype()); return success(); } + +Type Torch::getDefaultAccType(PatternRewriter &rewriter, Type inputType) { + if (inputType.isF16()) + return rewriter.getF32Type(); + if (inputType.isBF16()) + return rewriter.getF32Type(); + if (inputType.isa()) + return rewriter.getF32Type(); + if (inputType.isa()) + return rewriter.getF64Type(); + if (inputType.isFloat8E5M2()) + return rewriter.getF32Type(); + if (inputType.isFloat8E4M3FN()) + return rewriter.getF32Type(); + if (inputType.isFloat8E5M2FNUZ()) + return rewriter.getF32Type(); + if (inputType.isFloat8E4M3FNUZ()) + return rewriter.getF32Type(); + if (inputType.isSignedInteger(8)) + return rewriter.getI64Type(); + if (inputType.isUnsignedInteger(8)) + return rewriter.getI64Type(); + if (inputType.isSignedInteger(16)) + return rewriter.getI64Type(); + if (inputType.isSignedInteger(32)) + return rewriter.getI64Type(); + if (inputType.isSignedInteger(64)) + return rewriter.getI64Type(); + llvm::report_fatal_error("unhandled type for getDefaultAccType"); +} From 6fa21bd8b19882fe7e46b02ced7de66c3b332d71 Mon Sep 17 00:00:00 2001 From: aldesilv <35313749+aldesilv@users.noreply.github.com> Date: Wed, 13 Mar 2024 08:04:10 -0700 Subject: [PATCH 274/283] OnnxToTorch lower celu op (#2920) --- .../TorchOnnxToTorch/DefaultDomainAtoF.cpp | 41 +++++++++++++++++++ .../TorchOnnxToTorch/simple_ops_a_to_f.mlir | 20 +++++++++ 2 files changed, 61 insertions(+) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index c976b49842c8..2e3f3e8b8053 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -601,6 +601,47 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( binder.op, resultType, operand); return success(); }); + patterns.onOp( + "Celu", 12, [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value operand; + float alpha; + if (binder.tensorOperand(operand) || + binder.tensorResultType(resultType) || + binder.f32FloatAttr(alpha, "alpha", 1.0f)) + return failure(); + // exp(x/alpha) + Value constAlpha = rewriter.create( + binder.getLoc(), rewriter.getType(), + rewriter.getF64FloatAttr(alpha)); + Value xDivAlpha = rewriter.create( + binder.getLoc(), resultType, operand, constAlpha); + Value expXDivAlpha = rewriter.create( + binder.getLoc(), resultType, xDivAlpha); + // alpha * (exp(x/alpha) - 1) + Value constantOne = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(1)); + Value subOne = rewriter.create( + binder.getLoc(), resultType, expXDivAlpha, constantOne, + constantOne); + Value mulAlpha = rewriter.create( + binder.getLoc(), resultType, subOne, constAlpha); + Value constantZero = rewriter.create( + binder.getLoc(), rewriter.getI64IntegerAttr(0)); + Value zeroTensor = createRank0Tensor(rewriter, binder.getLoc(), + resultType, constantZero); + // min(0, alpha * (exp(x/alpha) - 1)) + Value minExpression = rewriter.create( + binder.getLoc(), resultType, zeroTensor, mulAlpha); + + // max(0, x) + Value maxExpression = rewriter.create( + binder.getLoc(), resultType, zeroTensor, operand); + // max(0,x) + min(0, alpha * (exp(x/alpha) - 1)) + rewriter.replaceOpWithNewOp( + binder.op, resultType, maxExpression, minExpression, constantOne); + return success(); + }); patterns.onOp( "Clip", 1, [](OpBinder binder, ConversionPatternRewriter &rewriter) { // https://onnx.ai/onnx/operators/onnx__Clip.html diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir index 1e816a38e2ea..8cd8bab0032f 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_a_to_f.mlir @@ -1645,3 +1645,23 @@ func.func @test_constant_of_shape_dense_int_cst() -> !torch.vtensor<[2,3,4], si6 %0 = "torch.operator"(%cst) <{name = "onnx.ConstantOfShape"}> {torch.onnx.value = dense<3> : tensor<1xsi64>}: (!torch.vtensor<[3], si64>) -> !torch.vtensor<[2,3,4], si64> return %0 : !torch.vtensor<[2,3,4], si64> } + +// CHECK-LABEL: func.func @test_celu +func.func @test_celu(%arg0: !torch.vtensor<[3,3,3,1],f32>) -> !torch.vtensor<[3,3,3,1],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 12 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { +// CHECK: %[[ALPHA:.*]] = torch.constant.float 2.000000e+00 +// CHECK: %0 = torch.aten.div.Scalar %arg0, %[[ALPHA]] : !torch.vtensor<[3,3,3,1],f32>, !torch.float -> !torch.vtensor<[3,3,3,1],f32> +// CHECK: %1 = torch.aten.exp %0 : !torch.vtensor<[3,3,3,1],f32> -> !torch.vtensor<[3,3,3,1],f32> +// CHECK: %int1 = torch.constant.int 1 +// CHECK: %2 = torch.aten.sub.Scalar %1, %int1, %int1 : !torch.vtensor<[3,3,3,1],f32>, !torch.int, !torch.int -> !torch.vtensor<[3,3,3,1],f32> +// CHECK: %3 = torch.aten.mul.Scalar %2, %[[ALPHA]] : !torch.vtensor<[3,3,3,1],f32>, !torch.float -> !torch.vtensor<[3,3,3,1],f32> +// CHECK: %int0 = torch.constant.int 0 +// CHECK: %4 = torch.prim.ListConstruct : () -> !torch.list +// CHECK: %none = torch.constant.none +// CHECK: %int6 = torch.constant.int 6 +// CHECK: %[[ZERO:.*]] = torch.aten.full %4, %int0, %int6, %none, %none, %none : !torch.list, !torch.int, !torch.int, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[],f32> +// CHECK: %[[MIN:.*]] = torch.aten.minimum %[[ZERO]], %3 : !torch.vtensor<[],f32>, !torch.vtensor<[3,3,3,1],f32> -> !torch.vtensor<[3,3,3,1],f32> +// CHECK: %[[MAX:.*]] = torch.aten.maximum %[[ZERO]], %arg0 : !torch.vtensor<[],f32>, !torch.vtensor<[3,3,3,1],f32> -> !torch.vtensor<[3,3,3,1],f32> +// CHECK: %8 = torch.aten.add.Tensor %[[MAX]], %[[MIN]], %int1 : !torch.vtensor<[3,3,3,1],f32>, !torch.vtensor<[3,3,3,1],f32>, !torch.int -> !torch.vtensor<[3,3,3,1],f32> + %0 = torch.operator "onnx.Celu"(%arg0) {torch.onnx.alpha = 2.000000e+00 : f32} : (!torch.vtensor<[3,3,3,1],f32>) -> !torch.vtensor<[3,3,3,1],f32> + return %0 : !torch.vtensor<[3,3,3,1],f32> +} From 524ff99216e99805dd1fe003fb9c2947fb8a9770 Mon Sep 17 00:00:00 2001 From: ptrifunovic98 <156185835+ptrifunovic98@users.noreply.github.com> Date: Wed, 13 Mar 2024 20:17:22 +0100 Subject: [PATCH 275/283] Implement lowering of torch.aten.linalg_cross (#2986) Closes [nod-ai/SHARK-Turbine#497](https://github.com/nod-ai/SHARK-Turbine/issues/497) --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 26 ++++ lib/Dialect/Torch/IR/TorchOps.cpp | 90 ++++++++++++++ .../Transforms/AbstractInterpLibrary.cpp | 76 ++++++++++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 112 ++++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 3 + .../build_tools/abstract_interp_lib_gen.py | 24 ++++ .../build_tools/torch_ods_gen.py | 1 + .../torch_mlir_e2e_test/test_suite/matmul.py | 111 +++++++++++++++++ 9 files changed, 444 insertions(+) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 3e92d40992b8..7c0d7a73e89c 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -11732,6 +11732,32 @@ def Torch_AtenLinspaceOp : Torch_Op<"aten.linspace", [ }]; } +def Torch_AtenLinalgCrossOp : Torch_Op<"aten.linalg_cross", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self, + AnyTorchTensorType:$other, + Torch_IntType:$dim + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenLinalgCrossOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 3, 1); + } + void AtenLinalgCrossOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 3, 1); + } + }]; + let hasVerifier = 1; +} + def Torch_AtenAliasCopyOp : Torch_Op<"aten.alias_copy", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 9ecf0e3e262e..30e1ff987aa9 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -4278,6 +4278,96 @@ LogicalResult AtenPermuteOp::verify() { return success(); } +//===----------------------------------------------------------------------===// +// AtenLinalgCrossOp +//===----------------------------------------------------------------------===// + +LogicalResult AtenLinalgCrossOp::verify() { + + auto selfType = getSelf().getType().cast(); + auto otherType = getOther().getType().cast(); + + if (!selfType.hasDtype() || !otherType.hasDtype() || !selfType.hasSizes() || + !otherType.hasSizes()) { + return success(); + } + + Type selfDtype = selfType.getDtype(); + Type otherDtype = otherType.getDtype(); + + // the operation succeeds only if both inputs have the same dtype + if (selfDtype != otherDtype) { + return emitOpError("input tensors must have the same dtype, but got ") + << selfDtype << " and " << otherDtype; + } + + // Check if any of the input tensors has torch.bool dtype. + // The operation does not support this type. + // The docs state that only float, double, cfloat and cdouble dtypes are + // supported, but, when testing, it fails only for boolean dtype. Update to + // fit the docs if necessary. + // https://pytorch.org/docs/stable/generated/torch.linalg.cross.html + if (selfDtype.isSignlessInteger(1) || otherDtype.isSignlessInteger(1)) { + return emitOpError("input tensors must not have bool dtype"); + } + + ArrayRef selfShape = selfType.getSizes(); + ArrayRef otherShape = otherType.getSizes(); + + int64_t selfRank = selfShape.size(); + int64_t otherRank = otherShape.size(); + + // check if both input tensors have the same number of dims + if (selfRank != otherRank) { + return emitOpError("input tensors must have the same number of dimensions, " + "but got ") + << selfRank << " and " << otherRank; + } + + // convert dim to an integer type + int64_t dim; + if (!matchPattern(getDim(), m_TorchConstantInt(&dim))) { + return success(); + } + + // check if dim is in the correct range + if (dim >= selfRank || dim < -selfRank) { + return emitOpError("dim expected to be in rank of [") + << -selfRank << ", " << selfRank - 1 << "], but got " << dim; + } + + // compensate for possible negative dim value + if (dim < 0) { + dim += selfRank; + } + + // check if the size of the dimensions specified by 'dim' is equal to 3 + // (required by the operation) + if ((selfShape[dim] != 3 && selfShape[dim] != kUnknownSize) || + (otherShape[dim] != 3 && otherShape[dim] != kUnknownSize)) { + return emitOpError("inputs dimension ") + << dim << " must have length 3, but got " << selfShape[dim] + << " and " << otherShape[dim]; + } + + // Check if there is a disparity between dimension sizes. + // Dimensions at the same index must either have the same size, + // or one of them must be equal to 1. + int32_t i = 0; + for (auto [selfCurrent, otherCurrent] : + llvm::zip_equal(selfShape, otherShape)) { + if (selfCurrent != otherCurrent && selfCurrent != 1 && otherCurrent != 1) { + return emitOpError("the size of first tensor (") + << selfCurrent << ") must match the size of second tensor (" + << otherCurrent << ") at dimension " << i + << " or one of them must be 1"; + } + ++i; + } + + return success(); +} + //===----------------------------------------------------------------------===// // DtypeCalculateYieldDtypesOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 19c84617a2a1..c55a2421f5be 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6793,6 +6793,57 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.linalg_cross\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int) -> !torch.list {\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %str_0 = torch.constant.str \"the size of first tensor ({}) must match the size of second tensor ({}) at dimension {}\"\n" +" %true = torch.constant.bool true\n" +" %none = torch.constant.none\n" +" %str_1 = torch.constant.str \"AssertionError: inputs must have the same number of dimensions\"\n" +" %int1 = torch.constant.int 1\n" +" %0 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" %1 = torch.aten.len.t %arg1 : !torch.list -> !torch.int\n" +" %2 = torch.aten.eq.int %0, %1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %2 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str_1, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %3 = torch.aten.len.t %arg0 : !torch.list -> !torch.int\n" +" torch.prim.Loop %3, %true, init() {\n" +" ^bb0(%arg3: !torch.int):\n" +" %5 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %6 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %7 = torch.aten.eq.int %5, %6 : !torch.int, !torch.int -> !torch.bool\n" +" %8 = torch.prim.If %7 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %10 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.eq.int %10, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" }\n" +" %9 = torch.prim.If %8 -> (!torch.bool) {\n" +" torch.prim.If.yield %true : !torch.bool\n" +" } else {\n" +" %10 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.eq.int %10, %int1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If.yield %11 : !torch.bool\n" +" }\n" +" torch.prim.If %9 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" %10 = torch.aten.__getitem__.t %arg0, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %11 = torch.aten.__getitem__.t %arg1, %arg3 : !torch.list, !torch.int -> !torch.int\n" +" %12 = torch.aten.format(%str_0, %10, %11, %arg3) : !torch.str, !torch.int, !torch.int, !torch.int -> !torch.str\n" +" %13 = torch.aten.add.str %str, %12 : !torch.str, !torch.str -> !torch.str\n" +" torch.prim.RaiseException %13, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" torch.prim.Loop.condition %true, iter()\n" +" } : (!torch.int, !torch.bool) -> ()\n" +" %4 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" +" return %4 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten._log_softmax_backward_data\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.int, %arg3: !torch.int) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.unary(%arg0) : (!torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -10033,6 +10084,31 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.linalg_cross\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int) -> !torch.int {\n" +" %int11 = torch.constant.int 11\n" +" %none = torch.constant.none\n" +" %str = torch.constant.str \"AssertionError: \"\n" +" %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %1:2 = torch.prim.TupleUnpack %arg1 : !torch.tuple -> !torch.int, !torch.int\n" +" %2 = torch.prim.ListConstruct %0#0, %1#0 : (!torch.int, !torch.int) -> !torch.list>\n" +" %3 = torch.aten.eq.int %0#1, %1#1 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %3 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %4 = torch.aten.ne.int %0#1, %int11 : !torch.int, !torch.int -> !torch.bool\n" +" torch.prim.If %4 -> () {\n" +" torch.prim.If.yield\n" +" } else {\n" +" torch.prim.RaiseException %str, %none : !torch.str, !torch.none\n" +" torch.prim.If.yield\n" +" }\n" +" %5 = torch.prim.ListConstruct %0#1, %1#1 : (!torch.int, !torch.int) -> !torch.list\n" +" %6 = call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.promote_dtypes(%2, %5) : (!torch.list>, !torch.list) -> !torch.int\n" +" return %6 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten._log_softmax_backward_data\"(%arg0: !torch.tuple, %arg1: !torch.tuple, %arg2: !torch.int, %arg3: !torch.int) -> !torch.int {\n" " return %arg3 : !torch.int\n" " }\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 09d5b90f0eeb..5335cbba9bb0 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -1823,6 +1823,117 @@ class DecomposeAtenMvOp : public OpRewritePattern { }; } // namespace +// Decompose aten.linalg_cross into: aten.broadcast_to, aten.index_select, +// aten.add.Tensor and aten.mull.Tensor. See +// https://github.com/pytorch/pytorch/blob/ed3c256b61f05720843454a9282aa7c903da2c81/torch/_refs/linalg/__init__.py#L70. +// def linalg_cross(self: Tensor, other: Tensor, dim: int = -1): +// broadcast_shape = compute_broadcast_shape(self, other) +// a = torch.broadcast_to(self, broadcast_shape) +// b = torch.broadcast_to(other, broadcast_shape) +// idx = torch.arange(3) +// return a.index_select(dim, (idx + 1) % 3) * +// b.index_select(dim, (idx + 2) % 3) - +// a.index_select(dim, (idx + 2) % 3) * +// b.index_select(dim, (idx + 1) % 3) +namespace { +class DecomposeAtenLinalgCrossOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLinalgCrossOp op, + PatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + Value self = op.getSelf(); + Value other = op.getOther(); + Type opType = op.getType(); + Value dim = op.getDim(); + + auto resType = self.getType().cast(); + if (!resType.hasDtype()) { + return rewriter.notifyMatchFailure(op, "result should have dtype"); + } + + Type dtype = resType.getDtype(); + if (dtype.isa()) { + return rewriter.notifyMatchFailure( + op, "lowering of aten.linalg_cross for complex inputs dtype is " + "currently unimplemented"); + } + + // calculate common shape for broadcast + SmallVector broadcastShape; + SmallVector broadcastShapeValue; + computeBroadcastShape(rewriter, loc, self, other, broadcastShape, + broadcastShapeValue); + + Type broadcastType = ValueTensorType::get( + op.getContext(), llvm::ArrayRef(broadcastShape), dtype); + + Value indexBroadcastShapeTorchList = rewriter.create( + loc, Torch::ListType::get(Torch::IntType::get(op.getContext())), + broadcastShapeValue); + + // broadcast tensors to common shape + auto a = rewriter.create(loc, broadcastType, self, + indexBroadcastShapeTorchList); + auto b = rewriter.create(loc, broadcastType, other, + indexBroadcastShapeTorchList); + + // create constants + Value constOne = rewriter.create( + loc, rewriter.getI64IntegerAttr(1)); + Value constTwo = rewriter.create( + loc, rewriter.getI64IntegerAttr(2)); + Value constThree = rewriter.create( + loc, rewriter.getI64IntegerAttr(3)); + Value none = rewriter.create(loc); + + // idx = torch.arange(3) + auto outType = opType.dyn_cast(); + auto arangeType = outType.getWithSizesAndDtype( + llvm::ArrayRef(3), + IntegerType::get(op.getContext(), 64, IntegerType::Signed)); + auto idx = rewriter.create( + loc, arangeType, constThree, /*dtype=*/none, /*layout=*/none, + /*device=*/none, /*pin_memory=*/none); + + // (idx + 1) and (idx + 2) + auto idxPlusOne = rewriter.create(loc, arangeType, idx, + constOne, constOne); + auto idxPlusTwo = rewriter.create(loc, arangeType, idx, + constTwo, constOne); + + // (idx + 1) % 3 and (idx + 2) % 3 + auto idxPlusOneRemainderThree = rewriter.create( + loc, arangeType, idxPlusOne, constThree); + auto idxPlusTwoRemainderThree = rewriter.create( + loc, arangeType, idxPlusTwo, constThree); + + // a.index_select(dim, (idx + 1) % 3) * b.index_select(dim, (idx + 2) % 3) + auto idxSelectAPlusOne = rewriter.create( + loc, opType, a, dim, idxPlusOneRemainderThree); + auto idxSelectBPlusTwo = rewriter.create( + loc, opType, b, dim, idxPlusTwoRemainderThree); + auto firstMul = rewriter.create( + loc, opType, idxSelectAPlusOne, idxSelectBPlusTwo); + + // a.index_select(dim, (idx + 2) % 3) * b.index_select(dim, (idx + 1) % 3) + auto idxSelectAPlusTwo = rewriter.create( + loc, opType, a, dim, idxPlusTwoRemainderThree); + auto idxSelectBPlusOne = rewriter.create( + loc, opType, b, dim, idxPlusOneRemainderThree); + auto secondMul = rewriter.create( + loc, opType, idxSelectAPlusTwo, idxSelectBPlusOne); + + // subtract the results of the two multiplications from above + rewriter.replaceOpWithNewOp(op, opType, firstMul, + secondMul, constOne); + + return success(); + } +}; +} // namespace + // Decompose aten.pixel_shuffle into: prims.split_dim, aten.permute, and // prims.collapse operations. // @@ -7081,6 +7192,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal( diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index f52c46789350..44a3986ac52f 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -395,6 +395,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 885164e35cb3..e262f52e3018 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2107,6 +2107,9 @@ "ReduceMinAlongDimUnsignedInt_basic", "TensorsStackNegativeDimModule_basic", "TensorsStackPromoteDTypeModule_basic", + + # Failure - "RuntimeError: linalg.cross: inputs dimension 1 must have length 3. Got 1 and 1" + "AtenLinalgCrossDynamic_basic" } ONNX_CRASHING_SET = { } diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 8ef43b0082b0..1f6e450da987 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -384,6 +384,17 @@ def aten〇clone〡shape(self: List[int], memory_format: Optional[int] = None) - def aten〇lift_fresh_copy〡shape(self: List[int]) -> List[int]: return upstream_shape_functions.unary(self) +@check_shape_function([ + Invocation(TensorOfShape(1, 2, 3), TensorOfShape(4, 1, 3)), # two dimensions to broadcast, self[0] and other[1] + ErrorInvocation(TensorOfShape(3), TensorOfShape(2, 3)), # different number of dimensions + ErrorInvocation(TensorOfShape(2, 3), TensorOfShape(4, 3)) # non-broadcastable dimensions +]) +def aten〇linalg_cross〡shape(self: List[int], other: List[int], dim: int = -1) -> List[int]: + assert len(self) == len(other), "inputs must have the same number of dimensions" + for i in range(len(self)): + assert (self[i] == other[i]) or self[i] == 1 or other[i] == 1, f"the size of first tensor ({self[i]}) must match the size of second tensor ({other[i]}) at dimension {i}" + return upstream_shape_functions.broadcast(self, other) + def aten〇_log_softmax_backward_data〡shape(grad_output: List[int], output: List[int], dim: int, input_dtype: int) -> List[int]: return upstream_shape_functions.unary(grad_output) @@ -2381,6 +2392,19 @@ def aten〇lift_fresh_copy〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype return self_dtype +@check_dtype_function( + _check_tensors_with_the_same_dtype(tensor_device="cpu", tensor_shapes=[(2,3), (2,3)], error_types={torch.bool}) + # same dtype + [ErrorInvocation(TensorOfShape(2, 3, dtype=torch.int32, device="cpu"), TensorOfShape(2, 3, dtype=torch.float16, device="cpu"))] #different dtypes +) +def aten〇linalg_cross〡dtype(self_rank_dtype: Tuple[int, int], other_rank_dtype: Tuple[int, int], dim: int = -1) -> int: + self_rank, self_dtype = self_rank_dtype + other_rank, other_dtype = other_rank_dtype + ranks: List[Optional[int]] = [self_rank, other_rank] + assert self_dtype == other_dtype + assert self_dtype != torch.bool + dtypes = [self_dtype, other_dtype] + return promote_dtypes(ranks, dtypes) + @check_dtype_function( _check_two_tensor_op(dim=0, input_dtype=torch.float32) + _check_two_tensor_op(dim=0, input_dtype=torch.float64)) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index 055f7127c9f2..fa469e035064 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -687,6 +687,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::fmod.Tensor : (Tensor, Tensor) -> (Tensor)") emit("aten::unique_consecutive : (Tensor, bool, bool, int?) -> (Tensor, Tensor, Tensor)") emit("aten::linspace : (Scalar, Scalar, int, int?, int?, Device?, bool?) -> (Tensor)") + emit("aten::linalg_cross : (Tensor, Tensor, int) -> (Tensor)", has_verifier=True) # Functionalization ops emit("aten::alias_copy : (Tensor) -> (Tensor)") diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py index 72a4097bc302..80f02a7b5dc8 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/matmul.py @@ -289,3 +289,114 @@ def forward(self, x, y): def AtenMmQuint8_basic(module, tu: TestUtils): module.forward(tu.randint(3, 4, low=-128, high=127).to(torch.int8), tu.randint(4, 3, low=-128, high=127).to(torch.int8)) + +# ============================================================================== + +class AtenLinalgCrossInt(torch.nn.Module): + + @export + @annotate_args([ + None, + ([2, 3], torch.int64, True), + ([2, 3], torch.int64, True), + ]) + def forward(self, a, b): + return torch.ops.aten.linalg_cross(a, b) + + +@register_test_case(module_factory=lambda: AtenLinalgCrossInt()) +def AtenLinalgCrossInt_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 3), tu.randint(2, 3)) + +# ============================================================================== + +class AtenLinalgCrossFloat(torch.nn.Module): + + @export + @annotate_args([ + None, + ([2, 3], torch.float32, True), + ([2, 3], torch.float32, True), + ]) + def forward(self, a, b): + return torch.ops.aten.linalg_cross(a, b) + + +@register_test_case(module_factory=lambda: AtenLinalgCrossFloat()) +def AtenLinalgCrossFloat_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 3), tu.rand(2, 3)) + + +# ============================================================================== + +class AtenLinalgCrossBroadcast(torch.nn.Module): + + @export + @annotate_args([ + None, + ([1, 4, 3], torch.float32, True), + ([5, 4, 3], torch.float32, True), + ]) + def forward(self, a, b): + return torch.ops.aten.linalg_cross(a, b) + + +@register_test_case(module_factory=lambda: AtenLinalgCrossBroadcast()) +def AtenLinalgCrossBroadcast_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 4, 3), tu.rand(5, 4, 3)) + +# ============================================================================== + +class AtenLinalgCrossCustomDim(torch.nn.Module): + + @export + @annotate_args([ + None, + ([1, 4, 3, 2, 2], torch.float32, True), + ([5, 4, 3, 2, 1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.ops.aten.linalg_cross(a, b, dim=2) + + +@register_test_case(module_factory=lambda: AtenLinalgCrossCustomDim()) +def AtenLinalgCrossCustomDim_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 4, 3, 2, 2), tu.rand(5, 4, 3, 2, 1)) + +# ============================================================================== + +class AtenLinalgCrossNegativeDim(torch.nn.Module): + + @export + @annotate_args([ + None, + ([1, 4, 3, 2, 2], torch.float32, True), + ([5, 4, 3, 2, 1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.ops.aten.linalg_cross(a, b, dim=-3) + + +@register_test_case(module_factory=lambda: AtenLinalgCrossNegativeDim()) +def AtenLinalgCrossNegativeDim_basic(module, tu: TestUtils): + module.forward(tu.rand(1, 4, 3, 2, 2), tu.rand(5, 4, 3, 2, 1)) + +# ============================================================================== + +class AtenLinalgCrossDynamic(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ([-1, -1, -1, -1], torch.float32, True), + ([-1, -1, -1, -1], torch.float32, True), + ]) + def forward(self, a, b): + return torch.ops.aten.linalg_cross(a, b, dim=1) + + +@register_test_case(module_factory=lambda: AtenLinalgCrossDynamic()) +def AtenLinalgCrossDynamic_basic(module, tu: TestUtils): + module.forward(tu.rand(4, 3, 1, 6), tu.rand(4, 3, 7, 1)) \ No newline at end of file From 43c6996a3199c77dcd33a6ed544d121bf9fde0df Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Thu, 14 Mar 2024 07:41:58 +0800 Subject: [PATCH 276/283] =?UTF-8?q?[Torch=20Dialect]=20add=20folder=20for?= =?UTF-8?q?=20aten.ceil=20and=20unify=20patterns=20of=20ceil,=20=E2=80=A6?= =?UTF-8?q?=20(#3010)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit …floor, round --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 185 +++++++++--------- lib/Dialect/Torch/IR/TorchOps.cpp | 59 +++--- .../build_tools/torch_ods_gen.py | 6 +- 3 files changed, 130 insertions(+), 120 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index 7c0d7a73e89c..19c65e8774df 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -1428,51 +1428,6 @@ def Torch_AtenNeg_Op : Torch_Op<"aten.neg_", [ }]; } -def Torch_AtenCeilOp : Torch_Op<"aten.ceil", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::ceil : (Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenCeilOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenCeilOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - -def Torch_AtenCeil_Op : Torch_Op<"aten.ceil_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::ceil_ : (Tensor) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self - ); - let results = (outs - Torch_NonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenCeil_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenCeil_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - def Torch_AtenBitwiseNotOp : Torch_Op<"aten.bitwise_not", [ AllowsTypeRefinement, HasValueSemantics, @@ -4142,7 +4097,7 @@ def Torch_AtenFloorOp : Torch_Op<"aten.floor", [ printDefaultTorchOp(printer, *this, 1, 1); } }]; - let hasCanonicalizer = 1; + let hasFolder = 1; } def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [ @@ -4167,6 +4122,98 @@ def Torch_AtenFloor_Op : Torch_Op<"aten.floor_", [ }]; } +def Torch_AtenCeilOp : Torch_Op<"aten.ceil", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::ceil : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCeilOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenCeilOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; + let hasFolder = 1; +} + +def Torch_AtenCeil_Op : Torch_Op<"aten.ceil_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::ceil_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenCeil_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenCeil_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + +def Torch_AtenRoundOp : Torch_Op<"aten.round", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::round : (Tensor) -> (Tensor)`"; + let arguments = (ins + AnyTorchTensorType:$self + ); + let results = (outs + AnyTorchTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRoundOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenRoundOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; + let hasFolder = 1; +} + +def Torch_AtenRound_Op : Torch_Op<"aten.round_", [ + IsTrailingUnderscoreInplaceVariant, + AllowsTypeRefinement + ]> { + let summary = "Generated op for `aten::round_ : (Tensor) -> (Tensor)`"; + let arguments = (ins + Torch_NonValueTensorType:$self + ); + let results = (outs + Torch_NonValueTensorType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenRound_Op::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 1, 1); + } + void AtenRound_Op::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 1, 1); + } + }]; +} + def Torch_AtenMaskedFillTensorOp : Torch_Op<"aten.masked_fill.Tensor", [ AllowsTypeRefinement, HasValueSemantics, @@ -5409,52 +5456,6 @@ def Torch_AtenTril_Op : Torch_Op<"aten.tril_", [ }]; } -def Torch_AtenRoundOp : Torch_Op<"aten.round", [ - AllowsTypeRefinement, - HasValueSemantics, - ReadOnly - ]> { - let summary = "Generated op for `aten::round : (Tensor) -> (Tensor)`"; - let arguments = (ins - AnyTorchTensorType:$self - ); - let results = (outs - AnyTorchTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenRoundOp::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenRoundOp::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; - let hasFolder = 1; -} - -def Torch_AtenRound_Op : Torch_Op<"aten.round_", [ - IsTrailingUnderscoreInplaceVariant, - AllowsTypeRefinement - ]> { - let summary = "Generated op for `aten::round_ : (Tensor) -> (Tensor)`"; - let arguments = (ins - Torch_NonValueTensorType:$self - ); - let results = (outs - Torch_NonValueTensorType:$result - ); - let hasCustomAssemblyFormat = 1; - let extraClassDefinition = [{ - ParseResult AtenRound_Op::parse(OpAsmParser &parser, OperationState &result) { - return parseDefaultTorchOp(parser, result, 1, 1); - } - void AtenRound_Op::print(OpAsmPrinter &printer) { - printDefaultTorchOp(printer, *this, 1, 1); - } - }]; -} - def Torch_AtenIndexPutOp : Torch_Op<"aten.index_put", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index 30e1ff987aa9..b7ec59bed7a5 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -743,20 +743,6 @@ OpFoldResult AtenSqueezeDimOp::fold(FoldAdaptor adaptor) { return nullptr; } -//===----------------------------------------------------------------------===// -// AtenRoundOp -//===----------------------------------------------------------------------===// - -OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { - if (getSelf().getType() != getResult().getType()) - return nullptr; - if (auto selfType = getSelf().getType().dyn_cast()) { - if (selfType.hasDtype() && selfType.getDtype().isa()) - return getSelf(); - } - return nullptr; -} - //===----------------------------------------------------------------------===// // AtenToDtypeOp //===----------------------------------------------------------------------===// @@ -1675,17 +1661,40 @@ OpFoldResult AtenNeScalarOp::fold(FoldAdaptor adaptor) { //===----------------------------------------------------------------------===// // AtenFloorOp //===----------------------------------------------------------------------===// -void AtenFloorOp::getCanonicalizationPatterns(RewritePatternSet &patterns, - MLIRContext *context) { - patterns.add(+[](AtenFloorOp op, PatternRewriter &rewriter) { - auto outputTy = op.getType().dyn_cast(); - if (outputTy && outputTy.hasDtype() && - outputTy.getDtype().isa()) { - rewriter.replaceOp(op, op.getSelf()); - return success(); - } - return failure(); - }); + +OpFoldResult AtenFloorOp::fold(FoldAdaptor adaptor) { + auto resultType = getType().dyn_cast(); + if (resultType && resultType.hasDtype() && + resultType.getDtype().isa()) { + return getSelf(); + } + return {}; +} + +//===----------------------------------------------------------------------===// +// AtenCeilOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenCeilOp::fold(FoldAdaptor adaptor) { + auto resultType = getType().dyn_cast(); + if (resultType && resultType.hasDtype() && + resultType.getDtype().isa()) { + return getSelf(); + } + return {}; +} + +//===----------------------------------------------------------------------===// +// AtenRoundOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenRoundOp::fold(FoldAdaptor adaptor) { + auto resultType = getType().dyn_cast(); + if (resultType && resultType.hasDtype() && + resultType.getDtype().isa()) { + return getSelf(); + } + return {}; } //===----------------------------------------------------------------------===// diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index fa469e035064..8b7d585cfecc 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -286,7 +286,6 @@ def emit_with_mutating_variants(key, **kwargs): "aten::atanh : (Tensor) -> (Tensor)", "aten::atan2 : (Tensor, Tensor) -> (Tensor)", "aten::neg : (Tensor) -> (Tensor)", - "aten::ceil : (Tensor) -> (Tensor)", "aten::bitwise_not : (Tensor) -> (Tensor)", "aten::div.Tensor : (Tensor, Tensor) -> (Tensor)", "aten::logical_or : (Tensor, Tensor) -> (Tensor)", @@ -347,7 +346,9 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants("aten::ge.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::eq.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::ne.Scalar : (Tensor, Scalar) -> (Tensor)", has_folder=True) - emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_canonicalizer=True) + emit_with_mutating_variants("aten::floor : (Tensor) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::ceil : (Tensor) -> (Tensor)", has_folder=True) + emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants("aten::masked_fill.Tensor : (Tensor, Tensor, Tensor) -> (Tensor)", has_canonicalizer=True) emit_with_mutating_variants("aten::addcmul : (Tensor, Tensor, Tensor, Scalar) -> (Tensor)") @@ -397,7 +398,6 @@ def emit_with_mutating_variants(key, **kwargs): emit_with_mutating_variants("aten::triu : (Tensor, int) -> (Tensor)") emit_with_mutating_variants("aten::tril : (Tensor, int) -> (Tensor)") - emit_with_mutating_variants("aten::round : (Tensor) -> (Tensor)", has_folder=True) emit_with_mutating_variants( "aten::index_put : (Tensor, Tensor?[], Tensor, bool) -> (Tensor)") emit_with_mutating_variants( From 870e63bc3c1d8b0c7d3e5aa847806f3aff5089d7 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Thu, 14 Mar 2024 08:28:33 +0800 Subject: [PATCH 277/283] [Torch Dialect] support decomposition of aten.linspace (#3006) --- .../Transforms/AbstractInterpLibrary.cpp | 16 ++++ .../Torch/Transforms/DecomposeComplexOps.cpp | 73 +++++++++++++++++ .../Transforms/LowerToBackendContract.cpp | 1 + projects/pt1/e2e_testing/xfail_sets.py | 8 ++ .../build_tools/abstract_interp_lib_gen.py | 13 +++ .../torch_mlir_e2e_test/test_suite/arange.py | 82 +++++++++++++++++++ 6 files changed, 193 insertions(+) diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index c55a2421f5be..aa9dab02d563 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -8334,6 +8334,10 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " %5 = call @__torch__.torch.jit._shape_functions.arange_end(%0, %1, %2, %3, %4) : (!torch.union, !torch.any, !torch.any, !torch.any, !torch.any) -> !torch.list\n" " return %5 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.linspace\"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.list {\n" +" %0 = torch.prim.ListConstruct %arg2 : (!torch.int) -> !torch.list\n" +" return %0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.add.Tensor\"(%arg0: !torch.list, %arg1: !torch.list, %arg2: !torch.float) -> !torch.list {\n" " %0 = call @__torch__.torch.jit._shape_functions.broadcast(%arg0, %arg1) : (!torch.list, !torch.list) -> !torch.list\n" " return %0 : !torch.list\n" @@ -12568,6 +12572,18 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.linspace\"(%arg0: !torch.number, %arg1: !torch.number, %arg2: !torch.int, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional, %arg6: !torch.optional) -> !torch.int {\n" +" %int6 = torch.constant.int 6\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__is__ %arg3, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" torch.prim.If.yield %int6 : !torch.int\n" +" } else {\n" +" %2 = torch.prim.unchecked_cast %arg3 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.normal_functional\"(%arg0: !torch.tuple, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.any) -> !torch.int {\n" " %none = torch.constant.none\n" " %str = torch.constant.str \"AssertionError: \"\n" diff --git a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp index 5335cbba9bb0..9b897f991417 100644 --- a/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp +++ b/lib/Dialect/Torch/Transforms/DecomposeComplexOps.cpp @@ -6331,6 +6331,78 @@ class DecomposeAtenRandOp : public OpRewritePattern { }; } // namespace +namespace { +class DecomposeAtenLinspaceOp : public OpRewritePattern { +public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AtenLinspaceOp op, + PatternRewriter &rewriter) const override { + Location loc = op.getLoc(); + MLIRContext *context = getContext(); + + auto baseType = ValueTensorType::getWithLeastStaticInformation(context); + Value none = rewriter.create(loc); + Value falseVal = rewriter.create(loc, false); + Value zero = + rewriter.create(loc, rewriter.getI64IntegerAttr(0)); + Value one = + rewriter.create(loc, rewriter.getI64IntegerAttr(1)); + + Value addStart; + int64_t steps; + if (matchPattern(op.getSteps(), m_TorchConstantInt(&steps)) && steps == 1) { + // specically handle steps == 1 + Value arange = rewriter.create( + loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(), + op.getDevice(), op.getPinMemory()); + addStart = rewriter.create(loc, baseType, arange, + op.getStart(), one); + } else { + // handle steps != 1 or dynamic steps + Value neOrNot = rewriter.create(loc, op.getSteps(), one); + rewriter.create( + loc, neOrNot, + rewriter.getStringAttr("linspace's dynamic steps must not be 1")); + // create arange: [0, ..., steps - 1] + Value arange = rewriter.create( + loc, baseType, zero, op.getSteps(), /*dtype=*/none, op.getLayout(), + op.getDevice(), op.getPinMemory()); + // calculate (end - start) / (steps - 1) + Value sub; + if (op.getEnd().getType().isa() || + op.getStart().getType().isa()) { + sub = rewriter.create(loc, Torch::FloatType::get(context), + op.getEnd(), op.getStart()); + } else { + sub = rewriter.create(loc, op.getEnd(), op.getStart()); + } + Value div = rewriter.create( + loc, sub, rewriter.create(loc, op.getSteps(), one)); + // calculate [0, ..., steps - 1] * ((end - start) / (steps - 1)) + start + Value mulScalar = + rewriter.create(loc, baseType, arange, div); + addStart = rewriter.create(loc, baseType, mulScalar, + op.getStart(), one); + } + // to dtype + Value result; + if (!op.getDtype().getType().isa()) { + result = rewriter.create( + loc, op.getType(), addStart, op.getDtype(), /*non_blocking=*/falseVal, + /*copy=*/falseVal, /*memory_format=*/none); + } else { + Value f32Type = rewriter.create( + loc, (int)torch_upstream::ScalarType::Float); + result = rewriter.create( + loc, op.getType(), addStart, f32Type, /*non_blocking=*/falseVal, + /*copy=*/falseVal, /*memory_format=*/none); + } + rewriter.replaceOp(op, result); + return success(); + } +}; +} // namespace + namespace { class DecomposeAtenVarMeanOp : public OpRewritePattern { public: @@ -7216,6 +7288,7 @@ class DecomposeComplexOpsPass addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal(patterns); + addPatternIfTargetOpIsIllegal(patterns); addPatternIfTargetOpIsIllegal< DecomposeAtenArgMinMaxOp>(patterns); addPatternIfTargetOpIsIllegal< diff --git a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp index 44a3986ac52f..912f7990c2ae 100644 --- a/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp +++ b/lib/Dialect/Torch/Transforms/LowerToBackendContract.cpp @@ -424,6 +424,7 @@ static void markDecomposedOpsAsIllegal(MLIRContext *context, target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); + target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); target.addIllegalOp(); diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index e262f52e3018..f1ec03592e79 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -841,6 +841,11 @@ "ZerosModuleFloat3D_basic", "ZerosModuleInt2D_basic", "ZerosModuleInt3D_basic", + "LinspaceDtypeModule_basic", + "LinspaceEmptyModule_basic", + "LinspaceModule_basic", + "LinspaceOneSizeModule_basic", + "LinspaceTwoSizeModule_basic", } STABLEHLO_CRASHING_SET = { @@ -1260,6 +1265,9 @@ "_LogSoftmaxModuleStable_basic", "_LogSoftmaxModule_basic", "_SoftmaxModule_basic", + "LinspaceModule_basic", + "LinspaceOneSizeModule_basic", + "LinspaceTwoSizeModule_basic", } MAKE_FX_TOSA_PASS_SET = (TOSA_PASS_SET | { diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 1f6e450da987..7edae5c97dc5 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1124,6 +1124,9 @@ def aten〇arange〇start〡shape(start: float, end: float, dtype: Optional[int] def aten〇arange〡shape(end: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: return upstream_shape_functions.arange_end(end, dtype, layout, device, pin_memory) +def aten〇linspace〡shape(start: float, end: float, steps: int, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]: + return [steps] + @check_shape_function([ Invocation(TensorOfShape(2, 3), TensorOfShape(2, 3)), # Basic case. Invocation(TensorOfShape(2, 3), TensorOfShape(3)), # Rank broadcasting. @@ -4248,6 +4251,16 @@ def aten〇randn〡dtype(size: List[int], dtype: Optional[int] = None, layout: O assert not is_integer_dtype(dtype) return dtype +@check_dtype_function([Invocation(start=1, end=10, steps=9), + Invocation(start=1, end=10, steps=9, dtype=torch.int32), + Invocation(start=1, end=10, steps=9, dtype=torch.double), + Invocation(start=1, end=10, steps=9, dtype=torch.complex64), + Invocation(start=1, end=10, steps=9, dtype=torch.complex128)]) +def aten〇linspace〡dtype(start: Union[int, float, complex], end: Union[int, float, complex], steps: int, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> int: + if dtype is None: + return torch.float32 + return dtype + @check_dtype_function(_check_tensors_with_the_same_dtype( num_of_tensors=1, error_types={torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64})) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py index be41d71edbe3..fff3e60c4605 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/arange.py @@ -62,6 +62,7 @@ def forward(self): def ArangeZeroElementOutputModule_basic(module, tu: TestUtils): module.forward() +# ============================================================================== class ArangeStartIntModule(torch.nn.Module): def __init__(self): @@ -130,6 +131,7 @@ def forward(self): def ArangeNegativeStartFloatModule_basic(module, tu: TestUtils): module.forward() +# ============================================================================== class ArangeStartStepIntModule(torch.nn.Module): def __init__(self): @@ -198,6 +200,7 @@ def forward(self): def ArangeStartNegativeStepFloatModule_basic(module, tu: TestUtils): module.forward() +# ============================================================================== class ArangeDtypeFloatModule(torch.nn.Module): def __init__(self): @@ -232,6 +235,7 @@ def forward(self): def ArangeDtypeIntModule_basic(module, tu: TestUtils): module.forward() +# ============================================================================== class ArangeFalsePinMemoryModule(torch.nn.Module): def __init__(self): @@ -298,3 +302,81 @@ def forward(self, x): @register_test_case(module_factory=lambda: ArangeStartOutDtypeModule()) def ArangeStartOutDtypeModule_basic(module, tu: TestUtils): module.forward(torch.zeros(12).to(torch.int64)) + +# ============================================================================== + +class LinspaceModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.linspace(-10.1, 10.1, 10) + +@register_test_case(module_factory=lambda: LinspaceModule()) +def LinspaceModule_basic(module, tu: TestUtils): + module.forward() + +class LinspaceDtypeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.linspace(-10.1, 10.1, 10, dtype=torch.int64) + + +@register_test_case(module_factory=lambda: LinspaceDtypeModule()) +def LinspaceDtypeModule_basic(module, tu: TestUtils): + module.forward() + +class LinspaceEmptyModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.linspace(-10.1, 10.1, 0) + +@register_test_case(module_factory=lambda: LinspaceEmptyModule()) +def LinspaceEmptyModule_basic(module, tu: TestUtils): + module.forward() + +class LinspaceOneSizeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.linspace(-10.1, 10.1, 1) + +@register_test_case(module_factory=lambda: LinspaceOneSizeModule()) +def LinspaceOneSizeModule_basic(module, tu: TestUtils): + module.forward() + +class LinspaceTwoSizeModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args([ + None, + ]) + def forward(self): + return torch.linspace(-10.1, 10.1, 2) + +@register_test_case(module_factory=lambda: LinspaceTwoSizeModule()) +def LinspaceTwoSizeModule_basic(module, tu: TestUtils): + module.forward() From 29ac23a7903cd6e74f113020f1c037c2c27f8fac Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Thu, 14 Mar 2024 11:41:48 +0800 Subject: [PATCH 278/283] Setuptools uses a separate build directory (#3023) * setuptools not steal the build directory name https://github.com/llvm/torch-mlir/pull/3021#issuecomment-1994447855 * support pre-built LLVM * support CMAKE_BUILD_TYPE env --- .gitignore | 1 + setup.py | 124 +++++++++++++++++++++++++++++++++++++---------------- 2 files changed, 87 insertions(+), 38 deletions(-) diff --git a/.gitignore b/.gitignore index 5c407428929c..00a5bc96f221 100644 --- a/.gitignore +++ b/.gitignore @@ -11,6 +11,7 @@ externals/pytorch/ libtorch* /build/ +/setup_build/ __pycache__ *.pyc diff --git a/setup.py b/setup.py index 77c8b2ad047d..4863a9807522 100644 --- a/setup.py +++ b/setup.py @@ -30,68 +30,123 @@ # on the CMake side to organize that directory already, so we avoid duplicating # that here, and just package up its contents. import os +import pathlib import shutil import subprocess import sys -import sysconfig +import multiprocessing from distutils.command.build import build as _build -from distutils.sysconfig import get_python_inc from setuptools import setup, Extension from setuptools.command.build_ext import build_ext from setuptools.command.build_py import build_py +def check_env_flag(name: str, default=None) -> bool: + return str(os.getenv(name, default)).upper() in ["ON", "1", "YES", "TRUE", "Y"] + + PACKAGE_VERSION = os.environ.get("TORCH_MLIR_PYTHON_PACKAGE_VERSION") or "0.0.1" # If true, enable LTC build by default TORCH_MLIR_ENABLE_LTC_DEFAULT = True -TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS = int(os.environ.get('TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS', False)) +TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS = check_env_flag( + 'TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS', False) +LLVM_INSTALL_DIR = os.getenv('LLVM_INSTALL_DIR', None) +SRC_DIR = pathlib.Path(__file__).parent.absolute() +CMAKE_BUILD_TYPE = os.getenv("CMAKE_BUILD_TYPE", "Release") + # Build phase discovery is unreliable. Just tell it what phases to run. class CustomBuild(_build): + def initialize_options(self): + _build.initialize_options(self) + # Make setuptools not steal the build directory name, + # because the mlir c++ developers are quite + # used to having build/ be for cmake + self.build_base = "setup_build" + def run(self): self.run_command("build_py") self.run_command("build_ext") self.run_command("build_scripts") + class CMakeBuild(build_py): + def cmake_build(self, cmake_build_dir): + llvm_dir = str(SRC_DIR / "externals" / "llvm-project" / "llvm") + enable_ltc = check_env_flag('TORCH_MLIR_ENABLE_LTC', TORCH_MLIR_ENABLE_LTC_DEFAULT) + max_jobs = os.getenv("MAX_JOBS") or str(multiprocessing.cpu_count()) + + cmake_config_args = [ + f"cmake", + f"-DCMAKE_BUILD_TYPE={CMAKE_BUILD_TYPE}", + f"-DPython3_EXECUTABLE={sys.executable}", + f"-DPython3_FIND_VIRTUALENV=ONLY", + f"-DMLIR_ENABLE_BINDINGS_PYTHON=ON", + f"-DLLVM_TARGETS_TO_BUILD=host", + f"-DLLVM_ENABLE_ZSTD=OFF", + # Optimization options for building wheels. + f"-DCMAKE_VISIBILITY_INLINES_HIDDEN=ON", + f"-DCMAKE_C_VISIBILITY_PRESET=hidden", + f"-DCMAKE_CXX_VISIBILITY_PRESET=hidden", + f"-DTORCH_MLIR_ENABLE_LTC={'ON' if enable_ltc else 'OFF'}", + f"-DTORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS={'OFF' if TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else 'ON'}", + ] + if LLVM_INSTALL_DIR: + cmake_config_args += [ + f"-DMLIR_DIR='{LLVM_INSTALL_DIR}/lib/cmake/mlir/'", + f"-DLLVM_DIR='{LLVM_INSTALL_DIR}/lib/cmake/llvm/'", + f"{SRC_DIR}", + ] + else: + cmake_config_args += [ + f"-DLLVM_ENABLE_PROJECTS=mlir", + f"-DLLVM_EXTERNAL_PROJECTS='torch-mlir'", + f"-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR={SRC_DIR}", + f"{llvm_dir}", + ] + cmake_build_args = [ + f"cmake", + f"--build", + f".", + f"--config", + f"{CMAKE_BUILD_TYPE}", + f"--target", + f"TorchMLIRPythonModules", + f"--", + f"-j{max_jobs}" + ] + try: + subprocess.check_call(cmake_config_args, cwd=cmake_build_dir) + subprocess.check_call(cmake_build_args, cwd=cmake_build_dir) + except subprocess.CalledProcessError as e: + print("cmake build failed with\n", e) + print("debug by follow cmake command:") + sys.exit(e.returncode) + finally: + print(f"cmake config: {' '.join(cmake_config_args)}") + print(f"cmake build: {' '.join(cmake_build_args)}") + print(f"cmake workspace: {cmake_build_dir}") + + def run(self): target_dir = self.build_lib cmake_build_dir = os.getenv("TORCH_MLIR_CMAKE_BUILD_DIR") if not cmake_build_dir: cmake_build_dir = os.path.abspath( os.path.join(target_dir, "..", "cmake_build")) - python_package_dir = os.path.join(cmake_build_dir, - "tools", "torch-mlir", "python_packages", - "torch_mlir") + if LLVM_INSTALL_DIR: + python_package_dir = os.path.join(cmake_build_dir, + "python_packages", + "torch_mlir") + else: + python_package_dir = os.path.join(cmake_build_dir, + "tools", "torch-mlir", "python_packages", + "torch_mlir") if not os.getenv("TORCH_MLIR_CMAKE_BUILD_DIR_ALREADY_BUILT"): - src_dir = os.path.abspath(os.path.dirname(__file__)) - llvm_dir = os.path.join( - src_dir, "externals", "llvm-project", "llvm") - - enable_ltc = int(os.environ.get('TORCH_MLIR_ENABLE_LTC', TORCH_MLIR_ENABLE_LTC_DEFAULT)) - - cmake_args = [ - f"-DCMAKE_BUILD_TYPE=Release", - f"-DPython3_EXECUTABLE={sys.executable}", - f"-DPython3_FIND_VIRTUALENV=ONLY", - f"-DLLVM_TARGETS_TO_BUILD=host", - f"-DMLIR_ENABLE_BINDINGS_PYTHON=ON", - f"-DLLVM_ENABLE_PROJECTS=mlir", - f"-DLLVM_ENABLE_ZSTD=OFF", - f"-DLLVM_EXTERNAL_PROJECTS=torch-mlir", - f"-DLLVM_EXTERNAL_TORCH_MLIR_SOURCE_DIR={src_dir}", - # Optimization options for building wheels. - f"-DCMAKE_VISIBILITY_INLINES_HIDDEN=ON", - f"-DCMAKE_C_VISIBILITY_PRESET=hidden", - f"-DCMAKE_CXX_VISIBILITY_PRESET=hidden", - f"-DTORCH_MLIR_ENABLE_LTC={'ON' if enable_ltc else 'OFF'}", - f"-DTORCH_MLIR_ENABLE_PYTORCH_EXTENSIONS={'OFF' if TORCH_MLIR_ENABLE_ONLY_MLIR_PYTHON_BINDINGS else 'ON'}", - ] - os.makedirs(cmake_build_dir, exist_ok=True) cmake_cache_file = os.path.join(cmake_build_dir, "CMakeCache.txt") if os.path.exists(cmake_cache_file): @@ -109,14 +164,7 @@ def run(self): shutil.rmtree(mlir_libs_dir) else: print(f"Not removing _mlir_libs dir (does not exist): {mlir_libs_dir}") - - subprocess.check_call(["cmake", llvm_dir] + - cmake_args, cwd=cmake_build_dir) - subprocess.check_call(["cmake", - "--build", ".", - "--config", "Release", - "--target", "TorchMLIRPythonModules"], - cwd=cmake_build_dir) + self.cmake_build(cmake_build_dir) if os.path.exists(target_dir): shutil.rmtree(target_dir, ignore_errors=False, onerror=None) From f83b0da914a268edc1ff506483fd2a78c495f939 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Wed, 17 Apr 2024 16:07:01 +0200 Subject: [PATCH 279/283] update xfails --- projects/pt1/e2e_testing/xfail_sets.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index c6a6e98bac81..05343f20c1dd 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1023,15 +1023,11 @@ "EinsumStaticModule_basic", "ElementwiseAbsFloatModule_basic", "ElementwiseAbsIntModule_basic", - "ElementwiseAcosModule_basic", - "ElementwiseAcosTensorFloatModule_basic", "ElementwiseAddModule_basic", "ElementwiseAddScalarFloatModule_basic", "ElementwiseAddScalarInt64Module_basic", "ElementwiseAddScalarIntModule_basic", "ElementwiseAddScalar_TensorLiteralInt32_Module_basic", - "ElementwiseAsinModule_basic", - "ElementwiseAsinTensorFloatModule_basic", "ElementwiseAtenDivIntScalarModule_basic", "ElementwiseAtenLogicalAndOpModule_basic", "ElementwiseAtenLogicalAndOpPromoteBroadcastModule_basic", From 33f46303596d8a8197e752410f586ee0c32b1594 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Mon, 22 Apr 2024 13:53:14 +0200 Subject: [PATCH 280/283] lib/InitAll.cpp: Explicitly depend on sparse_tensors for tests Especially when not using stablehlo, which also pulls this in. Need for some tests that use it. --- lib/InitAll.cpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index eebfc940870c..1205d6343e43 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -16,6 +16,7 @@ #include "mlir/Dialect/MLProgram/IR/MLProgram.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tosa/IR/TosaOps.h" #include "mlir/IR/Dialect.h" @@ -47,7 +48,8 @@ void mlir::torch::registerOptionalInputDialects( mlir::DialectRegistry ®istry) { registry.insert(); + scf::SCFDialect, tensor::TensorDialect, tosa::TosaDialect, + sparse_tensor::SparseTensorDialect>(); } void mlir::torch::registerAllPasses() { From f9b3b0c9866fa8602cd0ee91a1c757d9be3aed7c Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Mon, 22 Apr 2024 13:54:05 +0200 Subject: [PATCH 281/283] Disable some tests on older onnx/torch versions --- test/python/fx_importer/basic_test.py | 2 ++ test/python/fx_importer/sparse_test.py | 2 ++ test/python/onnx_importer/command_line_test.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/test/python/fx_importer/basic_test.py b/test/python/fx_importer/basic_test.py index fc5b2030b648..a51032273999 100644 --- a/test/python/fx_importer/basic_test.py +++ b/test/python/fx_importer/basic_test.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. +# Requires torch>=2.3.0.dev20240307 +# UNSUPPORTED: true # RUN: %PYTHON %s | FileCheck %s from typing import Optional diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 6260a5bbaab3..40c633cfc778 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -3,6 +3,8 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. +# Requires torch>=2.3.0.dev20240307 +# UNSUPPORTED: true # RUN: %PYTHON %s | FileCheck %s from typing import Any, Callable, Optional diff --git a/test/python/onnx_importer/command_line_test.py b/test/python/onnx_importer/command_line_test.py index 32dc0cbeb22f..f379376f0a4d 100644 --- a/test/python/onnx_importer/command_line_test.py +++ b/test/python/onnx_importer/command_line_test.py @@ -5,6 +5,8 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. +# Requires onnx==1.15.0 +# UNSUPPORTED: true # RUN: %PYTHON %s --output %t from pathlib import Path From 4f9aeef9a76df0ea292edbd7082e16dc95e0f2f2 Mon Sep 17 00:00:00 2001 From: Ferdinand Lemaire Date: Wed, 24 Apr 2024 09:41:23 +0100 Subject: [PATCH 282/283] Add unsupported to tests relying on python3.10 since the pipeline uses 3.8 --- test/python/compile.py | 1 - test/python/onnx_importer/_torch_mlir_config.py | 2 ++ test/python/onnx_importer/import_onnx_tool.runlit | 2 ++ test/python/onnx_importer/import_smoke_test.py | 2 ++ 4 files changed, 6 insertions(+), 1 deletion(-) diff --git a/test/python/compile.py b/test/python/compile.py index b336adafcf33..678a4137acf6 100644 --- a/test/python/compile.py +++ b/test/python/compile.py @@ -23,7 +23,6 @@ def forward(self, x): return x -# CHECK-LABEL: TEST: test_enable_ir_printing @run_test def test_enable_ir_printing(): torchscript.compile(TinyModel(), diff --git a/test/python/onnx_importer/_torch_mlir_config.py b/test/python/onnx_importer/_torch_mlir_config.py index f597b63b4dec..fdcf61cb81d7 100644 --- a/test/python/onnx_importer/_torch_mlir_config.py +++ b/test/python/onnx_importer/_torch_mlir_config.py @@ -4,6 +4,8 @@ # Also available under a BSD-style license. See LICENSE. # RUN: %PYTHON %s +# Requires python>=3.10 +# UNSUPPORTED: true """This file exists so that the tests can find/configure torch_mlir. diff --git a/test/python/onnx_importer/import_onnx_tool.runlit b/test/python/onnx_importer/import_onnx_tool.runlit index 45b733f9da7a..2f170c739896 100644 --- a/test/python/onnx_importer/import_onnx_tool.runlit +++ b/test/python/onnx_importer/import_onnx_tool.runlit @@ -1,3 +1,5 @@ # RUN: %PYTHON -m torch_mlir.tools.import_onnx %S/LeakyReLU.onnx | FileCheck %s +# Requires python>=3.10 +# UNSUPPORTED: true # CHECK: torch.operator "onnx.LeakyRelu" diff --git a/test/python/onnx_importer/import_smoke_test.py b/test/python/onnx_importer/import_smoke_test.py index bd687ae37049..533ffbc45d70 100644 --- a/test/python/onnx_importer/import_smoke_test.py +++ b/test/python/onnx_importer/import_smoke_test.py @@ -6,6 +6,8 @@ # Also available under a BSD-style license. See LICENSE. # RUN: %PYTHON %s --output %t +# Requires python>=3.10 +# UNSUPPORTED: true from glob import glob from pathlib import Path From 1adadd30b6e3c07be092584b05096e61ed25d88f Mon Sep 17 00:00:00 2001 From: Ferdinand Lemaire Date: Wed, 24 Apr 2024 16:41:02 +0100 Subject: [PATCH 283/283] Unsupport more tests --- projects/pt1/python/test/dynamo_fx_importer/basic.py | 2 ++ projects/pt1/python/test/torchscript_e2e_test/basic.py | 2 ++ .../pt1/python/test/torchscript_e2e_test/compilation_failure.py | 2 ++ projects/pt1/python/test/torchscript_e2e_test/error_reports.py | 2 ++ .../pt1/python/test/torchscript_e2e_test/non_tensor_values.py | 2 ++ .../pt1/python/test/torchscript_e2e_test/runtime_failure.py | 2 ++ projects/pt1/python/test/torchscript_e2e_test/submodule.py | 2 ++ 7 files changed, 14 insertions(+) diff --git a/projects/pt1/python/test/dynamo_fx_importer/basic.py b/projects/pt1/python/test/dynamo_fx_importer/basic.py index cea2f639f01d..fd3dcc7f4c2d 100644 --- a/projects/pt1/python/test/dynamo_fx_importer/basic.py +++ b/projects/pt1/python/test/dynamo_fx_importer/basic.py @@ -4,6 +4,8 @@ # Also available under a BSD-style license. See LICENSE. # RUN: %PYTHON %s | FileCheck %s +# Requires python>=3.10 +# UNSUPPORTED: true from typing import List diff --git a/projects/pt1/python/test/torchscript_e2e_test/basic.py b/projects/pt1/python/test/torchscript_e2e_test/basic.py index fa3f6f29729b..2dcface6f4e8 100644 --- a/projects/pt1/python/test/torchscript_e2e_test/basic.py +++ b/projects/pt1/python/test/torchscript_e2e_test/basic.py @@ -4,6 +4,8 @@ # Also available under a BSD-style license. See LICENSE. # RUN: %PYTHON %s | FileCheck %s +# Requires python>=3.10 +# UNSUPPORTED: true import torch diff --git a/projects/pt1/python/test/torchscript_e2e_test/compilation_failure.py b/projects/pt1/python/test/torchscript_e2e_test/compilation_failure.py index 9b9091452f01..36d81d83ab04 100644 --- a/projects/pt1/python/test/torchscript_e2e_test/compilation_failure.py +++ b/projects/pt1/python/test/torchscript_e2e_test/compilation_failure.py @@ -4,6 +4,8 @@ # Also available under a BSD-style license. See LICENSE. # RUN: %PYTHON %s | FileCheck %s +# Requires python>=3.10 +# UNSUPPORTED: true import torch diff --git a/projects/pt1/python/test/torchscript_e2e_test/error_reports.py b/projects/pt1/python/test/torchscript_e2e_test/error_reports.py index f3321285999a..1ebc3dd6dd42 100644 --- a/projects/pt1/python/test/torchscript_e2e_test/error_reports.py +++ b/projects/pt1/python/test/torchscript_e2e_test/error_reports.py @@ -4,6 +4,8 @@ # Also available under a BSD-style license. See LICENSE. # RUN: %PYTHON %s | FileCheck %s +# Requires python>=3.10 +# UNSUPPORTED: true from typing import List, Tuple, Dict diff --git a/projects/pt1/python/test/torchscript_e2e_test/non_tensor_values.py b/projects/pt1/python/test/torchscript_e2e_test/non_tensor_values.py index a1c8c5adfdf4..899dae0c1b9f 100644 --- a/projects/pt1/python/test/torchscript_e2e_test/non_tensor_values.py +++ b/projects/pt1/python/test/torchscript_e2e_test/non_tensor_values.py @@ -4,6 +4,8 @@ # Also available under a BSD-style license. See LICENSE. # RUN: %PYTHON %s | FileCheck %s +# Requires python>=3.10 +# UNSUPPORTED: true from typing import List, Tuple, Dict diff --git a/projects/pt1/python/test/torchscript_e2e_test/runtime_failure.py b/projects/pt1/python/test/torchscript_e2e_test/runtime_failure.py index 3581c1b6d41f..a5cc12e66857 100644 --- a/projects/pt1/python/test/torchscript_e2e_test/runtime_failure.py +++ b/projects/pt1/python/test/torchscript_e2e_test/runtime_failure.py @@ -4,6 +4,8 @@ # Also available under a BSD-style license. See LICENSE. # RUN: %PYTHON %s | FileCheck %s +# Requires python>=3.10 +# UNSUPPORTED: true import torch diff --git a/projects/pt1/python/test/torchscript_e2e_test/submodule.py b/projects/pt1/python/test/torchscript_e2e_test/submodule.py index c88ad53b31b3..8fc520c94396 100644 --- a/projects/pt1/python/test/torchscript_e2e_test/submodule.py +++ b/projects/pt1/python/test/torchscript_e2e_test/submodule.py @@ -4,6 +4,8 @@ # Also available under a BSD-style license. See LICENSE. # RUN: %PYTHON %s | FileCheck %s +# Requires python>=3.10 +# UNSUPPORTED: true import torch