Skip to content

Commit

Permalink
[AutoBump] Merge with fixes of 6e8c7be (Oct 04)
Browse files Browse the repository at this point in the history
  • Loading branch information
mgehre-amd committed Dec 17, 2024
2 parents bf1ee4a + 6e8c7be commit efc5746
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 7 deletions.
19 changes: 19 additions & 0 deletions lib/Conversion/TorchToLinalg/Uncategorized.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1635,6 +1635,25 @@ class ConvertElementwiseOp : public ConversionPattern {
operands, [](Value v) { return isa<RankedTensorType>(v.getType()); }));
auto resultType = cast<RankedTensorType>(
getTypeConverter()->convertType(op->getResult(0).getType()));
bool isScalarOp = resultType.getShape().size() == 0;
if (isScalarOp) {
// for elementwise ops that are actually rank0 scalar computations,
// perform the payload outside a linalg generic op.
SmallVector<Value> payloadArgs;
for (auto t : tensorOperands) {
payloadArgs.push_back(rewriter.create<tensor::ExtractOp>(loc, t));
}
Value scalarResult = createLinalgPayloadCalculationForElementwiseOp(
rewriter, loc, getTypeConverter(), payloadArgs, op, operands);
if (!scalarResult)
return rewriter.notifyMatchFailure(
op, "Failed to create payload for scalar elementwise op");
Value rank0Result =
createInitTensor(rewriter, loc, ValueRange{},
resultType.getElementType(), scalarResult);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, resultType, rank0Result);
return success();
}
bool hadErrorCreatingPayload = false;
Value generic = torch_to_linalg::createElementwiseLinalgGeneric(
rewriter, loc, tensorOperands, resultType.getElementType(),
Expand Down
12 changes: 5 additions & 7 deletions test/Conversion/TorchToLinalg/elementwise.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,11 @@
// CHECK-LABEL: func.func @elementwise$unary(
// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[],f32>) -> !torch.vtensor<[],f32> {
// CHECK-DAG: %[[BUILTIN_TENSOR:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[],f32> -> tensor<f32>
// CHECK: %[[INIT_TENSOR:.*]] = tensor.empty() : tensor<f32>
// CHECK: %[[GENERIC:.*]] = linalg.generic {indexing_maps = [affine_map<() -> ()>, affine_map<() -> ()>], iterator_types = []} ins(%[[BUILTIN_TENSOR]] : tensor<f32>) outs(%[[INIT_TENSOR]] : tensor<f32>) {
// CHECK: ^bb0(%[[BBARG0:.*]]: f32, %{{.*}}: f32):
// CHECK: %[[TANH:.*]] = math.tanh %[[BBARG0]] : f32
// CHECK: linalg.yield %[[TANH]] : f32
// CHECK: } -> tensor<f32>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[GENERIC:.*]] : tensor<f32> to tensor<f32>
// CHECK: %[[EXTRACT:.*]] = tensor.extract %[[BUILTIN_TENSOR]][] : tensor<f32>
// CHECK: %[[TANH:.*]] = math.tanh %[[EXTRACT]] : f32
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<f32>
// CHECK: %[[FILL:.*]] = linalg.fill ins(%[[TANH]] : f32) outs(%[[EMPTY]] : tensor<f32>) -> tensor<f32>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[FILL:.*]] : tensor<f32> to tensor<f32>
// CHECK: %[[RESULT:.*]] = torch_c.from_builtin_tensor %[[CASTED]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[RESULT]] : !torch.vtensor<[],f32>
// CHECK: }
Expand Down

0 comments on commit efc5746

Please sign in to comment.