diff --git a/docs/add_ops.md b/docs/add_ops.md index be939c4ed244..9e34551c0fe8 100644 --- a/docs/add_ops.md +++ b/docs/add_ops.md @@ -76,7 +76,7 @@ Helpful examples: `. Please don't just paste the generated tests - reference them to write your own ## Links - +- IMPORTANT: read the LLVM style guide: https://llvm.org/docs/CodingStandards.html#use-early-exits-and-continue-to-simplify-code - Tutorials - [Sungsoon's Shark Getting Started Google Doc](https://docs.google.com/document/d/1H79DwW_wnVzUU81EogwY5ueXgnl-QzKet1p2lnqPar4/edit?pli=1) - This document contains commands that would help you set up shark and run demos diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index e4ba46138f34..2e2c108c9143 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -32,7 +32,7 @@ set(LinkedLibs ) if(TORCH_MLIR_ENABLE_STABLEHLO) -list(APPEND LinkedLibs StablehloPasses StablehloLinalgTransforms) +list(APPEND LinkedLibs StablehloLinalgTransforms) endif() if(TORCH_MLIR_ENABLE_REFBACKEND) diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp index f998240b3472..4d856fca6521 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp @@ -43,8 +43,10 @@ static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) { switch (dtypeIntOnnx) { case 1: return 6; // float + case 6: + return 3; // int32 case 7: - return 5; // int64 + return 4; // int64 case 9: return 11; // bool case 10: @@ -1387,7 +1389,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF( return failure(); }); - patterns.onOp("Div", 14, + patterns.onOp("Div", 7, [](OpBinder binder, ConversionPatternRewriter &rewriter) { Torch::ValueTensorType resultType; Value lhs, rhs; diff --git a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp index 60738a579687..d37deb53adff 100644 --- a/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp +++ b/lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp @@ -1153,4 +1153,25 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP( binder.op, resultType, tensor, slope); return success(); }); + patterns.onOp("Mod", 13, + [](OpBinder binder, ConversionPatternRewriter &rewriter) { + Torch::ValueTensorType resultType; + Value self, other; + int64_t fmod; + if (binder.tensorOperands(self, other) || + binder.tensorResultType(resultType) || + binder.s64IntegerAttr(fmod, "fmod", 0)) { + return failure(); + } + + if (fmod) { + rewriter.replaceOpWithNewOp( + binder.op, resultType, self, other); + return success(); + } + + rewriter.replaceOpWithNewOp( + binder.op, resultType, self, other); + return success(); + }); } diff --git a/lib/Conversion/TorchToTosa/TorchToTosa.cpp b/lib/Conversion/TorchToTosa/TorchToTosa.cpp index a25bbe402a73..3f50a7f504dd 100644 --- a/lib/Conversion/TorchToTosa/TorchToTosa.cpp +++ b/lib/Conversion/TorchToTosa/TorchToTosa.cpp @@ -2631,10 +2631,9 @@ LogicalResult ConvertAtenOp::matchAndRewrite( // Not a ranked tensor type auto selfType = adaptor.getSelf().getType().dyn_cast(); - if (!selfType || !selfType.hasStaticShape()) - return rewriter.notifyMatchFailure( - op, - "Only ranked tensor types with static shapes are currently supported"); + if (!selfType) + return rewriter.notifyMatchFailure(op, + "Only ranked tensor types supported"); int64_t selfRank = selfType.getRank(); @@ -2666,8 +2665,11 @@ LogicalResult ConvertAtenOp::matchAndRewrite( } else { if (idx == start_dim) newShape.push_back(s.value()); - else + // Only updating when the shapes are static + else if (s.value() != kUnknownSize && newShape.back() != kUnknownSize) newShape.back() *= s.value(); + else + newShape.back() = kUnknownSize; } } diff --git a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt index 5858e29a496e..d01fdf06e618 100644 --- a/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt +++ b/lib/Dialect/TorchConversion/Transforms/CMakeLists.txt @@ -20,6 +20,7 @@ set(LinkedLibs if(TORCH_MLIR_ENABLE_STABLEHLO) list(APPEND LinkedLibs StablehloOps + StablehloPasses ) endif() diff --git a/lib/Dialect/TorchConversion/Transforms/Passes.cpp b/lib/Dialect/TorchConversion/Transforms/Passes.cpp index 18fbe7809e23..d8d1fe0b4c35 100644 --- a/lib/Dialect/TorchConversion/Transforms/Passes.cpp +++ b/lib/Dialect/TorchConversion/Transforms/Passes.cpp @@ -25,6 +25,7 @@ #include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h" #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" #ifdef TORCH_MLIR_ENABLE_STABLEHLO +#include "stablehlo/transforms/Passes.h" #include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h" #endif #include "torch-mlir/Dialect/Torch/Transforms/Passes.h" @@ -136,9 +137,14 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline( void TorchConversion::createTorchBackendToStablehloBackendPipeline( OpPassManager &pm, const TorchConversion::StablehloBackendPipelineOptions &options) { - // Generate Stablehlo ops. + // Generate Stablehlo & Chlo ops. pm.addNestedPass(createConvertTorchToStablehloPass( options.enableStaticShape, options.enableI32Index)); + // Lowering Chlo ops to Stablehlo + pm.addNestedPass( + stablehlo::createChloLegalizeToStablehloPass()); + pm.addNestedPass(createCanonicalizerPass()); + // Lowering remained ops to Arith pm.addNestedPass(createConvertTorchToArithPass()); diff --git a/lib/InitAll.cpp b/lib/InitAll.cpp index 1205d6343e43..17e746fb16d4 100644 --- a/lib/InitAll.cpp +++ b/lib/InitAll.cpp @@ -32,7 +32,6 @@ #ifdef TORCH_MLIR_ENABLE_STABLEHLO #include "stablehlo/conversions/linalg/transforms/Passes.h" -#include "stablehlo/transforms/Passes.h" #endif void mlir::torch::registerAllDialects(mlir::DialectRegistry ®istry) { @@ -60,7 +59,6 @@ void mlir::torch::registerAllPasses() { mlir::torch::TMTensor::registerPasses(); #ifdef TORCH_MLIR_ENABLE_STABLEHLO - mlir::stablehlo::registerChloLegalizeToStablehloPass(); mlir::stablehlo::registerStablehloLegalizeToLinalgPass(); #endif diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 15ddcdf1ecea..4e89f2e71927 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -954,6 +954,9 @@ "ArangeStartNegativeStepFloatModule_basic", "ArangeStartOutDtypeModule_basic", "ArangeStartStepFloatModule_basic", + "ArgmaxIntModule_basic", + "ArgmaxIntModule_multiple_maxs", + "ArgmaxModule_basic", "ArgmaxModule_keepDim", "ArgmaxModule_with_dim", "AtenComplex64Module_basic", @@ -1200,6 +1203,7 @@ "Fill_TensorFloat64WithInt64Static_basic", "FlattenRank0Module_basic", "FlattenStaticModule_basic", + "FlattenDynamicModuleCollapseAll_basic", "FullLikeModuleFloat3DStatic_basic", "FullLikeModuleInt2DStatic_basic", "FullModuleDefaultDtype_basic", @@ -1485,6 +1489,7 @@ }) - { ### Test failing in make_fx_tosa but not in tosa + "FlattenDynamicModuleCollapseAll_basic", # Dynamic shape, has extra unsupported broadcast ops "Matmul_3d", @@ -1806,8 +1811,6 @@ "ElementwiseOrTensorStaticShapeModule_basic", "ElementwiseQuantizePerTensorModule_basic", "ElementwiseRemainderTensorModule_Int_basic", - "ElementwiseFmodTensor_Float_basic", - "ElementwiseFmodTensor_Int_Float_basic", "ElementwiseFmodTensor_Int_basic", "EmptyStridedModule_basic", "EmptyStridedSizeIntStrideModule_basic", @@ -2074,10 +2077,7 @@ "BucketizeTensorOutInt32RightModule_basic", "ElementwiseToDtypeI64ToI8Module_basic", "ElementwiseToDtypeI64ToUI8Module_basic", - "HBC_basic", "QuantizedMLP_basic", - "TypeConversionI1ToI32Module_basic", - "TypeConversionI64ToI32Module_basic", # Failure - onnx_lowering: onnx.Clip "NormalizeModule_basic", @@ -2102,14 +2102,6 @@ "MaxPool2dWithIndicesNonDefaultPaddingModule_basic", "MaxPool2dWithIndicesStaticModule_basic", - # Failure - onnx_lowering: onnx.Mod - "ElementwiseRemainderScalarModule_Bool_basic", - "ElementwiseRemainderScalarModule_Int_basic", - "UnflattenIntNegativeOneDimStaticModule_basic", - "UnflattenIntNegativeOneSizeStaticModule_basic", - "UnflattenIntStaticModule_basic", - "UnflattenStaticModule_basic", - # Failure - onnx_lowering: onnx.OneHot "OneHotModule_basic", @@ -2312,7 +2304,6 @@ "AtenLinalgCrossDynamic_basic", # Only on feature/backport_ea1_ops - "AtenToDtypeModule_basic", "Conv1dNoPaddingGroupModule_basic", "ElementwiseAcosTensorIntModule_basic", "ElementwiseAsinTensorIntModule_basic", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index b27db645adc6..ffe4be8ed2ba 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception # Also available under a BSD-style license. See LICENSE. -from typing import List, Optional, Any, Tuple, Union +from typing import List, Optional, Any, Tuple, Union, Dict, Set import argparse import os @@ -1875,9 +1875,9 @@ def aten〇upsample_nearest2d〡shape(self: List[int], output_size: List[int], s def _check_tensors_with_the_same_dtype( num_of_tensors: Optional[int] = None, - tensor_shapes: Optional[list[tuple[int]]] = None, + tensor_shapes: Optional[List[Tuple[int]]] = None, tensor_device: Optional[torch.device] = None, - error_types: Optional[set[int]] = None, *args, **kwargs): + error_types: Optional[Set[int]] = None, *args, **kwargs): """Create invocations where all tensors have the same dtype. This function generates invocations with `num_of_tensors` tensors @@ -1909,10 +1909,10 @@ def _check_tensors_with_the_same_dtype( return invocations def _check_two_tensor_op( - tensor_shapes: Optional[list[tuple[int]]] = None, + tensor_shapes: Optional[List[Tuple[int]]] = None, tensor_device: Optional[torch.device] = None, - input_error_types: Optional[set[int]] = None, - output_error_types: Optional[set[int]] = None, **kwargs): + input_error_types: Optional[Set[int]] = None, + output_error_types: Optional[Set[int]] = None, **kwargs): """Generate invocations for basic two-tensor dtype functions. This helper function is meant to be used to check dtype functions that diff --git a/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py index 7dee2041c724..efbf9dff2d84 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py +++ b/projects/pt1/python/torch_mlir_e2e_test/stablehlo_backends/linalg_on_tensors.py @@ -18,8 +18,6 @@ # The pipeline of func.func passes that lower the STABLEHLO backend contract to the # Linalg-on-Tensors backend contract accepted by RefBackend. STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join([ - "func.func(chlo-legalize-to-stablehlo)", - "canonicalize", "stablehlo-legalize-to-linalg" ]) diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index f1e3700e0a4a..a83393851d32 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -391,6 +391,25 @@ def forward(self, x): def FlattenDynamicModule_basic(module, tu: TestUtils): module.forward(tu.rand(10, 3, 8, 9, 3, 4)) +class FlattenDynamicModuleCollapseAll(torch.nn.Module): + + def __init__(self): + super().__init__() + self.flat = torch.nn.Flatten(0) + + @export + @annotate_args([ + None, + ([-1, -1, -1, 9, 3, -1], torch.float32, True), + ]) + def forward(self, x): + return self.flat(x) + + +@register_test_case(module_factory=lambda: FlattenDynamicModuleCollapseAll()) +def FlattenDynamicModuleCollapseAll_basic(module, tu: TestUtils): + module.forward(tu.rand(10, 3, 8, 9, 3, 4)) + # ============================================================================== diff --git a/python/torch_mlir/extras/fx_importer.py b/python/torch_mlir/extras/fx_importer.py index 952b638c1988..ac4a04cfac66 100644 --- a/python/torch_mlir/extras/fx_importer.py +++ b/python/torch_mlir/extras/fx_importer.py @@ -276,7 +276,7 @@ class SparsityMeta: batch_dim: int sparse_dim: int dense_dim: int - blocksize: Optional[tuple[int, int]] + blocksize: Optional[Tuple[int, int]] pos_dtype: torch.dtype crd_dtype: torch.dtype @@ -297,11 +297,14 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str: assert dim == len(shape) blocksize = sparsity.blocksize - dims = ",".join(f"d{d}" for d in range(0, dim)) + dims = ",".join(f"d{d}" for d in range(dim)) if sparsity.layout is torch.sparse_coo: - assert sparse_dim == 2 and blocksize is None # TODO: deeper sparse dims - lvls = f"d{batch_dim}:compressed(nonunique),d{batch_dim+1}:singleton(soa)" + assert sparse_dim >= 2 and blocksize is None + trail_dim = batch_dim + sparse_dim - 1 + coords = ",".join(f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim+1, trail_dim)) + sep = "," if sparse_dim > 2 else "" + lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)" elif sparsity.layout is torch.sparse_csr: assert sparse_dim == 2 and blocksize is None lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed" @@ -322,7 +325,7 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str: ) if batch_dim > 0: - batch = ",".join(f"d{d}:dense" for d in range(0, batch_dim)) + batch = ",".join(f"d{d}:dense" for d in range(batch_dim)) lvls = f"{batch},{lvls}" if dense_dim > 0: @@ -489,8 +492,8 @@ def import_program( default policy is to capture them as frozen values. """ # Create lookaside table of placeholders/outputs. - placeholder_nodes: dict[str, Node] = {} - all_producer_nodes: dict[str, Node] = {} + placeholder_nodes: Dict[str, Node] = {} + all_producer_nodes: Dict[str, Node] = {} loc: Optional[Location] = None for node in prog.graph.nodes: if loc is None: @@ -522,15 +525,15 @@ def import_program( } # Additional bindings that we need to set up after the function is created. - mutable_buffer_target_producers: dict[str, str] = {} - constant_tensors: dict[Node, torch.Tensor] = {} - parameter_bindings: dict[Node, tuple[Any, InputInfo]] = {} - buffer_bindings: dict[Node, tuple[Any, InputInfo]] = {} + mutable_buffer_target_producers: Dict[str, str] = {} + constant_tensors: Dict[Node, torch.Tensor] = {} + parameter_bindings: Dict[Node, Tuple[Any, InputInfo]] = {} + buffer_bindings: Dict[Node, Tuple[Any, InputInfo]] = {} # Derive user outputs that we preserve. These will be nodes of the # producer for the output. - user_outputs: list[Node] = [] - user_output_types: list[IrType] = [] + user_outputs: List[Node] = [] + user_output_types: List[IrType] = [] for output_spec in sig.output_specs: kind = output_spec.kind arg = output_spec.arg @@ -548,8 +551,8 @@ def import_program( mutable_buffer_target_producers[output_spec.target] = arg.name # Derive user inputs. These will be op=='placeholder' nodes. - user_inputs: list[Node] = [] - user_input_types: list[IrType] = [] + user_inputs: List[Node] = [] + user_input_types: List[IrType] = [] for input_spec in sig.input_specs: arg = input_spec.arg if input_spec.kind == InputKind.USER_INPUT: @@ -700,7 +703,7 @@ def import_frozen_program( """ sig = prog.graph_signature state_dict = prog.state_dict - arg_replacements: dict[str, Any] = {} + arg_replacements: Dict[str, Any] = {} # If there is no "constants" attribute, consult the "state_dict". Otherwise, only look # at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969 @@ -1003,7 +1006,7 @@ def __init__( # constructs and returns a value. self._v: Dict[Union[Callable[[], Value], Tuple[torch_fx.Node, int]], Value] = {} # Map of node name to hook that should be called when it is produced. - self._on_node_produced: dict[str, Callable[[Value], None]] = {} + self._on_node_produced: Dict[str, Callable[[Value], None]] = {} # Statically multi-result nodes which we have de-tupled are noted here. # They will have their getitem calls short-circuited. self._multi_result_nodes: Set[torch_fx.Node] = set() @@ -1118,7 +1121,7 @@ def on_produced(value: Value): self._on_node_produced[info.mutable_producer_node_name] = on_produced - def return_node_values(self, loc, nodes: list[Node]): + def return_node_values(self, loc, nodes: List[Node]): with loc, InsertionPoint(self._b): operands = [self.resolve_node_value(n) for n in nodes] func_dialect.ReturnOp(operands, loc=loc) diff --git a/python/torch_mlir/extras/onnx_importer.py b/python/torch_mlir/extras/onnx_importer.py index 289e5722efce..f53922d2a818 100644 --- a/python/torch_mlir/extras/onnx_importer.py +++ b/python/torch_mlir/extras/onnx_importer.py @@ -33,7 +33,7 @@ "The onnx package (`pip install onnx`) is required to use the onnx importer" ) from e -from typing import Optional +from typing import Optional, List, Dict, Tuple from dataclasses import dataclass @@ -113,16 +113,16 @@ class GraphInfo: def __init__(self, model_info: ModelInfo, graph_proto: onnx.GraphProto): self.model_info = model_info self.graph_proto = graph_proto - self.initializer_map: dict[str, onnx.TensorProto] = { + self.initializer_map: Dict[str, onnx.TensorProto] = { n.name: n for n in graph_proto.initializer } - self.value_info_map: dict[str, onnx.ValueInfoProto] = { + self.value_info_map: Dict[str, onnx.ValueInfoProto] = { n.name: n for n in graph_proto.value_info } - self.declared_input_map: dict[str, onnx.ValueInfoProto] = { + self.declared_input_map: Dict[str, onnx.ValueInfoProto] = { n.name: n for n in graph_proto.input } - self.output_map: dict[str, onnx.ValueInfoProto] = { + self.output_map: Dict[str, onnx.ValueInfoProto] = { n.name: n for n in graph_proto.output } @@ -191,7 +191,7 @@ def __init__( self._gi = graph_info self._p = parent_op self._b = block - self._nv_map: dict[str, Value] = {} + self._nv_map: Dict[str, Value] = {} @classmethod def define_function( @@ -225,7 +225,7 @@ def _populate_graph_attrs(self, container_op: Operation): with container_op.context: i64_type = IntegerType.get_signed(64) default_opset_version = 0 - opset_versions: dict[str, IntegerAttr] = {} + opset_versions: Dict[str, IntegerAttr] = {} for opset_import in m.opset_import: if opset_import.domain: opset_versions[opset_import.domain] = IntegerAttr.get( @@ -335,7 +335,7 @@ def import_node(self, node: onnx.NodeProto): for output_name, output_value in zip(output_names, custom_op.results): self._nv_map[output_name] = output_value - def import_attributes(self, onnx_attrs: list[onnx.AttributeProto]): + def import_attributes(self, onnx_attrs: List[onnx.AttributeProto]): attrs = {} for onnx_attr in onnx_attrs: attr_type = onnx_attr.type @@ -358,14 +358,14 @@ def import_attributes(self, onnx_attrs: list[onnx.AttributeProto]): attrs[f"torch.onnx.{onnx_attr.name}"] = result return attrs - def count_regions(self, onnx_attrs: list[onnx.AttributeProto]): + def count_regions(self, onnx_attrs: List[onnx.AttributeProto]): count = 0 for onnx_attr in onnx_attrs: if onnx_attr.type == onnx.AttributeProto.AttributeType.GRAPH: count += 1 return count - def import_regions(self, onnx_attrs: list[onnx.AttributeProto], op): + def import_regions(self, onnx_attrs: List[onnx.AttributeProto], op): attr_map = {} for onnx_attr in onnx_attrs: attr_type = onnx_attr.type @@ -458,10 +458,10 @@ class ContextCache: def __init__(self, context: Context): self._c = context - self._elem_type_map: dict[int, IrType] = {} - self._list_type_map:dict[str, IrType] = {} - self._optional_type_map:dict[str, IrType] = {} - self._vtensor_type_map: dict[tuple[tuple[Optional[int]], IrType], IrType] = {} + self._elem_type_map: Dict[int, IrType] = {} + self._list_type_map:Dict[str, IrType] = {} + self._optional_type_map:Dict[str, IrType] = {} + self._vtensor_type_map: Dict[Tuple[Tuple[Optional[int]], IrType], IrType] = {} def tensor_element_type(self, elem_type: int) -> IrType: t = self._elem_type_map.get(elem_type) @@ -539,7 +539,7 @@ def get_optional_element_type(self, tp: onnx.TypeProto) -> IrType: f"Unsupport optional element type") def get_vtensor_type( - self, dims: tuple[Optional[int]], element_type: IrType + self, dims: Tuple[Optional[int]], element_type: IrType ) -> IrType: key = (dims, element_type) t = self._vtensor_type_map.get(key) diff --git a/setup.py b/setup.py index 4863a9807522..3cd5f2eca2d4 100644 --- a/setup.py +++ b/setup.py @@ -229,6 +229,7 @@ def build_extension(self, ext): "build_py": CMakeBuild, }, ext_modules=EXT_MODULES, + python_requires=">=3.8", install_requires=INSTALL_REQUIRES, extras_require={ "onnx": [ diff --git a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir index 9dceff316eaa..a348f40e3018 100644 --- a/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir +++ b/test/Conversion/TorchOnnxToTorch/simple_ops_g_to_p.mlir @@ -629,6 +629,24 @@ func.func @test_globalaveragepool_precomputed(%arg0: !torch.vtensor<[1,1,3,3],f3 // ----- +// CHECK-LABEL: func.func @test_mod_int64_fmod +func.func @test_mod_int64_fmod(%arg0: !torch.vtensor<[6],si64>, %arg1: !torch.vtensor<[6],si64>) -> !torch.vtensor<[6],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.fmod.Tensor %arg0, %arg1 : !torch.vtensor<[6],si64>, !torch.vtensor<[6],si64> -> !torch.vtensor<[6],si64> + %0 = torch.operator "onnx.Mod"(%arg0, %arg1) {torch.onnx.fmod = 1 : si64} : (!torch.vtensor<[6],si64>, !torch.vtensor<[6],si64>) -> !torch.vtensor<[6],si64> + return %0 : !torch.vtensor<[6],si64> +} + +// ----- + +// CHECK-LABEL: func.func @test_mod_int64_no_fmod +func.func @test_mod_int64_no_fmod(%arg0: !torch.vtensor<[6],si64>, %arg1: !torch.vtensor<[6],si64>) -> !torch.vtensor<[6],si64> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { + // CHECK: torch.aten.remainder.Tensor %arg0, %arg1 : !torch.vtensor<[6],si64>, !torch.vtensor<[6],si64> -> !torch.vtensor<[6],si64> + %0 = torch.operator "onnx.Mod"(%arg0, %arg1) : (!torch.vtensor<[6],si64>, !torch.vtensor<[6],si64>) -> !torch.vtensor<[6],si64> + return %0 : !torch.vtensor<[6],si64> +} + +// ----- + // CHECK-LABEL: func.func @test_log func.func @test_log(%arg0: !torch.vtensor<[3,4,5],f32>) -> !torch.vtensor<[3,4,5],f32> attributes {torch.onnx_meta.ir_version = 7 : si64, torch.onnx_meta.opset_version = 13 : si64, torch.onnx_meta.producer_name = "backend-test", torch.onnx_meta.producer_version = ""} { // CHECK: torch.aten.log %arg0 : !torch.vtensor<[3,4,5],f32> -> !torch.vtensor<[3,4,5],f32> diff --git a/test/python/fx_importer/sparse_test.py b/test/python/fx_importer/sparse_test.py index 40c633cfc778..c76cd8584a96 100644 --- a/test/python/fx_importer/sparse_test.py +++ b/test/python/fx_importer/sparse_test.py @@ -7,7 +7,7 @@ # UNSUPPORTED: true # RUN: %PYTHON %s | FileCheck %s -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, Tuple, Dict import torch import torch.export @@ -82,7 +82,7 @@ def sparse_metadata(a: torch.Tensor) -> SparsityMeta: def sparse_export( - f: Callable, args: tuple[Any, ...], kwargs: Optional[dict[str, Any]] = None + f: Callable, args: Tuple[Any, ...], kwargs: Optional[Dict[str, Any]] = None ) -> torch.export.ExportedProgram: """ This is a ***temporary*** wrapper around `torch.export.export`