From 9938abf25e1e7526ca7f43a8c49e9078c14fc55c Mon Sep 17 00:00:00 2001 From: "Xida Ren (Cedar)" Date: Thu, 26 Sep 2024 18:17:22 -0400 Subject: [PATCH] AtenCumprodOp (#3737) --- include/torch-mlir/Conversion/Utils/Utils.h | 2 + .../TorchToTMTensor/TorchToTMTensor.cpp | 75 +++++++++++++++++ lib/Conversion/Utils/Utils.cpp | 10 +++ .../Transforms/AbstractInterpLibrary.cpp | 22 +++++ projects/pt1/e2e_testing/xfail_sets.py | 21 +++++ .../build_tools/abstract_interp_lib_gen.py | 15 ++++ .../torch_mlir_e2e_test/test_suite/basic.py | 84 +++++++++++++++++++ 7 files changed, 229 insertions(+) diff --git a/include/torch-mlir/Conversion/Utils/Utils.h b/include/torch-mlir/Conversion/Utils/Utils.h index b76efe869a0f..d21dd5504dcd 100644 --- a/include/torch-mlir/Conversion/Utils/Utils.h +++ b/include/torch-mlir/Conversion/Utils/Utils.h @@ -40,6 +40,8 @@ Value createInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, Type elemTy); +Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes, + Type elemTy); Value castIntToIndex(OpBuilder &b, Location loc, Value v); diff --git a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp index b0b0b0df2ef0..94d7154115be 100644 --- a/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp +++ b/lib/Conversion/TorchToTMTensor/TorchToTMTensor.cpp @@ -1497,6 +1497,79 @@ class ConvertAtenSortOp : public OpConversionPattern { }; } // namespace +namespace { +class ConvertAtenCumprodOp : public OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + LogicalResult + matchAndRewrite(AtenCumprodOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + + Location loc = op.getLoc(); + Value input = adaptor.getSelf(); + auto resultType = cast( + getTypeConverter()->convertType(op->getResult(0).getType())); + Type elementType = resultType.getElementType(); + Type inputElementType = + cast(input.getType()).getElementType(); + + // Converting the input element type to the result's element type. + // The only possible mismatch would be when the input element type is an + // integer but not `si64`. Therefore, we directly convert the input to + // `si64`. Rest all cases are handled in the dtype definition for this op. + if (elementType != inputElementType) { + Value torchInput = convertTensorToDtype( + rewriter, loc, op.getSelf(), + rewriter.getIntegerType(64, IntegerType::Signed)); + input = typeConverter->materializeTargetConversion( + rewriter, loc, typeConverter->convertType(torchInput.getType()), + torchInput); + } + + int64_t inputRank = resultType.getRank(); + Value dtype = op.getDtype(); + if (!isa(dtype.getType())) + return rewriter.notifyMatchFailure( + op, "unsupported: dtype argument not supported"); + + int64_t dim; + if (!matchPattern(op.getDim(), m_TorchConstantInt(&dim))) + return rewriter.notifyMatchFailure( + op, "unimplemented: only constant dim value is supported"); + dim = toPositiveDim(dim, inputRank); + if (!isValidDim(dim, inputRank)) + return rewriter.notifyMatchFailure(op, "invalid dim"); + + SmallVector sizes = getTensorSizes(rewriter, loc, input); + Value output = createOneInitTensor(rewriter, loc, sizes, elementType); + output = rewriter.create(loc, resultType, output); + + SmallVector accSizes(sizes); + accSizes.erase(accSizes.begin() + dim); + SmallVector accStatic( + makeShapeTorchCompatible(resultType.getShape())); + accStatic.erase(accStatic.begin() + dim); + Value acc = createOneInitTensor(rewriter, loc, accSizes, elementType); + Type accType = + RankedTensorType::get(makeShapeLLVMCompatible(accStatic), elementType); + acc = rewriter.create(loc, accType, acc); + + Value result = createTMTensorScanOp( + rewriter, loc, input, output, acc, dim, /*inclusive=*/true, + [](OpBuilder &b, Location loc, Value input, Value acc) { + Value prod = + (isa(input.getType()) + ? b.create(loc, input, acc)->getResult(0) + : b.create(loc, input, acc)->getResult(0)); + b.create(loc, prod); + }); + + rewriter.replaceOpWithNewOp(op, resultType, result); + return success(); + } +}; +} // namespace + namespace { class ConvertAtenCumsumOp : public OpConversionPattern { public: @@ -2240,6 +2313,8 @@ class ConvertTorchToTMTensor patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); + target.addIllegalOp(); + patterns.add(typeConverter, context); target.addIllegalOp(); patterns.add(typeConverter, context); diff --git a/lib/Conversion/Utils/Utils.cpp b/lib/Conversion/Utils/Utils.cpp index 5ef0ab16963a..1a208f4ab127 100644 --- a/lib/Conversion/Utils/Utils.cpp +++ b/lib/Conversion/Utils/Utils.cpp @@ -138,6 +138,16 @@ Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes, return b.create(loc, c0, initTensor).getResult(0); } +Value createOneInitTensor(OpBuilder &b, Location loc, ValueRange sizes, + Type elemTy) { + Value initTensor = + b.create(loc, getAsOpFoldResult(sizes), elemTy); + RankedTensorType type = cast(initTensor.getType()); + Value c1 = + b.create(loc, b.getOneAttr(type.getElementType())); + return b.create(loc, c1, initTensor).getResult(0); +} + Value castIntToIndex(OpBuilder &b, Location loc, Value v) { assert(isa(v.getType()) && "must be called with integer type"); return b.createOrFold(loc, b.getIndexType(), v); diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 59cf69393ded..995a7df283fd 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -9134,6 +9134,9 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " func.func @\"__torch_mlir_shape_fn.aten.cumsum\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" +" func.func @\"__torch_mlir_shape_fn.aten.cumprod\"(%arg0: !torch.list, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.list {\n" +" return %arg0 : !torch.list\n" +" }\n" " func.func @\"__torch_mlir_shape_fn.aten.rand_like\"(%arg0: !torch.list, %arg1: !torch.optional, %arg2: !torch.optional, %arg3: !torch.optional, %arg4: !torch.optional, %arg5: !torch.optional) -> !torch.list {\n" " return %arg0 : !torch.list\n" " }\n" @@ -11844,6 +11847,25 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " return %1 : !torch.int\n" " }\n" +" func.func @\"__torch_mlir_dtype_fn.aten.cumprod\"(%arg0: !torch.tuple, %arg1: !torch.int, %arg2: !torch.optional) -> !torch.int {\n" +" %int4 = torch.constant.int 4\n" +" %none = torch.constant.none\n" +" %0 = torch.aten.__isnot__ %arg2, %none : !torch.optional, !torch.none -> !torch.bool\n" +" %1 = torch.prim.If %0 -> (!torch.int) {\n" +" %2 = torch.prim.unchecked_cast %arg2 : !torch.optional -> !torch.int\n" +" torch.prim.If.yield %2 : !torch.int\n" +" } else {\n" +" %2:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" +" %3 = func.call @__torch__.torch_mlir.jit_ir_importer.build_tools.library_generator.is_integer_dtype(%2#1) : (!torch.int) -> !torch.bool\n" +" %4 = torch.prim.If %3 -> (!torch.int) {\n" +" torch.prim.If.yield %int4 : !torch.int\n" +" } else {\n" +" torch.prim.If.yield %2#1 : !torch.int\n" +" }\n" +" torch.prim.If.yield %4 : !torch.int\n" +" }\n" +" return %1 : !torch.int\n" +" }\n" " func.func @\"__torch_mlir_dtype_fn.aten.detach\"(%arg0: !torch.tuple) -> !torch.int {\n" " %0:2 = torch.prim.TupleUnpack %arg0 : !torch.tuple -> !torch.int, !torch.int\n" " return %0#1 : !torch.int\n" diff --git a/projects/pt1/e2e_testing/xfail_sets.py b/projects/pt1/e2e_testing/xfail_sets.py index 3b3e4611ea6b..0e741d0de36b 100644 --- a/projects/pt1/e2e_testing/xfail_sets.py +++ b/projects/pt1/e2e_testing/xfail_sets.py @@ -79,6 +79,7 @@ #### General TorchDynamo/PyTorch errors # torch._dynamo.exc.Unsupported: Tensor.item "CumsumModule_basic", + "CumprodModule_basic", # TypeError: new_empty(): argument 'size' (position 1) must be tuple of ints, but found element of type NoneType at pos 0 # RuntimeError: Failed running call_function aten.convolution_backward(... # https://github.com/pytorch/pytorch/issues/89629 @@ -432,6 +433,7 @@ "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "CumsumModule_basic", + "CumprodModule_basic", "DeformConv2D_basic", "DivFloatModule_basic", "DivIntModule_basic", @@ -667,6 +669,10 @@ "ConvolutionBackwardModule2DStrided_basic", "ConvolutionBackwardModule2D_basic", "CumsumModule_basic", + "CumprodModule_basic", + "CumprodInputDtypeInt32Module_basic", + "CumprodStaticModule_basic", + "CumprodStaticNegativeDimModule_basic", "DeformConv2D_basic", "DeterminantBatchedModule_F32", "DeterminantDynamicModule_F32", @@ -1077,6 +1083,9 @@ "CumsumInputDtypeInt32Module_basic", "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", + "CumprodInputDtypeInt32Module_basic", + "CumprodStaticModule_basic", + "CumprodStaticNegativeDimModule_basic", "DetachModule_basic", "DivFloatModule_basic", "DivIntModule_basic", @@ -3105,6 +3114,10 @@ "CopyWithDifferentDTypesModule_basic", "CosineSimilarityStaticBroadcastModule_basic", "CumsumInputDtypeInt32Module_basic", + "CumprodModule_basic", + "CumprodInputDtypeInt32Module_basic", + "CumprodStaticModule_basic", + "CumprodStaticNegativeDimModule_basic", "ElementwiseAcosIntModule_basic", "ElementwiseAsinIntModule_basic", "ElementwiseAtanTensorIntModule_basic", @@ -3378,6 +3391,10 @@ "CumsumModule_basic", "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", + "CumprodModule_basic", + "CumprodInputDtypeInt32Module_basic", + "CumprodStaticModule_basic", + "CumprodStaticNegativeDimModule_basic", "DeformConv2D_basic", "DeterminantBatchedModule_F32", "DeterminantDynamicModule_F32", @@ -4110,6 +4127,10 @@ "CumsumModule_basic", "CumsumStaticModule_basic", "CumsumStaticNegativeDimModule_basic", + "CumprodModule_basic", + "CumprodInputDtypeInt32Module_basic", + "CumprodStaticModule_basic", + "CumprodStaticNegativeDimModule_basic", "DeformConv2D_basic", "DeterminantModule_F32", "DeterminantBatchedModule_F32", diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py index bc49757ee9d3..22fe8e299f07 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/abstract_interp_lib_gen.py @@ -1434,6 +1434,9 @@ def aten〇multinomial〡shape(self: List[int], num_samples: int, replacement: b def aten〇cumsum〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]: return self +def aten〇cumprod〡shape(self: List[int], dim: int, dtype: Optional[int] = None) -> List[int]: + return self + def aten〇rand_like〡shape(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]: return self @@ -2926,6 +2929,18 @@ def aten〇cumsum〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Opt return torch.int64 return self_dtype + +@check_dtype_function( + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0) + + _check_tensors_with_the_same_dtype(num_of_tensors=1, dim=0, dtype=torch.float32)) +def aten〇cumprod〡dtype(self_rank_dtype: Tuple[int, int], dim: int, dtype: Optional[int] = None) -> int: + if dtype is not None: + return dtype + self_rank, self_dtype = self_rank_dtype + if is_integer_dtype(self_dtype): + return torch.int64 + return self_dtype + @check_dtype_function(_check_tensors_with_the_same_dtype(num_of_tensors=1)) def aten〇detach〡dtype(self_rank_dtype: Tuple[int, int]) -> int: self_rank, self_dtype = self_rank_dtype diff --git a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py index cb6aa7fc15d7..ef20079b6f75 100644 --- a/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py +++ b/projects/pt1/python/torch_mlir_e2e_test/test_suite/basic.py @@ -4830,6 +4830,90 @@ def CumsumInputDtypeInt32Module_basic(module, tu: TestUtils): # ============================================================================== +class CumprodModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([-1, -1, -1], torch.float32, True), + ] + ) + def forward(self, val): + ones = torch.ones([1], dtype=torch.int32) + return torch.ops.aten.cumprod(val, ones.item()) + + +@register_test_case(module_factory=lambda: CumprodModule()) +def CumprodModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 7, 4)) + + +class CumprodStaticModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 7, 4], torch.float32, True), + ] + ) + def forward(self, val): + return torch.ops.aten.cumprod(val, 1) + + +@register_test_case(module_factory=lambda: CumprodStaticModule()) +def CumprodStaticModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 7, 4)) + + +class CumprodStaticNegativeDimModule(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 7, 4], torch.float32, True), + ] + ) + def forward(self, val): + return torch.ops.aten.cumprod(val, dim=-1) + + +@register_test_case(module_factory=lambda: CumprodStaticNegativeDimModule()) +def CumprodStaticNegativeDimModule_basic(module, tu: TestUtils): + module.forward(tu.rand(2, 7, 4)) + + +class CumprodInputDtypeInt32Module(torch.nn.Module): + def __init__(self): + super().__init__() + + @export + @annotate_args( + [ + None, + ([2, 7, 4], torch.int32, True), + ] + ) + def forward(self, val): + return torch.ops.aten.cumprod(val, 1) + + +@register_test_case(module_factory=lambda: CumprodInputDtypeInt32Module()) +def CumprodInputDtypeInt32Module_basic(module, tu: TestUtils): + module.forward(tu.randint(2, 7, 4).to(torch.int32)) + + +# ============================================================================== + + class AtenToDeviceModule(torch.nn.Module): def __init__(self): super().__init__()