From 6e8c7bed4b12117764274e79bc60a93443d5bdd5 Mon Sep 17 00:00:00 2001 From: zjgarvey <47986913+zjgarvey@users.noreply.github.com> Date: Fri, 4 Oct 2024 11:27:00 -0500 Subject: [PATCH 1/2] [TorchToLinalg] perform rank0 elementwise computations outside linalg generic ops (#3762) This is motivated by the fact that shapes are stored as tensors in ONNX, and IREE tries to perform tensor arithmetic on the device. This causes unnecessary dispatches, and makes it harder for the compiler to reason about shapes. Here is a small snippet of torch-IR that is typical seen coming from ONNX models: ```mlir module { func.func @main_graph(%arg0: !torch.vtensor<[?,?,768],f32>, %arg1: !torch.vtensor<[?,?,768],f32>) -> !torch.vtensor<[],si64> { %int0 = torch.constant.int 0 %0 = torch.vtensor.literal(dense<0> : tensor<1xsi64>) : !torch.vtensor<[1],si64> %1 = torch.aten._shape_as_tensor %arg1 : !torch.vtensor<[?,?,768],f32> -> !torch.vtensor<[3],si64> %2 = torch.aten.index_select %1, %int0, %0 : !torch.vtensor<[3],si64>, !torch.int, !torch.vtensor<[1],si64> -> !torch.vtensor<[1],si64> %3 = torch.aten.squeeze.dim %2, %int0 : !torch.vtensor<[1],si64>, !torch.int -> !torch.vtensor<[],si64> %4 = torch.aten.item %3 : !torch.vtensor<[],si64> -> !torch.int %5 = torch.aten.eq.int %4, %int0 : !torch.int, !torch.int -> !torch.bool %6 = torch.aten.Int.bool %5 : !torch.bool -> !torch.int %7 = torch.aten.size.int %arg0, %int0 : !torch.vtensor<[?,?,768],f32>, !torch.int -> !torch.int %8 = torch.prim.NumToTensor.Scalar %6 : !torch.int -> !torch.vtensor<[],i1> %9 = torch.prim.NumToTensor.Scalar %7 : !torch.int -> !torch.vtensor<[],si64> %10 = torch.prim.NumToTensor.Scalar %4 : !torch.int -> !torch.vtensor<[],si64> %11 = torch.aten.where.self %8, %9, %10 : !torch.vtensor<[],i1>, !torch.vtensor<[],si64>, !torch.vtensor<[],si64> -> !torch.vtensor<[],si64> return %11 : !torch.vtensor<[],si64> } } ``` Without the change in this PR, the result would be: ```mlir #map = affine_map<() -> ()> module { ml_program.global private mutable @global_seed(dense<0> : tensor) : tensor func.func @main_graph(%arg0: tensor, %arg1: tensor) -> tensor { %c0_i64 = arith.constant 0 : i64 %c0 = arith.constant 0 : index %dim = tensor.dim %arg1, %c0 : tensor %0 = arith.index_cast %dim : index to i64 %1 = tensor.empty() : tensor<1xi64> %collapsed = tensor.collapse_shape %1 [] : tensor<1xi64> into tensor %2 = linalg.fill ins(%0 : i64) outs(%collapsed : tensor) -> tensor %extracted = tensor.extract %2[] : tensor %3 = arith.cmpi eq, %extracted, %c0_i64 : i64 %dim_0 = tensor.dim %arg0, %c0 : tensor %4 = arith.index_cast %dim_0 : index to i64 %5 = tensor.empty() : tensor %6 = linalg.fill ins(%3 : i1) outs(%5 : tensor) -> tensor %7 = tensor.empty() : tensor %8 = linalg.fill ins(%4 : i64) outs(%7 : tensor) -> tensor %9 = linalg.fill ins(%extracted : i64) outs(%7 : tensor) -> tensor %10 = linalg.generic {indexing_maps = [#map, #map, #map, #map], iterator_types = []} ins(%6, %8, %9 : tensor, tensor, tensor) outs(%7 : tensor) { ^bb0(%in: i1, %in_1: i64, %in_2: i64, %out: i64): %11 = arith.select %in, %in_1, %in_2 : i64 linalg.yield %11 : i64 } -> tensor return %10 : tensor } } ``` With the change in this PR, we would instead get: ```mlir module { ml_program.global private mutable @global_seed(dense<0> : tensor) : tensor func.func @main_graph(%arg0: tensor, %arg1: tensor) -> tensor { %c0_i64 = arith.constant 0 : i64 %c0 = arith.constant 0 : index %dim = tensor.dim %arg1, %c0 : tensor %0 = arith.index_cast %dim : index to i64 %1 = tensor.empty() : tensor<1xi64> %collapsed = tensor.collapse_shape %1 [] : tensor<1xi64> into tensor %2 = linalg.fill ins(%0 : i64) outs(%collapsed : tensor) -> tensor %extracted = tensor.extract %2[] : tensor %3 = arith.cmpi eq, %extracted, %c0_i64 : i64 %dim_0 = tensor.dim %arg0, %c0 : tensor %4 = arith.index_cast %dim_0 : index to i64 %5 = arith.select %3, %4, %extracted : i64 %6 = tensor.empty() : tensor %7 = linalg.fill ins(%5 : i64) outs(%6 : tensor) -> tensor return %7 : tensor } } ``` Some related issues for context: 1. 2. --- .../TorchToLinalg/Uncategorized.cpp | 19 +++++++++++++++++++ .../Conversion/TorchToLinalg/elementwise.mlir | 12 +++++------- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 0f6f92bd7c2c..0532b4b19d94 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1627,6 +1627,25 @@ class ConvertElementwiseOp : public ConversionPattern { operands, [](Value v) { return isa(v.getType()); })); auto resultType = cast( getTypeConverter()->convertType(op->getResult(0).getType())); + bool isScalarOp = resultType.getShape().size() == 0; + if (isScalarOp) { + // for elementwise ops that are actually rank0 scalar computations, + // perform the payload outside a linalg generic op. + SmallVector payloadArgs; + for (auto t : tensorOperands) { + payloadArgs.push_back(rewriter.create(loc, t)); + } + Value scalarResult = createLinalgPayloadCalculationForElementwiseOp( + rewriter, loc, getTypeConverter(), payloadArgs, op, operands); + if (!scalarResult) + return rewriter.notifyMatchFailure( + op, "Failed to create payload for scalar elementwise op"); + Value rank0Result = + createInitTensor(rewriter, loc, ValueRange{}, + resultType.getElementType(), scalarResult); + rewriter.replaceOpWithNewOp(op, resultType, rank0Result); + return success(); + } bool hadErrorCreatingPayload = false; Value generic = torch_to_linalg::createElementwiseLinalgGeneric( rewriter, loc, tensorOperands, resultType.getElementType(), diff --git a/test/Conversion/TorchToLinalg/elementwise.mlir b/test/Conversion/TorchToLinalg/elementwise.mlir index aa2be74f5d7e..ecf4caa58389 100644 --- a/test/Conversion/TorchToLinalg/elementwise.mlir +++ b/test/Conversion/TorchToLinalg/elementwise.mlir @@ -4,13 +4,11 @@ // CHECK-LABEL: func.func @elementwise$unary( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { // CHECK-DAG: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor -// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor -// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%[[BUILTIN_TENSOR]] : tensor) outs(%[[INIT_TENSOR]] : tensor) { -// CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32): -// CHECK: %[[TANH:.*]] = math.tanh %[[BBARG0]] : f32 -// CHECK: linalg.yield %[[TANH]] : f32 -// CHECK: } -> tensor -// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor to tensor +// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[BUILTIN_TENSOR]][] : tensor +// CHECK: %[[TANH:.*]] = math.tanh %[[EXTRACT]] : f32 +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor +// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[TANH]] : f32) outs(%[[EMPTY]] : tensor) -> tensor +// CHECK: %[[CASTED:.*]] = tensor.cast %[[FILL:.*]] : tensor to tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor -> !torch.vtensor<[],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[],f32> // CHECK: } From e1267ce7323868e19c11089600b14b0251a74f01 Mon Sep 17 00:00:00 2001 From: Rob Suderman Date: Fri, 4 Oct 2024 14:48:02 -0700 Subject: [PATCH 2/2] Revert "[TorchToLinalg] perform rank0 elementwise computations outside linalg generic ops (#3762)" (#3767) Reverted due to downstream model changes. Will reland with fixes post integration. This reverts commit 6e8c7bed4b12117764274e79bc60a93443d5bdd5. --- .../TorchToLinalg/Uncategorized.cpp | 19 ------------------- .../Conversion/TorchToLinalg/elementwise.mlir | 12 +++++++----- 2 files changed, 7 insertions(+), 24 deletions(-) diff --git a/lib/Conversion/TorchToLinalg/Uncategorized.cpp b/lib/Conversion/TorchToLinalg/Uncategorized.cpp index 2886c2835897..4292a8dde0d8 100644 --- a/lib/Conversion/TorchToLinalg/Uncategorized.cpp +++ b/lib/Conversion/TorchToLinalg/Uncategorized.cpp @@ -1635,25 +1635,6 @@ class ConvertElementwiseOp : public ConversionPattern { operands, [](Value v) { return isa(v.getType()); })); auto resultType = cast( getTypeConverter()->convertType(op->getResult(0).getType())); - bool isScalarOp = resultType.getShape().size() == 0; - if (isScalarOp) { - // for elementwise ops that are actually rank0 scalar computations, - // perform the payload outside a linalg generic op. - SmallVector payloadArgs; - for (auto t : tensorOperands) { - payloadArgs.push_back(rewriter.create(loc, t)); - } - Value scalarResult = createLinalgPayloadCalculationForElementwiseOp( - rewriter, loc, getTypeConverter(), payloadArgs, op, operands); - if (!scalarResult) - return rewriter.notifyMatchFailure( - op, "Failed to create payload for scalar elementwise op"); - Value rank0Result = - createInitTensor(rewriter, loc, ValueRange{}, - resultType.getElementType(), scalarResult); - rewriter.replaceOpWithNewOp(op, resultType, rank0Result); - return success(); - } bool hadErrorCreatingPayload = false; Value generic = torch_to_linalg::createElementwiseLinalgGeneric( rewriter, loc, tensorOperands, resultType.getElementType(), diff --git a/test/Conversion/TorchToLinalg/elementwise.mlir b/test/Conversion/TorchToLinalg/elementwise.mlir index ecf4caa58389..aa2be74f5d7e 100644 --- a/test/Conversion/TorchToLinalg/elementwise.mlir +++ b/test/Conversion/TorchToLinalg/elementwise.mlir @@ -4,11 +4,13 @@ // CHECK-LABEL: func.func @elementwise$unary( // CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> { // CHECK-DAG: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor -// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[BUILTIN_TENSOR]][] : tensor -// CHECK: %[[TANH:.*]] = math.tanh %[[EXTRACT]] : f32 -// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor -// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[TANH]] : f32) outs(%[[EMPTY]] : tensor) -> tensor -// CHECK: %[[CASTED:.*]] = tensor.cast %[[FILL:.*]] : tensor to tensor +// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor +// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%[[BUILTIN_TENSOR]] : tensor) outs(%[[INIT_TENSOR]] : tensor) { +// CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32): +// CHECK: %[[TANH:.*]] = math.tanh %[[BBARG0]] : f32 +// CHECK: linalg.yield %[[TANH]] : f32 +// CHECK: } -> tensor +// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor to tensor // CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor -> !torch.vtensor<[],f32> // CHECK: return %[[RESULT]] : !torch.vtensor<[],f32> // CHECK: }