Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auto Bump (no conflicts/test failures) (2) #205

Closed
wants to merge 10 commits into from
2 changes: 1 addition & 1 deletion docs/add_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ Helpful examples:
`. Please don't just paste the generated tests - reference them to write your own

## Links

- IMPORTANT: read the LLVM style guide: https://llvm.org/docs/CodingStandards.html#use-early-exits-and-continue-to-simplify-code
- Tutorials
- [Sungsoon's Shark Getting Started Google Doc](https://docs.google.com/document/d/1H79DwW_wnVzUU81EogwY5ueXgnl-QzKet1p2lnqPar4/edit?pli=1)
- This document contains commands that would help you set up shark and run demos
Expand Down
2 changes: 1 addition & 1 deletion lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ set(LinkedLibs
)

if(TORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND LinkedLibs StablehloPasses StablehloLinalgTransforms)
list(APPEND LinkedLibs StablehloLinalgTransforms)
endif()

if(TORCH_MLIR_ENABLE_REFBACKEND)
Expand Down
6 changes: 4 additions & 2 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainAtoF.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,10 @@ static int64_t onnxDtypeIntToTorchDtypeInt(int64_t dtypeIntOnnx) {
switch (dtypeIntOnnx) {
case 1:
return 6; // float
case 6:
return 3; // int32
case 7:
return 5; // int64
return 4; // int64
case 9:
return 11; // bool
case 10:
Expand Down Expand Up @@ -1387,7 +1389,7 @@ void mlir::torch::onnx_c::populateDefaultDomainAtoF(

return failure();
});
patterns.onOp("Div", 14,
patterns.onOp("Div", 7,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value lhs, rhs;
Expand Down
21 changes: 21 additions & 0 deletions lib/Conversion/TorchOnnxToTorch/DefaultDomainGtoP.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1153,4 +1153,25 @@ void mlir::torch::onnx_c::populateDefaultDomainGtoP(
binder.op, resultType, tensor, slope);
return success();
});
patterns.onOp("Mod", 13,
[](OpBinder binder, ConversionPatternRewriter &rewriter) {
Torch::ValueTensorType resultType;
Value self, other;
int64_t fmod;
if (binder.tensorOperands(self, other) ||
binder.tensorResultType(resultType) ||
binder.s64IntegerAttr(fmod, "fmod", 0)) {
return failure();
}

if (fmod) {
rewriter.replaceOpWithNewOp<Torch::AtenFmodTensorOp>(
binder.op, resultType, self, other);
return success();
}

rewriter.replaceOpWithNewOp<Torch::AtenRemainderTensorOp>(
binder.op, resultType, self, other);
return success();
});
}
12 changes: 7 additions & 5 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2631,10 +2631,9 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(

// Not a ranked tensor type
auto selfType = adaptor.getSelf().getType().dyn_cast<RankedTensorType>();
if (!selfType || !selfType.hasStaticShape())
return rewriter.notifyMatchFailure(
op,
"Only ranked tensor types with static shapes are currently supported");
if (!selfType)
return rewriter.notifyMatchFailure(op,
"Only ranked tensor types supported");

int64_t selfRank = selfType.getRank();

Expand Down Expand Up @@ -2666,8 +2665,11 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
} else {
if (idx == start_dim)
newShape.push_back(s.value());
else
// Only updating when the shapes are static
else if (s.value() != kUnknownSize && newShape.back() != kUnknownSize)
newShape.back() *= s.value();
else
newShape.back() = kUnknownSize;
}
}

Expand Down
1 change: 1 addition & 0 deletions lib/Dialect/TorchConversion/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ set(LinkedLibs
if(TORCH_MLIR_ENABLE_STABLEHLO)
list(APPEND LinkedLibs
StablehloOps
StablehloPasses
)
endif()

Expand Down
8 changes: 7 additions & 1 deletion lib/Dialect/TorchConversion/Transforms/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "torch-mlir/Conversion/TorchToTosa/TorchToTosa.h"
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "stablehlo/transforms/Passes.h"
#include "torch-mlir/Conversion/TorchToStablehlo/TorchToStablehlo.h"
#endif
#include "torch-mlir/Dialect/Torch/Transforms/Passes.h"
Expand Down Expand Up @@ -136,9 +137,14 @@ void TorchConversion::createTorchBackendToTosaBackendPipeline(
void TorchConversion::createTorchBackendToStablehloBackendPipeline(
OpPassManager &pm,
const TorchConversion::StablehloBackendPipelineOptions &options) {
// Generate Stablehlo ops.
// Generate Stablehlo & Chlo ops.
pm.addNestedPass<func::FuncOp>(createConvertTorchToStablehloPass(
options.enableStaticShape, options.enableI32Index));
// Lowering Chlo ops to Stablehlo
pm.addNestedPass<func::FuncOp>(
stablehlo::createChloLegalizeToStablehloPass());
pm.addNestedPass<func::FuncOp>(createCanonicalizerPass());

// Lowering remained ops to Arith
pm.addNestedPass<func::FuncOp>(createConvertTorchToArithPass());

Expand Down
2 changes: 0 additions & 2 deletions lib/InitAll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
#include "stablehlo/conversions/linalg/transforms/Passes.h"
#include "stablehlo/transforms/Passes.h"
#endif

void mlir::torch::registerAllDialects(mlir::DialectRegistry &registry) {
Expand Down Expand Up @@ -60,7 +59,6 @@ void mlir::torch::registerAllPasses() {
mlir::torch::TMTensor::registerPasses();

#ifdef TORCH_MLIR_ENABLE_STABLEHLO
mlir::stablehlo::registerChloLegalizeToStablehloPass();
mlir::stablehlo::registerStablehloLegalizeToLinalgPass();
#endif

Expand Down
19 changes: 5 additions & 14 deletions projects/pt1/e2e_testing/xfail_sets.py
Original file line number Diff line number Diff line change
Expand Up @@ -954,6 +954,9 @@
"ArangeStartNegativeStepFloatModule_basic",
"ArangeStartOutDtypeModule_basic",
"ArangeStartStepFloatModule_basic",
"ArgmaxIntModule_basic",
"ArgmaxIntModule_multiple_maxs",
"ArgmaxModule_basic",
"ArgmaxModule_keepDim",
"ArgmaxModule_with_dim",
"AtenComplex64Module_basic",
Expand Down Expand Up @@ -1200,6 +1203,7 @@
"Fill_TensorFloat64WithInt64Static_basic",
"FlattenRank0Module_basic",
"FlattenStaticModule_basic",
"FlattenDynamicModuleCollapseAll_basic",
"FullLikeModuleFloat3DStatic_basic",
"FullLikeModuleInt2DStatic_basic",
"FullModuleDefaultDtype_basic",
Expand Down Expand Up @@ -1485,6 +1489,7 @@
}) - {
### Test failing in make_fx_tosa but not in tosa

"FlattenDynamicModuleCollapseAll_basic",
# Dynamic shape, has extra unsupported broadcast ops
"Matmul_3d",

Expand Down Expand Up @@ -1806,8 +1811,6 @@
"ElementwiseOrTensorStaticShapeModule_basic",
"ElementwiseQuantizePerTensorModule_basic",
"ElementwiseRemainderTensorModule_Int_basic",
"ElementwiseFmodTensor_Float_basic",
"ElementwiseFmodTensor_Int_Float_basic",
"ElementwiseFmodTensor_Int_basic",
"EmptyStridedModule_basic",
"EmptyStridedSizeIntStrideModule_basic",
Expand Down Expand Up @@ -2074,10 +2077,7 @@
"BucketizeTensorOutInt32RightModule_basic",
"ElementwiseToDtypeI64ToI8Module_basic",
"ElementwiseToDtypeI64ToUI8Module_basic",
"HBC_basic",
"QuantizedMLP_basic",
"TypeConversionI1ToI32Module_basic",
"TypeConversionI64ToI32Module_basic",

# Failure - onnx_lowering: onnx.Clip
"NormalizeModule_basic",
Expand All @@ -2102,14 +2102,6 @@
"MaxPool2dWithIndicesNonDefaultPaddingModule_basic",
"MaxPool2dWithIndicesStaticModule_basic",

# Failure - onnx_lowering: onnx.Mod
"ElementwiseRemainderScalarModule_Bool_basic",
"ElementwiseRemainderScalarModule_Int_basic",
"UnflattenIntNegativeOneDimStaticModule_basic",
"UnflattenIntNegativeOneSizeStaticModule_basic",
"UnflattenIntStaticModule_basic",
"UnflattenStaticModule_basic",

# Failure - onnx_lowering: onnx.OneHot
"OneHotModule_basic",

Expand Down Expand Up @@ -2312,7 +2304,6 @@
"AtenLinalgCrossDynamic_basic",

# Only on feature/backport_ea1_ops
"AtenToDtypeModule_basic",
"Conv1dNoPaddingGroupModule_basic",
"ElementwiseAcosTensorIntModule_basic",
"ElementwiseAsinTensorIntModule_basic",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
# Also available under a BSD-style license. See LICENSE.

from typing import List, Optional, Any, Tuple, Union
from typing import List, Optional, Any, Tuple, Union, Dict, Set
import argparse
import os

Expand Down Expand Up @@ -1875,9 +1875,9 @@ def aten〇upsample_nearest2d〡shape(self: List[int], output_size: List[int], s

def _check_tensors_with_the_same_dtype(
num_of_tensors: Optional[int] = None,
tensor_shapes: Optional[list[tuple[int]]] = None,
tensor_shapes: Optional[List[Tuple[int]]] = None,
tensor_device: Optional[torch.device] = None,
error_types: Optional[set[int]] = None, *args, **kwargs):
error_types: Optional[Set[int]] = None, *args, **kwargs):
"""Create invocations where all tensors have the same dtype.

This function generates invocations with `num_of_tensors` tensors
Expand Down Expand Up @@ -1909,10 +1909,10 @@ def _check_tensors_with_the_same_dtype(
return invocations

def _check_two_tensor_op(
tensor_shapes: Optional[list[tuple[int]]] = None,
tensor_shapes: Optional[List[Tuple[int]]] = None,
tensor_device: Optional[torch.device] = None,
input_error_types: Optional[set[int]] = None,
output_error_types: Optional[set[int]] = None, **kwargs):
input_error_types: Optional[Set[int]] = None,
output_error_types: Optional[Set[int]] = None, **kwargs):
"""Generate invocations for basic two-tensor dtype functions.

This helper function is meant to be used to check dtype functions that
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
# The pipeline of func.func passes that lower the STABLEHLO backend contract to the
# Linalg-on-Tensors backend contract accepted by RefBackend.
STABLEHLO_TO_LINALG_FUNC_PIPELINE = ",".join([
"func.func(chlo-legalize-to-stablehlo)",
"canonicalize",
"stablehlo-legalize-to-linalg"
])

Expand Down
19 changes: 19 additions & 0 deletions projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,25 @@ def forward(self, x):
def FlattenDynamicModule_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 3, 8, 9, 3, 4))

class FlattenDynamicModuleCollapseAll(torch.nn.Module):

def __init__(self):
super().__init__()
self.flat = torch.nn.Flatten(0)

@export
@annotate_args([
None,
([-1, -1, -1, 9, 3, -1], torch.float32, True),
])
def forward(self, x):
return self.flat(x)


@register_test_case(module_factory=lambda: FlattenDynamicModuleCollapseAll())
def FlattenDynamicModuleCollapseAll_basic(module, tu: TestUtils):
module.forward(tu.rand(10, 3, 8, 9, 3, 4))


# ==============================================================================

Expand Down
39 changes: 21 additions & 18 deletions python/torch_mlir/extras/fx_importer.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ class SparsityMeta:
batch_dim: int
sparse_dim: int
dense_dim: int
blocksize: Optional[tuple[int, int]]
blocksize: Optional[Tuple[int, int]]
pos_dtype: torch.dtype
crd_dtype: torch.dtype

Expand All @@ -297,11 +297,14 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
assert dim == len(shape)
blocksize = sparsity.blocksize

dims = ",".join(f"d{d}" for d in range(0, dim))
dims = ",".join(f"d{d}" for d in range(dim))

if sparsity.layout is torch.sparse_coo:
assert sparse_dim == 2 and blocksize is None # TODO: deeper sparse dims
lvls = f"d{batch_dim}:compressed(nonunique),d{batch_dim+1}:singleton(soa)"
assert sparse_dim >= 2 and blocksize is None
trail_dim = batch_dim + sparse_dim - 1
coords = ",".join(f"d{d}:singleton(nonunique,soa)" for d in range(batch_dim+1, trail_dim))
sep = "," if sparse_dim > 2 else ""
lvls = f"d{batch_dim}:compressed(nonunique),{coords}{sep}d{trail_dim}:singleton(soa)"
elif sparsity.layout is torch.sparse_csr:
assert sparse_dim == 2 and blocksize is None
lvls = f"d{batch_dim}:dense,d{batch_dim+1}:compressed"
Expand All @@ -322,7 +325,7 @@ def sparsity_encoding(shape: torch.Size, sparsity: SparsityMeta) -> str:
)

if batch_dim > 0:
batch = ",".join(f"d{d}:dense" for d in range(0, batch_dim))
batch = ",".join(f"d{d}:dense" for d in range(batch_dim))
lvls = f"{batch},{lvls}"

if dense_dim > 0:
Expand Down Expand Up @@ -489,8 +492,8 @@ def import_program(
default policy is to capture them as frozen values.
"""
# Create lookaside table of placeholders/outputs.
placeholder_nodes: dict[str, Node] = {}
all_producer_nodes: dict[str, Node] = {}
placeholder_nodes: Dict[str, Node] = {}
all_producer_nodes: Dict[str, Node] = {}
loc: Optional[Location] = None
for node in prog.graph.nodes:
if loc is None:
Expand Down Expand Up @@ -522,15 +525,15 @@ def import_program(
}

# Additional bindings that we need to set up after the function is created.
mutable_buffer_target_producers: dict[str, str] = {}
constant_tensors: dict[Node, torch.Tensor] = {}
parameter_bindings: dict[Node, tuple[Any, InputInfo]] = {}
buffer_bindings: dict[Node, tuple[Any, InputInfo]] = {}
mutable_buffer_target_producers: Dict[str, str] = {}
constant_tensors: Dict[Node, torch.Tensor] = {}
parameter_bindings: Dict[Node, Tuple[Any, InputInfo]] = {}
buffer_bindings: Dict[Node, Tuple[Any, InputInfo]] = {}

# Derive user outputs that we preserve. These will be nodes of the
# producer for the output.
user_outputs: list[Node] = []
user_output_types: list[IrType] = []
user_outputs: List[Node] = []
user_output_types: List[IrType] = []
for output_spec in sig.output_specs:
kind = output_spec.kind
arg = output_spec.arg
Expand All @@ -548,8 +551,8 @@ def import_program(
mutable_buffer_target_producers[output_spec.target] = arg.name

# Derive user inputs. These will be op=='placeholder' nodes.
user_inputs: list[Node] = []
user_input_types: list[IrType] = []
user_inputs: List[Node] = []
user_input_types: List[IrType] = []
for input_spec in sig.input_specs:
arg = input_spec.arg
if input_spec.kind == InputKind.USER_INPUT:
Expand Down Expand Up @@ -700,7 +703,7 @@ def import_frozen_program(
"""
sig = prog.graph_signature
state_dict = prog.state_dict
arg_replacements: dict[str, Any] = {}
arg_replacements: Dict[str, Any] = {}

# If there is no "constants" attribute, consult the "state_dict". Otherwise, only look
# at "constants". Relevant upstream patch: https://github.com/pytorch/pytorch/pull/118969
Expand Down Expand Up @@ -1003,7 +1006,7 @@ def __init__(
# constructs and returns a value.
self._v: Dict[Union[Callable[[], Value], Tuple[torch_fx.Node, int]], Value] = {}
# Map of node name to hook that should be called when it is produced.
self._on_node_produced: dict[str, Callable[[Value], None]] = {}
self._on_node_produced: Dict[str, Callable[[Value], None]] = {}
# Statically multi-result nodes which we have de-tupled are noted here.
# They will have their getitem calls short-circuited.
self._multi_result_nodes: Set[torch_fx.Node] = set()
Expand Down Expand Up @@ -1118,7 +1121,7 @@ def on_produced(value: Value):

self._on_node_produced[info.mutable_producer_node_name] = on_produced

def return_node_values(self, loc, nodes: list[Node]):
def return_node_values(self, loc, nodes: List[Node]):
with loc, InsertionPoint(self._b):
operands = [self.resolve_node_value(n) for n in nodes]
func_dialect.ReturnOp(operands, loc=loc)
Expand Down
Loading
Loading