From f34c187ac417d8f7cf84b1d6c69726dc69d90b7b Mon Sep 17 00:00:00 2001 From: penguin_wwy <940375606@qq.com> Date: Fri, 15 Mar 2024 23:29:48 +0800 Subject: [PATCH 1/9] Normalize type hints to be compatible with multiple Python versions (#3028) Although we provide a wheel package for Python 3.8, it may actually throw the following exception: `TypeError: 'type' object is not subscriptable` --- .../build_tools/abstract_interp_lib_gen.py | 12 ++++---- python/torch_mlir/extras/fx_importer.py | 28 ++++++++--------- python/torch_mlir/extras/onnx_importer.py | 30 +++++++++---------- setup.py | 1 + test/python/fx_importer/sparse_test.py | 4 +-- 5 files changed, 38 insertions(+), 37 deletions(-) diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index 9ef07ffd073b..faa31ae5965b 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -from typing import List, Optional, Any, Tuple, Union +from typing import List, Optional, Any, Tuple, Union, Dict, Set import argparse import os @@ -1767,9 +1767,9 @@ def aten〇upsample_nearest2d〡shape(self: List[int], output_size: List[int], s def _check_tensors_with_the_same_dtype( num_of_tensors: Optional[int] = None, - tensor_shapes: Optional[list[tuple[int]]] = None, + tensor_shapes: Optional[List[Tuple[int]]] = None, tensor_device: Optional[torch.device] = None, - error_types: Optional[set[int]] = None, *args, **kwargs): + error_types: Optional[Set[int]] = None, *args, **kwargs): """Create invocations where all tensors have the same dtype. This function generates invocations with `num_of_tensors` tensors @@ -1801,10 +1801,10 @@ def _check_tensors_with_the_same_dtype( return invocations def _check_two_tensor_op( - tensor_shapes: Optional[list[tuple[int]]] = None, + tensor_shapes: Optional[List[Tuple[int]]] = None, tensor_device: Optional[torch.device] = None, - input_error_types: Optional[set[int]] = None, - output_error_types: Optional[set[int]] = None, **kwargs): + input_error_types: Optional[Set[int]] = None, + output_error_types: Optional[Set[int]] = None, **kwargs): """Generate invocations for basic two-tensor dtype functions. This helper function is meant to be used to check dtype functions that diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 952b638c1988..4eaeb7ac8dfb 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -276,7 +276,7 @@ class SparsityMeta: batch_dim: int sparse_dim: int dense_dim: int - blocksize: Optional[tuple[int, int]] + blocksize: Optional[Tuple[int, int]] pos_dtype: torch.dtype crd_dtype: torch.dtype @@ -489,8 +489,8 @@ def import_program( default policy is to capture them as frozen values. """ # Create lookaside table of placeholders/outputs. - placeholder_nodes: dict[str, Node] = {} - all_producer_nodes: dict[str, Node] = {} + placeholder_nodes: Dict[str, Node] = {} + all_producer_nodes: Dict[str, Node] = {} loc: Optional[Location] = None for node in prog.graph.nodes: if loc is None: @@ -522,15 +522,15 @@ def import_program( } # Additional bindings that we need to set up after the function is created. - mutable_buffer_target_producers: dict[str, str] = {} - constant_tensors: dict[Node, torch.Tensor] = {} - parameter_bindings: dict[Node, tuple[Any, InputInfo]] = {} - buffer_bindings: dict[Node, tuple[Any, InputInfo]] = {} + mutable_buffer_target_producers: Dict[str, str] = {} + constant_tensors: Dict[Node, torch.Tensor] = {} + parameter_bindings: Dict[Node, Tuple[Any, InputInfo]] = {} + buffer_bindings: Dict[Node, Tuple[Any, InputInfo]] = {} # Derive user outputs that we preserve. These will be nodes of the # producer for the output. - user_outputs: list[Node] = [] - user_output_types: list[IrType] = [] + user_outputs: List[Node] = [] + user_output_types: List[IrType] = [] for output_spec in sig.output_specs: kind = output_spec.kind arg = output_spec.arg @@ -548,8 +548,8 @@ def import_program( mutable_buffer_target_producers[output_spec.target] = arg.name # Derive user inputs. These will be op=='placeholder' nodes. - user_inputs: list[Node] = [] - user_input_types: list[IrType] = [] + user_inputs: List[Node] = [] + user_input_types: List[IrType] = [] for input_spec in sig.input_specs: arg = input_spec.arg if input_spec.kind == InputKind.USER_INPUT: @@ -700,7 +700,7 @@ def import_frozen_program( """ sig = prog.graph_signature state_dict = prog.state_dict - arg_replacements: dict[str, Any] = {} + arg_replacements: Dict[str, Any] = {} # If there is no "constants" attribute, consult the "state_dict". Otherwise, only look # at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969 @@ -1003,7 +1003,7 @@ def __init__( # constructs and returns a value. self._v: Dict[Union[Callable[[], Value], Tuple[torch_fx.Node, int]], Value] = {} # Map of node name to hook that should be called when it is produced. - self._on_node_produced: dict[str, Callable[[Value], None]] = {} + self._on_node_produced: Dict[str, Callable[[Value], None]] = {} # Statically multi-result nodes which we have de-tupled are noted here. # They will have their getitem calls short-circuited. self._multi_result_nodes: Set[torch_fx.Node] = set() @@ -1118,7 +1118,7 @@ def on_produced(value: Value): self._on_node_produced[info.mutable_producer_node_name] = on_produced - def return_node_values(self, loc, nodes: list[Node]): + def return_node_values(self, loc, nodes: List[Node]): with loc, InsertionPoint(self._b): operands = [self.resolve_node_value(n) for n in nodes] func_dialect.ReturnOp(operands, loc=loc) diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index 289e5722efce..f53922d2a818 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -33,7 +33,7 @@ "The onnx package (`pip install onnx`) is required to use the onnx importer" ) from e -from typing import Optional +from typing import Optional, List, Dict, Tuple from dataclasses import dataclass @@ -113,16 +113,16 @@ class GraphInfo: def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto): self.model_info = model_info self.graph_proto = graph_proto - self.initializer_map: dict[str, onnx.TensorProto] = { + self.initializer_map: Dict[str, onnx.TensorProto] = { n.name: n for n in graph_proto.initializer } - self.value_info_map: dict[str, onnx.ValueInfoProto] = { + self.value_info_map: Dict[str, onnx.ValueInfoProto] = { n.name: n for n in graph_proto.value_info } - self.declared_input_map: dict[str, onnx.ValueInfoProto] = { + self.declared_input_map: Dict[str, onnx.ValueInfoProto] = { n.name: n for n in graph_proto.input } - self.output_map: dict[str, onnx.ValueInfoProto] = { + self.output_map: Dict[str, onnx.ValueInfoProto] = { n.name: n for n in graph_proto.output } @@ -191,7 +191,7 @@ def __init__( self._gi = graph_info self._p = parent_op self._b = block - self._nv_map: dict[str, Value] = {} + self._nv_map: Dict[str, Value] = {} @classmethod def define_function( @@ -225,7 +225,7 @@ def _populate_graph_attrs(self, container_op: Operation): with container_op.context: i64_type = IntegerType.get_signed(64) default_opset_version = 0 - opset_versions: dict[str, IntegerAttr] = {} + opset_versions: Dict[str, IntegerAttr] = {} for opset_import in m.opset_import: if opset_import.domain: opset_versions[opset_import.domain] = IntegerAttr.get( @@ -335,7 +335,7 @@ def import_node(self, node: onnx.NodeProto): for output_name, output_value in zip(output_names, custom_op.results): self._nv_map[output_name] = output_value - def import_attributes(self, onnx_attrs: list[onnx.AttributeProto]): + def import_attributes(self, onnx_attrs: List[onnx.AttributeProto]): attrs = {} for onnx_attr in onnx_attrs: attr_type = onnx_attr.type @@ -358,14 +358,14 @@ def import_attributes(self, onnx_attrs: list[onnx.AttributeProto]): attrs[f"torch.onnx.{onnx_attr.name}"] = result return attrs - def count_regions(self, onnx_attrs: list[onnx.AttributeProto]): + def count_regions(self, onnx_attrs: List[onnx.AttributeProto]): count = 0 for onnx_attr in onnx_attrs: if onnx_attr.type == onnx.AttributeProto.AttributeType.GRAPH: count += 1 return count - def import_regions(self, onnx_attrs: list[onnx.AttributeProto], op): + def import_regions(self, onnx_attrs: List[onnx.AttributeProto], op): attr_map = {} for onnx_attr in onnx_attrs: attr_type = onnx_attr.type @@ -458,10 +458,10 @@ class ContextCache: def __init__(self, context: Context): self._c = context - self._elem_type_map: dict[int, IrType] = {} - self._list_type_map:dict[str, IrType] = {} - self._optional_type_map:dict[str, IrType] = {} - self._vtensor_type_map: dict[tuple[tuple[Optional[int]], IrType], IrType] = {} + self._elem_type_map: Dict[int, IrType] = {} + self._list_type_map:Dict[str, IrType] = {} + self._optional_type_map:Dict[str, IrType] = {} + self._vtensor_type_map: Dict[Tuple[Tuple[Optional[int]], IrType], IrType] = {} def tensor_element_type(self, elem_type: int) -> IrType: t = self._elem_type_map.get(elem_type) @@ -539,7 +539,7 @@ def get_optional_element_type(self, tp: onnx.TypeProto) -> IrType: f"Unsupport optional element type") def get_vtensor_type( - self, dims: tuple[Optional[int]], element_type: IrType + self, dims: Tuple[Optional[int]], element_type: IrType ) -> IrType: key = (dims, element_type) t = self._vtensor_type_map.get(key) diff --git a/setup.py b/setup.py index 4863a9807522..3cd5f2eca2d4 100644 --- a/setup.py +++ b/setup.py @@ -229,6 +229,7 @@ def build_extension(self, ext): "build_py": CMakeBuild, }, ext_modules=EXT_MODULES, + python_requires=">=3.8", install_requires=INSTALL_REQUIRES, extras_require={ "onnx": [ diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 6260a5bbaab3..52f10de321e7 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -5,7 +5,7 @@ # RUN: %PYTHON %s | FileCheck %s -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Tuple, Dict import torch import torch.export @@ -80,7 +80,7 @@ def sparse_metadata(a: torch.Tensor) -> SparsityMeta: def sparse_export( - f: Callable, args: tuple[Any, ...], kwargs: Optional[dict[str, Any]] = None + f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None ) -> torch.export.ExportedProgram: """ This is a ***temporary*** wrapper around `torch.export.export` From d8a52e82c2c3920697e3ceb71ecf981d46f92fef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Xinan=20Jiang=28=E5=A7=9C=E6=9B=A6=E6=A5=A0=29?= Date: Sat, 16 Mar 2024 01:14:09 +0800 Subject: [PATCH 2/9] [onnx] Fix onnx.cast cases between int32 and int64 (#2982) 2 modifications: 1. torch.int64 is enum 4 in TORCH_DTYPE_TO_INT 2. add int32 support --- lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp | 4 +++- projects/pt1/e2e_testing/xfail_sets.py | 3 --- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 2e3f3e8b8053..364714136264 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -43,8 +43,10 @@ static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) { switch (dtypeIntOnnx) { case 1: return 6; // float + case 6: + return 3; // int32 case 7: - return 5; // int64 + return 4; // int64 case 9: return 11; // bool case 10: diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index b8b0e12a658d..fbd0f5bb6eb0 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1883,10 +1883,7 @@ "BucketizeTensorOutInt32RightModule_basic", "ElementwiseToDtypeI64ToI8Module_basic", "ElementwiseToDtypeI64ToUI8Module_basic", - "HBC_basic", "QuantizedMLP_basic", - "TypeConversionI1ToI32Module_basic", - "TypeConversionI64ToI32Module_basic", # Failure - onnx_lowering: onnx.Clip "NormalizeModule_basic", From c51e2130f269615e0418ca421e5976f94022daee Mon Sep 17 00:00:00 2001 From: Pavani Chowdary Date: Mon, 18 Mar 2024 17:54:37 +0530 Subject: [PATCH 3/9] [onnx] support for lowering mod op from onnx to torch (#2859) nod-ai/Shark-Turbine#267 --------- Authored-by: boddu.pavani@research.iiit.ac.in Co-authored-by: Vivek Khandelwal --- .../TorchOnnxToTorch/DefaultDomainGtoP.cpp | 21 +++++++++++++++++++ projects/pt1/e2e_testing/xfail_sets.py | 10 --------- .../TorchOnnxToTorch/simple_ops_g_to_p.mlir | 18 ++++++++++++++++ 3 files changed, 39 insertions(+), 10 deletions(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index a7bdddbc8d78..4d3711e2e60a 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1153,4 +1153,25 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, tensor, slope); return success(); }); + patterns.onOp("Mod", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value self, other; + int64_t fmod; + if (binder.tensorOperands(self, other) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(fmod, "fmod", 0)) { + return failure(); + } + + if (fmod) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, self, other); + return success(); + } + + rewriter.replaceOpWithNewOp( + binder.op, resultType, self, other); + return success(); + }); } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index fbd0f5bb6eb0..6c6666a28619 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -1611,8 +1611,6 @@ "ElementwiseOrTensorStaticShapeModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseRemainderTensorModule_Int_basic", - "ElementwiseFmodTensor_Float_basic", - "ElementwiseFmodTensor_Int_Float_basic", "ElementwiseFmodTensor_Int_basic", "EmptyStridedModule_basic", "EmptyStridedSizeIntStrideModule_basic", @@ -1908,14 +1906,6 @@ "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", "MaxPool2dWithIndicesStaticModule_basic", - # Failure - onnx_lowering: onnx.Mod - "ElementwiseRemainderScalarModule_Bool_basic", - "ElementwiseRemainderScalarModule_Int_basic", - "UnflattenIntNegativeOneDimStaticModule_basic", - "UnflattenIntNegativeOneSizeStaticModule_basic", - "UnflattenIntStaticModule_basic", - "UnflattenStaticModule_basic", - # Failure - onnx_lowering: onnx.OneHot "OneHotModule_basic", diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 9dceff316eaa..a348f40e3018 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -629,6 +629,24 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 // ----- +// CHECK-LABEL: func.func @test_mod_int64_fmod +func.func @test_mod_int64_fmod(%arg0: !torch.vtensor<[6],si64>, %arg1: !torch.vtensor<[6],si64>) -> !torch.vtensor<[6],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[6],si64>, !torch.vtensor<[6],si64> -> !torch.vtensor<[6],si64> + %0 = torch.operator "onnx.Mod"(%arg0, %arg1) {torch.onnx.fmod = 1 : si64} : (!torch.vtensor<[6],si64>, !torch.vtensor<[6],si64>) -> !torch.vtensor<[6],si64> + return %0 : !torch.vtensor<[6],si64> +} + +// ----- + +// CHECK-LABEL: func.func @test_mod_int64_no_fmod +func.func @test_mod_int64_no_fmod(%arg0: !torch.vtensor<[6],si64>, %arg1: !torch.vtensor<[6],si64>) -> !torch.vtensor<[6],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.remainder.Tensor %arg0, %arg1 : !torch.vtensor<[6],si64>, !torch.vtensor<[6],si64> -> !torch.vtensor<[6],si64> + %0 = torch.operator "onnx.Mod"(%arg0, %arg1) : (!torch.vtensor<[6],si64>, !torch.vtensor<[6],si64>) -> !torch.vtensor<[6],si64> + return %0 : !torch.vtensor<[6],si64> +} + +// ----- + // CHECK-LABEL: func.func @test_log func.func @test_log(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.log %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> From 895ea8663a7d350a0c46254b138018e0328a865f Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Mon, 18 Mar 2024 11:25:22 -0700 Subject: [PATCH 4/9] add llvm style guide --- docs/add_ops.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/add_ops.md b/docs/add_ops.md index 1805f1700b47..37dee90817db 100644 --- a/docs/add_ops.md +++ b/docs/add_ops.md @@ -85,7 +85,7 @@ Recent Turbine Camp Attendees, from recent to less recent - Sungsoon.Cho@amd.com ## Links - +- IMPORTANT: read the LLVM style guide: https://llvm.org/docs/CodingStandards.html#use-early-exits-and-continue-to-simplify-code - Tutorials - [Sungsoon's Shark Getting Started Google Doc](https://docs.google.com/document/d/1H79DwW_wnVzUU81EogwY5ueXgnl-QzKet1p2lnqPar4/edit?pli=1) - This document contains commands that would help you set up shark and run demos From 8b96727d0d4fd945dd5fe2f39742d5c07067c3c2 Mon Sep 17 00:00:00 2001 From: Yuanqiang Liu Date: Tue, 19 Mar 2024 21:18:54 +0800 Subject: [PATCH 5/9] [Stablehlo] lowering chlo to stablehlo in torch-to-stablehlo pipeline (#3037) as that stablehlo is better than chlo as the boundary between frontend compiler and backend compiler. --- lib/CMakeLists.txt | 2 +- lib/Dialect/TorchConversion/Transforms/CMakeLists.txt | 1 + lib/Dialect/TorchConversion/Transforms/Passes.cpp | 8 +++++++- lib/InitAll.cpp | 2 -- .../stablehlo_backends/linalg_on_tensors.py | 2 -- 5 files changed, 9 insertions(+), 6 deletions(-) diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index e4ba46138f34..2e2c108c9143 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -32,7 +32,7 @@ set(LinkedLibs ) if(TORCH_MLIR_ENABLE_STABLEHLO) -list(APPEND LinkedLibs StablehloPasses StablehloLinalgTransforms) +list(APPEND LinkedLibs StablehloLinalgTransforms) endif() if(TORCH_MLIR_ENABLE_REFBACKEND) diff --git a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt index 5858e29a496e..d01fdf06e618 100644 --- a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt +++ b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -20,6 +20,7 @@ set(LinkedLibs if(TORCH_MLIR_ENABLE_STABLEHLO) list(APPEND LinkedLibs StablehloOps + StablehloPasses ) endif() diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 55bedc1192eb..a1face8c8d79 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -25,6 +25,7 @@ #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #ifdef TORCH_MLIR_ENABLE_STABLEHLO +#include "stablehlo/transforms/Passes.h" #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #endif #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -134,9 +135,14 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline( void TorchConversion::createTorchBackendToStablehloBackendPipeline( OpPassManager &pm, const TorchConversion::StablehloBackendPipelineOptions &options) { - // Generate Stablehlo ops. + // Generate Stablehlo & Chlo ops. pm.addNestedPass(createConvertTorchToStablehloPass( options.enableStaticShape, options.enableI32Index)); + // Lowering Chlo ops to Stablehlo + pm.addNestedPass( + stablehlo::createChloLegalizeToStablehloPass()); + pm.addNestedPass(createCanonicalizerPass()); + // Lowering remained ops to Arith pm.addNestedPass(createConvertTorchToArithPass()); diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index eebfc940870c..ce29176c93c6 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -31,7 +31,6 @@ #ifdef TORCH_MLIR_ENABLE_STABLEHLO #include "stablehlo/conversions/linalg/transforms/Passes.h" -#include "stablehlo/transforms/Passes.h" #endif void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) { @@ -58,7 +57,6 @@ void mlir::torch::registerAllPasses() { mlir::torch::TMTensor::registerPasses(); #ifdef TORCH_MLIR_ENABLE_STABLEHLO - mlir::stablehlo::registerChloLegalizeToStablehloPass(); mlir::stablehlo::registerStablehloLegalizeToLinalgPass(); #endif diff --git a/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py index 7dee2041c724..efbf9dff2d84 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py +++ b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py @@ -18,8 +18,6 @@ # The pipeline of func.func passes that lower the STABLEHLO backend contract to the # Linalg-on-Tensors backend contract accepted by RefBackend. STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join([ - "func.func(chlo-legalize-to-stablehlo)", - "canonicalize", "stablehlo-legalize-to-linalg" ]) From 7a9608bb6972f76b1bab8f7a616262d486f32676 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Tue, 19 Mar 2024 15:35:05 -0500 Subject: [PATCH 6/9] [ONNX] Reduces onnx.Div sinceVersion to 7 (#3041) The only difference between version 7 and newer versions is support for different data types. We should allow this pattern to match as early as 7. Earlier versions have a more manual broadcast specification through attributes, so I did not include those versions. See: [onnx.Div docs](https://onnx.ai/onnx/operators/onnx__Div.html#l-onnx-doc-divl) --- lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index 364714136264..740eaadb957e 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -1388,7 +1388,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return failure(); }); - patterns.onOp("Div", 14, + patterns.onOp("Div", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; From df026927268907eb99d490de6141626ac5aa804b Mon Sep 17 00:00:00 2001 From: Abhishek-TyRnT <58800592+Abhishek-TyRnT@users.noreply.github.com> Date: Wed, 20 Mar 2024 03:49:29 +0530 Subject: [PATCH 7/9] Dynamic size support for flatten (#3005) Added support for dynamic shapes in `flattenusingints` op in tosa dialect. Due to this some Argmax tests pass This PR fixes this issue https://github.com/llvm/torch-mlir/issues/3004 The following tests pass after this PR ``` 1. "ArgmaxIntModule_basic" 2. "ArgmaxIntModule_multiple_maxs" 3. "ArgmaxModule_basic" ``` --- lib/Conversion/TorchToTosa/TorchToTosa.cpp | 12 +++++++----- projects/pt1/e2e_testing/xfail_sets.py | 5 +++++ .../torch_mlir_e2e_test/test_suite/basic.py | 19 +++++++++++++++++++ 3 files changed, 31 insertions(+), 5 deletions(-) diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index 93fe9dc1c4e8..9ee8ad895966 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2485,10 +2485,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a ranked tensor type auto selfType = adaptor.getSelf().getType().dyn_cast(); - if (!selfType || !selfType.hasStaticShape()) - return rewriter.notifyMatchFailure( - op, - "Only ranked tensor types with static shapes are currently supported"); + if (!selfType) + return rewriter.notifyMatchFailure(op, + "Only ranked tensor types supported"); int64_t selfRank = selfType.getRank(); @@ -2520,8 +2519,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } else { if (idx == start_dim) newShape.push_back(s.value()); - else + // Only updating when the shapes are static + else if (s.value() != kUnknownSize && newShape.back() != kUnknownSize) newShape.back() *= s.value(); + else + newShape.back() = kUnknownSize; } } diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 6c6666a28619..82d86e7391d8 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -885,6 +885,9 @@ "ArangeStartNegativeStepFloatModule_basic", "ArangeStartOutDtypeModule_basic", "ArangeStartStepFloatModule_basic", + "ArgmaxIntModule_basic", + "ArgmaxIntModule_multiple_maxs", + "ArgmaxModule_basic", "ArgmaxModule_keepDim", "ArgmaxModule_with_dim", "AtenComplex64Module_basic", @@ -1077,6 +1080,7 @@ "EmbeddingModuleI32Static_basic", "FlattenRank0Module_basic", "FlattenStaticModule_basic", + "FlattenDynamicModuleCollapseAll_basic", "FullLikeModuleFloat3DStatic_basic", "FullLikeModuleInt2DStatic_basic", "FullModuleDefaultDtype_basic", @@ -1292,6 +1296,7 @@ }) - { ### Test failing in make_fx_tosa but not in tosa + "FlattenDynamicModuleCollapseAll_basic", # Dynamic shape, has extra unsupported broadcast ops "Matmul_3d", "MatmulStaticBroadcast_basic", diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index c5ef92d41637..fba52e2e711c 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -391,6 +391,25 @@ def forward(self, x): def FlattenDynamicModule_basic(module, tu: TestUtils): module.forward(tu.rand(10, 3, 8, 9, 3, 4)) +class FlattenDynamicModuleCollapseAll(torch.nn.Module): + + def __init__(self): + super().__init__() + self.flat = torch.nn.Flatten(0) + + @export + @annotate_args([ + None, + ([-1, -1, -1, 9, 3, -1], torch.float32, True), + ]) + def forward(self, x): + return self.flat(x) + + +@register_test_case(module_factory=lambda: FlattenDynamicModuleCollapseAll()) +def FlattenDynamicModuleCollapseAll_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 3, 8, 9, 3, 4)) + # ============================================================================== From fe59f1ee0d052086da5dfc10445eda4161e99151 Mon Sep 17 00:00:00 2001 From: Aart Bik Date: Tue, 19 Mar 2024 15:59:07 -0700 Subject: [PATCH 8/9] [torch-mlir][sparse] higher dimension COO (#3042) Lift this from 2-dim only to n-dim for n>=2 --- python/torch_mlir/extras/fx_importer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 4eaeb7ac8dfb..ac4a04cfac66 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -297,11 +297,14 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str: assert dim == len(shape) blocksize = sparsity.blocksize - dims = ",".join(f"d{d}" for d in range(0, dim)) + dims = ",".join(f"d{d}" for d in range(dim)) if sparsity.layout is torch.sparse_coo: - assert sparse_dim == 2 and blocksize is None # TODO: deeper sparse dims - lvls = f"d{batch_dim}:compressed(nonunique),d{batch_dim+1}:singleton(soa)" + assert sparse_dim >= 2 and blocksize is None + trail_dim = batch_dim + sparse_dim - 1 + coords = ",".join(f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim+1, trail_dim)) + sep = "," if sparse_dim > 2 else "" + lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)" elif sparsity.layout is torch.sparse_csr: assert sparse_dim == 2 and blocksize is None lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed" @@ -322,7 +325,7 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str: ) if batch_dim > 0: - batch = ",".join(f"d{d}:dense" for d in range(0, batch_dim)) + batch = ",".join(f"d{d}:dense" for d in range(batch_dim)) lvls = f"{batch},{lvls}" if dense_dim > 0: From bdeca3b59c34364661d6c0bd72d200ed2250735d Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Fri, 26 Jul 2024 09:27:39 +0200 Subject: [PATCH 9/9] xfail: onnx: fix xpass --- projects/pt1/e2e_testing/xfail_sets.py | 1 - 1 file changed, 1 deletion(-) diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 93f6d515222e..4e89f2e71927 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -2304,7 +2304,6 @@ "AtenLinalgCrossDynamic_basic", # Only on feature/backport_ea1_ops - "AtenToDtypeModule_basic", "Conv1dNoPaddingGroupModule_basic", "ElementwiseAcosTensorIntModule_basic", "ElementwiseAsinTensorIntModule_basic",