Skip to content

Commit

Permalink
[LLVMGPU] Add vector transfer lowering patterns to LLVMGPUVectorLowering
Browse files Browse the repository at this point in the history
This adds a few peephole optimizations used in SPIR-V vector lowerings
to LLVMGPU. At some future point it might be worth splitting up
SPIRVVectorize to unify the usage of various peephole/lowering patterns
between backends.

Additionally this adds some missing dialect registrations to this pass.
  • Loading branch information
qedawkins committed Sep 12, 2023
1 parent 53215e2 commit 3001524
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ namespace {
struct LLVMGPUVectorLoweringPass
: public LLVMGPUVectorLoweringBase<LLVMGPUVectorLoweringPass> {
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<affine::AffineDialect>();
registry.insert<memref::MemRefDialect>();
registry.insert<vector::VectorDialect>();
registry.insert<scf::SCFDialect>();
}
void runOnOperation() override {
func::FuncOp funcOp = getOperation();
Expand All @@ -33,6 +36,10 @@ struct LLVMGPUVectorLoweringPass
// Lower high level vector operations like contract or multidim reduce ops
// to lower level vector ops.
RewritePatternSet contractLoweringPatterns(funcOp.getContext());
vector::populateVectorTransferPermutationMapLoweringPatterns(
contractLoweringPatterns);
vector::TransposeOp::getCanonicalizationPatterns(contractLoweringPatterns,
funcOp.getContext());
vector::populateVectorBroadcastLoweringPatterns(contractLoweringPatterns);
vector::populateVectorContractLoweringPatterns(
contractLoweringPatterns,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ iree_lit_test_suite(
"transform_gpu_pipelining.mlir",
"transform_vector_to_mma.mlir",
"transpose_pipeline_test.mlir",
"vector_lowering.mlir",
"vector_to_gpu.mlir",
"workgroup_specialization_pipeline_test.mlir",
],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ iree_lit_test_suite(
"transform_gpu_pipelining.mlir"
"transform_vector_to_mma.mlir"
"transpose_pipeline_test.mlir"
"vector_lowering.mlir"
"vector_to_gpu.mlir"
"workgroup_specialization_pipeline_test.mlir"
TOOLS
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-llvmgpu-vector-lowering))" --split-input-file %s | FileCheck %s

module {
func.func @broadcast_read_lowering(%arg0: memref<4096x32xf16>) -> vector<1x8xf16> {
%cst_1 = arith.constant 0.000000e+00 : f16
%0 = gpu.thread_id x
%workgroup_id_x = hal.interface.workgroup.id[0] : index
%broadcast_read = vector.transfer_read %arg0[%workgroup_id_x, %0], %cst_1 {in_bounds = [true, true], permutation_map = affine_map<(d0, d1) -> (d1, 0)>} : memref<4096x32xf16>, vector<1x8xf16>
return %broadcast_read : vector<1x8xf16>
}
}
// CHECK-LABEL: func.func @broadcast_read_lowering
// CHECK-SAME: (%[[ARG0:.+]]: memref<4096x32xf16>)
// CHECK: %[[INIT:.+]] = arith.constant dense<0.000000e+00> : vector<1x8xf16>
// CHECK: %[[ELEM:.+]] = memref.load %[[ARG0]]{{.*}} : memref<4096x32xf16>
// CHECK: %[[SPLAT:.+]] = vector.splat %[[ELEM]] : vector<8xf16>
// CHECK: %[[INSERT:.+]] = vector.insert %[[SPLAT]], %[[INIT]] [0] : vector<8xf16> into vector<1x8xf16>
// CHECK: return %[[INSERT]]

0 comments on commit 3001524

Please sign in to comment.