From 85d9d6266f0c27bca2ac2b9906a6f00e34941a77 Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Tue, 1 Oct 2024 16:18:52 +0000 Subject: [PATCH 1/5] Add torch.aten.mul.float_int via torch_ods_gen.py and implment its folding. --- .../Dialect/Torch/IR/GeneratedTorchOps.td | 25 +++++++++++++++++++ lib/Conversion/TorchToArith/TorchToArith.cpp | 7 ++++-- lib/Dialect/Torch/IR/TorchOps.cpp | 12 +++++++++ .../Transforms/AbstractInterpLibrary.cpp | 4 +-- .../build_tools/torch_ods_gen.py | 1 + test/Dialect/Torch/canonicalize.mlir | 10 ++++++++ 6 files changed, 55 insertions(+), 4 deletions(-) diff --git a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td index c9329ccb895d..0b52c997f7ef 100644 --- a/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td +++ b/include/torch-mlir/Dialect/Torch/IR/GeneratedTorchOps.td @@ -15456,6 +15456,31 @@ def Torch_AtenAddFloatIntOp : Torch_Op<"aten.add.float_int", [ let hasFolder = 1; } +def Torch_AtenMulFloatIntOp : Torch_Op<"aten.mul.float_int", [ + AllowsTypeRefinement, + HasValueSemantics, + ReadOnly + ]> { + let summary = "Generated op for `aten::mul.float_int : (float, int) -> (float)`"; + let arguments = (ins + Torch_FloatType:$a, + Torch_IntType:$b + ); + let results = (outs + Torch_FloatType:$result + ); + let hasCustomAssemblyFormat = 1; + let extraClassDefinition = [{ + ParseResult AtenMulFloatIntOp::parse(OpAsmParser &parser, OperationState &result) { + return parseDefaultTorchOp(parser, result, 2, 1); + } + void AtenMulFloatIntOp::print(OpAsmPrinter &printer) { + printDefaultTorchOp(printer, *this, 2, 1); + } + }]; + let hasFolder = 1; +} + def Torch_AtenSubFloatOp : Torch_Op<"aten.sub.float", [ AllowsTypeRefinement, HasValueSemantics, diff --git a/lib/Conversion/TorchToArith/TorchToArith.cpp b/lib/Conversion/TorchToArith/TorchToArith.cpp index a1af190e460a..ca388a2b97d9 100644 --- a/lib/Conversion/TorchToArith/TorchToArith.cpp +++ b/lib/Conversion/TorchToArith/TorchToArith.cpp @@ -74,7 +74,8 @@ class ConvertAtenBinaryOp : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { Value a = adaptor.getA(); Value b = adaptor.getB(); - if (llvm::is_one_of::value) + if (llvm::is_one_of::value || + llvm::is_one_of::value) b = convertScalarToDtype(rewriter, op.getLoc(), b, a.getType()); rewriter.template replaceOpWithNewOp(op, a, b); return success(); @@ -467,7 +468,7 @@ class ConvertTorchToArith patterns.add(typeConverter, context); target.addIllegalOp(); + AtenMulIntOp, AtenMulFloatIntOp>(); patterns.add>( typeConverter, context); patterns.add>( @@ -476,6 +477,8 @@ class ConvertTorchToArith typeConverter, context); patterns.add>( typeConverter, context); + patterns.add>( + typeConverter, context); target.addIllegalOp(); patterns.add>( typeConverter, context); diff --git a/lib/Dialect/Torch/IR/TorchOps.cpp b/lib/Dialect/Torch/IR/TorchOps.cpp index bed228671de1..d0f7c0573e04 100644 --- a/lib/Dialect/Torch/IR/TorchOps.cpp +++ b/lib/Dialect/Torch/IR/TorchOps.cpp @@ -3979,6 +3979,18 @@ OpFoldResult AtenAddFloatIntOp::fold(FoldAdaptor adaptor) { adaptor.getOperands(), [](double a, double b) { return a + b; }); } +//===----------------------------------------------------------------------===// +// AtenMulFloatIntOp +//===----------------------------------------------------------------------===// + +OpFoldResult AtenMulFloatIntOp::fold(FoldAdaptor adaptor) { + if (!adaptor.getA() || !adaptor.getB()) { + return nullptr; + } + return atenBinaryFloatOperatorFoldHelper( + adaptor.getOperands(), [](double a, double b) { return a * b; }); +} + //===----------------------------------------------------------------------===// // AtenPowIntFloatOp //===----------------------------------------------------------------------===// diff --git a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp index 995a7df283fd..feef150ad47f 100644 --- a/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp +++ b/lib/Dialect/Torch/Transforms/AbstractInterpLibrary.cpp @@ -6897,12 +6897,12 @@ StringRef mlir::torch::Torch::getAbstractInterpLibrary() { " }\n" " %11 = torch.aten.__getitem__.t %9, %int0 : !torch.list, !torch.int -> !torch.float\n" " %12 = torch.aten.__getitem__.t %arg0, %int2 : !torch.list, !torch.int -> !torch.int\n" -" %13 = torch.operator \"aten.mul.float_int\"(%11, %12) : (!torch.float, !torch.int) -> !torch.float \n" +" %13 = torch.aten.mul.float_int %11, %12 : !torch.float, !torch.int -> !torch.float\n" " %14 = torch.aten.Int.float %13 : !torch.float -> !torch.int\n" " %15 = torch.aten.append.t %3, %14 : !torch.list, !torch.int -> !torch.list\n" " %16 = torch.aten.__getitem__.t %9, %int1 : !torch.list, !torch.int -> !torch.float\n" " %17 = torch.aten.__getitem__.t %arg0, %int3 : !torch.list, !torch.int -> !torch.int\n" -" %18 = torch.operator \"aten.mul.float_int\"(%16, %17) : (!torch.float, !torch.int) -> !torch.float \n" +" %18 = torch.aten.mul.float_int %16, %17 : !torch.float, !torch.int -> !torch.float\n" " %19 = torch.aten.Int.float %18 : !torch.float -> !torch.int\n" " %20 = torch.aten.append.t %3, %19 : !torch.list, !torch.int -> !torch.list\n" " torch.prim.If.yield %true, %3 : !torch.bool, !torch.list\n" diff --git a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py index f3227f29b5ce..bcbfb00b1f81 100644 --- a/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py +++ b/projects/pt1/python/torch_mlir/jit_ir_importer/build_tools/torch_ods_gen.py @@ -1088,6 +1088,7 @@ def emit_with_mutating_variants(key, **kwargs): emit("aten::neg.int : (int) -> (int)", has_folder=True) emit("aten::log.int : (int) -> (float)") emit("aten::add.float_int : (float, int) -> (float)", has_folder=True) + emit("aten::mul.float_int : (float, int) -> (float)", has_folder=True) emit("aten::sub.float : (float, float) -> (float)", has_folder=True) emit("aten::mul.float : (float, float) -> (float)", has_folder=True) emit("aten::div.float : (float, float) -> (float)", has_folder=True) diff --git a/test/Dialect/Torch/canonicalize.mlir b/test/Dialect/Torch/canonicalize.mlir index f13bf60cb15b..51ce43262af6 100644 --- a/test/Dialect/Torch/canonicalize.mlir +++ b/test/Dialect/Torch/canonicalize.mlir @@ -1191,6 +1191,16 @@ func.func @torch.aten.mul.float() -> !torch.float { return %ret : !torch.float } +// CHECK-LABEL: func.func @torch.aten.mul.float_int() -> !torch.float { +// CHECK: %[[CST6:.*]] = torch.constant.float 6.000000e+00 +// CHECK: return %[[CST6]] : !torch.float +func.func @torch.aten.mul.float_int() -> !torch.float { + %cst2 = torch.constant.float 2.0 + %cst3 = torch.constant.int 3 + %ret = torch.aten.mul.float_int %cst2, %cst3: !torch.float, !torch.int -> !torch.float + return %ret : !torch.float +} + // CHECK-LABEL: func.func @torch.aten.neg.float() -> !torch.float { // CHECK: %[[CST_6:.*]] = torch.constant.float -6.000000e+00 // CHECK: return %[[CST_6]] : !torch.float From b0b0d74e321ee7d8963f402f040a29365e209994 Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Tue, 1 Oct 2024 16:37:08 +0000 Subject: [PATCH 2/5] Add Arith conversion test. --- test/Conversion/TorchToArith/basic.mlir | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/test/Conversion/TorchToArith/basic.mlir b/test/Conversion/TorchToArith/basic.mlir index 3d9e9f22a858..d297c7df93ed 100644 --- a/test/Conversion/TorchToArith/basic.mlir +++ b/test/Conversion/TorchToArith/basic.mlir @@ -194,6 +194,19 @@ func.func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.in return %0 : !torch.int } +// CHECK-LABEL: func.func @torch.aten.mul.float_int( +// CHECK-SAME: %[[LHS:.*]]: !torch.float, +// CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.float { +// CHECK-DAG: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] +// CHECK-DAG: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] +// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_F64:.*]], [[RHS_I64:.*]] : f64 +// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[MUL:.*]] +// CHECK: return %[[OUT:.*]] : !torch.float +func.func @torch.aten.mul.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.float { + %0 = torch.aten.mul.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.float + return %0 : !torch.float +} + // CHECK-LABEL: func.func @torch.aten.div.float( // CHECK-SAME: %[[LHS:.*]]: !torch.float, // CHECK-SAME: %[[RHS:.*]]: !torch.float) -> !torch.float { From c6fdab1dd8b05276bcb43e07a211591d2429cb64 Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Fri, 15 Nov 2024 11:31:56 +0000 Subject: [PATCH 3/5] Correct missing arith.sitofp in unit test --- test/Conversion/TorchToArith/basic.mlir | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/Conversion/TorchToArith/basic.mlir b/test/Conversion/TorchToArith/basic.mlir index 4eada1198099..b70ad70bad89 100644 --- a/test/Conversion/TorchToArith/basic.mlir +++ b/test/Conversion/TorchToArith/basic.mlir @@ -241,7 +241,8 @@ func.func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.in // CHECK-SAME: %[[RHS:.*]]: !torch.int) -> !torch.float { // CHECK-DAG: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] // CHECK-DAG: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] -// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_F64:.*]], [[RHS_I64:.*]] : f64 +// CHECK: %[[RHS_F64:.*]] = arith.sitofp %[[RHS_I64]] : i64 to f64 +// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_F64:.*]], [[RHS_F64:.*]] : f64 // CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[MUL:.*]] // CHECK: return %[[OUT:.*]] : !torch.float func.func @torch.aten.mul.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.float { From 74effb24e3f3bc16fbc9989c627a2b4fea3deba7 Mon Sep 17 00:00:00 2001 From: Giacomo Serafini <179146510+giacs-epic@users.noreply.github.com> Date: Wed, 20 Nov 2024 08:58:01 +0000 Subject: [PATCH 4/5] Fix check in test/Conversion/TorchToArith/basic.mlir Co-authored-by: zjgarvey <47986913+zjgarvey@users.noreply.github.com> --- test/Conversion/TorchToArith/basic.mlir | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Conversion/TorchToArith/basic.mlir b/test/Conversion/TorchToArith/basic.mlir index b70ad70bad89..3eafce01e34a 100644 --- a/test/Conversion/TorchToArith/basic.mlir +++ b/test/Conversion/TorchToArith/basic.mlir @@ -242,7 +242,7 @@ func.func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.in // CHECK-DAG: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] // CHECK-DAG: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] // CHECK: %[[RHS_F64:.*]] = arith.sitofp %[[RHS_I64]] : i64 to f64 -// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_F64:.*]], [[RHS_F64:.*]] : f64 +// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_F64]], [[RHS_F64]] : f64 // CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[MUL:.*]] // CHECK: return %[[OUT:.*]] : !torch.float func.func @torch.aten.mul.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.float { From 009eef922aeec2c98717d5836e51951f95b5a3e0 Mon Sep 17 00:00:00 2001 From: giacs-epic <179146510+giacs-epic@users.noreply.github.com> Date: Wed, 20 Nov 2024 15:02:45 +0000 Subject: [PATCH 5/5] Fix checks in basic.mlir --- test/Conversion/TorchToArith/basic.mlir | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/Conversion/TorchToArith/basic.mlir b/test/Conversion/TorchToArith/basic.mlir index 3eafce01e34a..59fc9d7b756f 100644 --- a/test/Conversion/TorchToArith/basic.mlir +++ b/test/Conversion/TorchToArith/basic.mlir @@ -242,9 +242,9 @@ func.func @torch.aten.mul.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.in // CHECK-DAG: %[[LHS_F64:.*]] = torch_c.to_f64 %[[LHS]] // CHECK-DAG: %[[RHS_I64:.*]] = torch_c.to_i64 %[[RHS]] // CHECK: %[[RHS_F64:.*]] = arith.sitofp %[[RHS_I64]] : i64 to f64 -// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_F64]], [[RHS_F64]] : f64 -// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[MUL:.*]] -// CHECK: return %[[OUT:.*]] : !torch.float +// CHECK: %[[MUL:.*]] = arith.mulf %[[LHS_F64]], %[[RHS_F64]] : f64 +// CHECK: %[[OUT:.*]] = torch_c.from_f64 %[[MUL]] +// CHECK: return %[[OUT]] : !torch.float func.func @torch.aten.mul.float_int(%arg0: !torch.float, %arg1: !torch.int) -> !torch.float { %0 = torch.aten.mul.float_int %arg0, %arg1 : !torch.float, !torch.int -> !torch.float return %0 : !torch.float