Skip to content

Commit

Permalink
[Codegen] Add the bitcast -> extui to shuffle folding patterns to…
Browse files Browse the repository at this point in the history
… EmulateNarrowTypes pass. (iree-org#15102)

Folding the `bitcast -> arith.extui` to `shuffle` seems like worth doing
across all backends (all backends support shuffle better). Also add a
pattern to push broadcasts past `extui`-like operations to increase the
coverage of cases where this kicks in.
  • Loading branch information
MaheshRavishankar authored Oct 7, 2023
1 parent 17e758b commit b5bbea2
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 6 deletions.
21 changes: 20 additions & 1 deletion compiler/src/iree/compiler/Codegen/Common/EmulateNarrowType.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

namespace mlir {
namespace iree_compiler {
Expand Down Expand Up @@ -137,8 +138,26 @@ struct EmulateNarrowTypePass
populateIreeNarrowTypeEmulationPatterns(typeConverter, patterns);

if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
std::move(patterns)))) {
getOperation()->emitOpError("failed to emulate bit width");
return signalPassFailure();
}

RewritePatternSet sinkBroadcast(ctx);
vector::populateSinkVectorBroadcastPatterns(sinkBroadcast);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(sinkBroadcast)))) {
getOperation()->emitOpError("failed in sinking of broadcasts");
return signalPassFailure();
}

// Also do the `bitcast -> extui/extsi` rewrite.
RewritePatternSet foldExtPatterns(ctx);
vector::populateVectorNarrowTypeRewritePatterns(foldExtPatterns);
if (failed(applyPatternsAndFoldGreedily(getOperation(),
std::move(foldExtPatterns)))) {
return signalPassFailure();
}
}
};
} // namespace
Expand Down
Original file line number Diff line number Diff line change
@@ -1,18 +1,21 @@
// RUN: iree-opt --split-input-file --iree-codegen-emulate-narrow-type %s | FileCheck %s

func.func @memref_i4_to_i8() {
func.func @memref_i4_to_i8() -> i4 {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%c0) flags(ReadOnly) : memref<3x15xi4>
return
%1 = memref.load %0[%c0, %c0] : memref<3x15xi4>
return %1 : i4
}
// CHECK-LABEL: func.func @memref_i4_to_i8
// CHECK: hal.interface.binding.subspan {{.+}} memref<23xi8>

// -----

func.func @memref_i4_to_i8_dynamic(%arg0 : index, %arg1 : index, %arg2 : index) {
func.func @memref_i4_to_i8_dynamic(%arg0 : index, %arg1 : index, %arg2 : index) -> i4 {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan set(0) binding(0) type(storage_buffer) alignment(64) offset(%arg0) flags(ReadOnly) : memref<?x?xi4, strided<[?, 1], offset: ?>>{%arg1, %arg2}
return
%1 = memref.load %0[%c0, %c0] : memref<?x?xi4, strided<[?, 1], offset: ?>>
return %1 : i4
}
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0, s1] -> ((s0 * s1) floordiv 2)>
// CHECK: func.func @memref_i4_to_i8_dynamic
Expand All @@ -23,3 +26,17 @@ func.func @memref_i4_to_i8_dynamic(%arg0 : index, %arg1 : index, %arg2 : index)
// CHECK: hal.interface.binding.subspan
// CHECK-SAME: offset(%[[ARG0]])
// CHECK-SAME: memref<?xi8, strided<[1], offset: ?>>{%[[SIZE]]}

// -----

func.func @broadcast_extui() -> vector<1x1x64xi32> {
%c0 = arith.constant 0 : index
%0 = memref.alloc() : memref<64xi4>
%1 = vector.load %0[%c0] : memref<64xi4>, vector<64xi4>
%2 = vector.broadcast %1 : vector<64xi4> to vector<1x1x64xi4>
%3 = arith.extui %2 : vector<1x1x64xi4> to vector<1x1x64xi32>
return %3 : vector<1x1x64xi32>
}
// CHECK-LABEL: func @broadcast_extui()
// CHECK-NOT: vector.bitcast
// CHECK: vector.shuffle
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,10 @@ void ConvertToSPIRVPass::runOnOperation() {
/// rely on this rewrite for the cases seen today.
/// TODO: Support general emulation of compute on sub-byte types. This is
/// not mutually exclusive with this pattern, but does mean it is no longer
/// load bearing.
/// load bearing. Also these patterns are already run during
/// `EmulateNarrotType` pass but dont trigger there due to missing support for
/// emulation of `vector.transfer_read` in the emulation path. Remove the
/// patterns from here after that is done.
for (auto funcOp : moduleOp.getOps<func::FuncOp>()) {
RewritePatternSet narrowingPatterns(context);
vector::populateVectorNarrowTypeRewritePatterns(narrowingPatterns);
Expand Down

0 comments on commit b5bbea2

Please sign in to comment.