From ff2485b200f017c9a24f06c01f9241bb1837a0ca Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Wed, 11 Dec 2024 00:50:45 +0100 Subject: [PATCH 1/2] TOSA: fold cast-to-bf16(cast-to-f32(x)) -> cast-to-bf16(x) --- mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 10 ++++++++++ mlir/test/Dialect/Tosa/canonicalize.mlir | 8 ++++++++ 2 files changed, 18 insertions(+) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index c3d9d2a773ae70..d84f4629bb2756 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1008,6 +1008,16 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) { } } + // cast-to-bf16(cast-to-f32(x)) -> cast-to-bf16(x) + if (auto cast = getInput().getDefiningOp()) { + auto intermediateElTy = cast.getType().getElementType(); + auto finalElTy = getType().getElementType(); + if (isa(intermediateElTy) && isa(finalElTy)) { + getInputMutable().assign(cast.getInput()); + return getResult(); + } + } + auto operand = llvm::dyn_cast_if_present(adaptor.getInput()); if (!operand) return {}; diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 2035659f17146e..49a664ab4a409c 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -55,6 +55,14 @@ func.func @cast_fold_double(%arg0: tensor) -> tensor { return %1 : tensor } +// CHECK-LABEL: @cast_fold_double +func.func @cast_fold_double2(%arg0: tensor) -> tensor { + // CHECK: tosa.cast{{.*}} (tensor) -> tensor + %0 = tosa.cast %arg0 : (tensor) -> tensor + %1 = tosa.cast %0 : (tensor) -> tensor + return %1 : tensor +} + // CHECK-LABEL: @cast_no_fold_double1 func.func @cast_no_fold_double1(%arg0: tensor) -> tensor { // CHECK: tosa.cast{{.*}} (tensor) -> tensor From 729187cc2d68d2b37dab9b728af28726398a2b66 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Wed, 11 Dec 2024 11:19:01 +0100 Subject: [PATCH 2/2] Restrict to bf16-f32-bf16 --- mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp | 6 ++++-- mlir/test/Dialect/Tosa/canonicalize.mlir | 6 +++--- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp index d84f4629bb2756..c3dd3d00e7b8ea 100644 --- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp +++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp @@ -1008,11 +1008,13 @@ OpFoldResult CastOp::fold(FoldAdaptor adaptor) { } } - // cast-to-bf16(cast-to-f32(x)) -> cast-to-bf16(x) + // Fold cast from bf16 -> f32 -> bf16 into no-op. if (auto cast = getInput().getDefiningOp()) { + auto sourceElTy = cast.getInput().getType().getElementType(); auto intermediateElTy = cast.getType().getElementType(); auto finalElTy = getType().getElementType(); - if (isa(intermediateElTy) && isa(finalElTy)) { + if (isa(sourceElTy) && isa(intermediateElTy) && + isa(finalElTy)) { getInputMutable().assign(cast.getInput()); return getResult(); } diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir index 49a664ab4a409c..f35df639cca523 100644 --- a/mlir/test/Dialect/Tosa/canonicalize.mlir +++ b/mlir/test/Dialect/Tosa/canonicalize.mlir @@ -56,9 +56,9 @@ func.func @cast_fold_double(%arg0: tensor) -> tensor { } // CHECK-LABEL: @cast_fold_double -func.func @cast_fold_double2(%arg0: tensor) -> tensor { - // CHECK: tosa.cast{{.*}} (tensor) -> tensor - %0 = tosa.cast %arg0 : (tensor) -> tensor +func.func @cast_fold_double2(%arg0: tensor) -> tensor { + // CHECK: return %arg0 + %0 = tosa.cast %arg0 : (tensor) -> tensor %1 = tosa.cast %0 : (tensor) -> tensor return %1 : tensor }