Skip to content

Commit

Permalink
[LLVMGPU] Use scf.forall for workgroup distribution (iree-org#18826)
Browse files Browse the repository at this point in the history
Enable scf.forall distribution for `tileAndBufferize`,
`GPUWinogradVectorizePassPipeline`, `GPUMatmulSimtPassPipeline` ,
`GPUTransposePassPipeline` and `GPUPackUnPackPasses` pipeline.
  • Loading branch information
pashu123 authored Nov 22, 2024
1 parent e179a6e commit 2602a2a
Show file tree
Hide file tree
Showing 5 changed files with 84 additions and 33 deletions.
69 changes: 64 additions & 5 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,52 @@ namespace mlir::iree_compiler {
namespace {
static constexpr int64_t kCudaWarpSize = 32;

static void
replaceUnitMappingIdsHelper(RewriterBase &rewriter, Location loc, Block *parent,
Value replacement,
ArrayRef<int64_t> availableMappingSizes) {
parent->walk([&](gpu::ThreadIdOp idOp) {
if (availableMappingSizes[static_cast<int64_t>(idOp.getDimension())] == 1)
rewriter.replaceAllUsesWith(idOp.getResult(), replacement);
});
}

// This is an upstream method adapted from
// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/GPU/TransformOps/GPUTransformOps.cpp#L846
// to fix the ASAN error.
DiagnosedSilenceableFailure static mapNestedForallToThreadsImpl(
RewriterBase &rewriter, Operation *target, ArrayRef<int64_t> blockDims,
int64_t warpSize, bool syncAfterDistribute) {

if (blockDims.size() != 3) {
return emitDefiniteFailure(target, "requires size-3 thread mapping");
}

Block *parentBlock = target->getBlock();

// Create an early zero index value for replacements.
Location loc = target->getLoc();
Value zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
DiagnosedSilenceableFailure diag = DiagnosedSilenceableFailure::success();
WalkResult walkResult = target->walk([&](scf::ForallOp forallOp) {
diag = mlir::transform::gpu::mapOneForallToThreadsImpl(
rewriter, std::nullopt, forallOp, blockDims, warpSize,
syncAfterDistribute);
if (diag.isDefiniteFailure())
return WalkResult::interrupt();
if (diag.succeeded())
return WalkResult::skip();
return WalkResult::advance();
});
if (walkResult.wasInterrupted())
return diag;

// Replace ids of dimensions known to be 1 by 0 to simplify the IR.
// Here, the result of mapping determines the available mapping sizes.
replaceUnitMappingIdsHelper(rewriter, loc, parentBlock, zero, blockDims);
return DiagnosedSilenceableFailure::success();
}

struct GPUDistributePass final
: impl::GPUDistributePassBase<GPUDistributePass> {
void runOnOperation() override {
Expand All @@ -41,11 +87,24 @@ struct GPUDistributePass final
int64_t subgroupSize = maybeSubgroupSize.value_or(kCudaWarpSize);

rewriter.setInsertionPointToStart(&funcOp.front());
DiagnosedSilenceableFailure result =
mlir::transform::gpu::mapNestedForallToThreadsImpl(
rewriter, std::nullopt, funcOp, workgroupSize.value(), subgroupSize,
false);
if (!result.succeeded())

DiagnosedSilenceableFailure result = DiagnosedSilenceableFailure::success();
WalkResult walkResult = funcOp->walk([&](scf::ForallOp forallOp) {
bool hasWorkgroupMapping =
llvm::any_of(forallOp.getMapping().value(),
llvm::IsaPred<IREE::Codegen::WorkgroupMappingAttr>);
if (!hasWorkgroupMapping) {
result = mapNestedForallToThreadsImpl(
rewriter, forallOp, workgroupSize.value(), subgroupSize, false);
if (result.isDefiniteFailure())
return WalkResult::interrupt();
if (result.succeeded())
return WalkResult::skip();
}
return WalkResult::advance();
});

if (walkResult.wasInterrupted())
return signalPassFailure();
}
};
Expand Down
10 changes: 5 additions & 5 deletions compiler/src/iree/compiler/Codegen/LLVMGPU/Passes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ static void tileAndDistributeToWorkgroup(
static void tileAndBufferize(OpPassManager &funcPassManager) {
ConvertToDestinationPassingStylePassOptions options;
options.useWARForCooperativeMatrixCodegen = true;
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false,
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/true,
/*convertToDpsOptions=*/options);
addBufferizePasses(funcPassManager);
}
Expand Down Expand Up @@ -487,7 +487,7 @@ void addGPUTileAndFusePassPipeline(OpPassManager &funcPassManager,
//===---------------------------------------------------------------------===//

void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) {
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/true);

funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
Expand Down Expand Up @@ -524,7 +524,7 @@ void addGPUWinogradVectorizePassPipeline(OpPassManager &funcPassManager) {

void addGPUMatmulSimtPassPipeline(OpPassManager &funcPassManager,
const GPUPipelineOptions &options) {
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/true);

funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
Expand Down Expand Up @@ -725,7 +725,7 @@ void addGPUMatmulTensorCoreMmaSyncPassPipeline(

void addGPUTransposePassPipeline(OpPassManager &funcPassManager,
const GPUPipelineOptions &options) {
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/true);

funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
Expand Down Expand Up @@ -969,7 +969,7 @@ void addGPUWarpReductionPassPipeline(OpPassManager &funcPassManager) {
}

void addGPUPackUnPackPasses(OpPassManager &funcPassManager) {
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/false);
tileAndDistributeToWorkgroup(funcPassManager, /*useForall=*/true);
funcPassManager.addPass(createConfigTrackingCanonicalizerPass());
funcPassManager.addPass(createCSEPass());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ hal.executable @mma_fused_fp16 {
// CHECK: llvm.br
// CHECK-NOT: nvvm.mma.sync
// CHECK-COUNT-4: llvm.store {{.*}} : vector<2xf16>, !llvm.ptr<3>
// CHECK: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16>
// CHECK: llvm.load {{.*}} : !llvm.ptr<1> -> vector<16xf16>
// CHECK: llvm.store {{.*}} : vector<8xf16>, !llvm.ptr

// -----
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -462,14 +462,7 @@ hal.executable @mma_fused {
// SM80: nvvm.cp.async.commit.group
// SM80: llvm.br
// SM80-NOT: nvvm.wmma.mma
// SM80-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<3>, f32, f32, f32, f32, f32, f32, f32, f32
// SM80: vvm.barrier0
// SM80: llvm.load {{.*}} : !llvm.ptr<3> -> vector<4xf32>
// SM80: llvm.fadd {{.*}} : vector<4xf32>
// SM80: llvm.store {{.*}} : vector<4xf32>, !llvm.ptr<1>
// SM80: llvm.load {{.*}} : !llvm.ptr<3> -> vector<4xf32>
// SM80: llvm.fadd {{.*}} : vector<4xf32>
// SM80: llvm.store {{.*}} : vector<4xf32>, !llvm.ptr<1>
// SM80-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<1>, f32, f32, f32, f32, f32, f32, f32, f32



Expand Down Expand Up @@ -547,12 +540,7 @@ hal.executable @mma_fused_fp16 {
// SM80: nvvm.cp.async.commit.group
// SM80: llvm.br
// SM80-NOT: nvvm.wmma.mma
// SM80-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<3>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>
// SM80: vvm.barrier0
// SM80: llvm.load {{.*}} : !llvm.ptr<3> -> vector<8xf16>
// SM80: llvm.fadd {{.*}} : vector<8xf16>
// SM80: llvm.store {{.*}} : vector<8xf16>, !llvm.ptr<1>
// SM80: vvm.barrier0
// SM80-COUNT-1: nvvm.wmma.store {{.*}} : !llvm.ptr<1>, vector<2xf16>, vector<2xf16>, vector<2xf16>, vector<2xf16>

// -----

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ hal.executable @transpose_dispatch_0 {
// CHECK: %[[D11:.*]] = affine.apply #{{.*}}()[%[[D0]]]
// CHECK: %[[D12:.*]] = vector.transfer_read %[[D3]][%[[D11]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<32x33xf32, #gpu.address_space<workgroup>>, vector<4x1xf32>
// CHECK: %[[D13:.*]] = vector.shape_cast %[[D12]] : vector<4x1xf32> to vector<4xf32>
// CHECK: %[[D15:.*]] = affine.apply #{{.*}}()[%[[D1]], %{{.*}}]
// CHECK: %[[D15:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D1]]]
// CHECK: %[[D16:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]]
// CHECK: vector.transfer_write %[[D13]], %[[D5]][%[[D15]], %[[D16]]] {in_bounds = [true]} : vector<4xf32>, memref<4096x4096xf32, #hal.descriptor_type<storage_buffer>>

Expand Down Expand Up @@ -116,11 +116,13 @@ hal.executable @transpose_single_operand_dispatch_0_generic_768x2048 {
// CHECK: gpu.barrier
// CHECK: %[[D12:.*]] = affine.apply #{{.*}}()[%[[D0]]]
// CHECK: %[[D13:.*]] = vector.transfer_read %[[D3]][%[[D12]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<32x33xf32, #gpu.address_space<workgroup>>, vector<4x1xf32>
// CHECK: %[[D15:.*]] = affine.apply #{{.*}}()[%[[D1]], %{{.*}}]
// CHECK: %[[D16:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]]
// CHECK: %[[D17:.*]] = vector.transfer_read %[[D5]][%[[D15]], %[[D16]]], %[[CST]] {in_bounds = [true]} : memref<768x2048xf32, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
// CHECK: %[[DUP_D15:.*]] = arith.addi %[[D1]], %{{.*}} : index
// CHECK: %[[DUP_D16:.*]] = arith.addi %[[D12]], %{{.*}} : index
// CHECK: %[[D17:.*]] = vector.transfer_read %[[D5]][%[[DUP_D15]], %[[DUP_D16]]], %[[CST]] {in_bounds = [true]} : memref<768x2048xf32, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
// CHECK: %[[D14:.*]] = vector.shape_cast %[[D13]] : vector<4x1xf32> to vector<4xf32>
// CHECK: %[[D19:.*]] = arith.addf %[[D14]], %[[D17]] : vector<4xf32>
// CHECK: %[[D15:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D1]]]
// CHECK: %[[D16:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]]
// CHECK: vector.transfer_write %[[D19]], %[[D6]][%[[D15]], %[[D16]]] {in_bounds = [true]} : vector<4xf32>, memref<768x2048xf32, #hal.descriptor_type<storage_buffer>>

// -----
Expand Down Expand Up @@ -223,11 +225,13 @@ hal.executable @transpose_3d_yes_dispatch_0_generic_10x768x2048 {
// CHECK: gpu.barrier
// CHECK: %[[D12:.*]] = affine.apply #{{.*}}()[%[[D0]]]
// CHECK: %[[D13:.*]] = vector.transfer_read %[[D3]][%[[C0]], %[[D12]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<1x32x33xf32, #gpu.address_space<workgroup>>, vector<4x1xf32>
// CHECK: %[[D16:.*]] = affine.apply #{{.*}}()[%[[D1]], %{{.*}}]
// CHECK: %[[D17:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]]
// CHECK: %[[D18:.*]] = vector.transfer_read %[[D5]][%{{.*}}, %[[D16]], %[[D17]]], %[[CST]] {in_bounds = [true]} : memref<10x768x2048xf32, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
// CHECK: %[[DUP_D16:.*]] = arith.addi %[[D1]], %{{.*}} : index
// CHECK: %[[DUP_D17:.*]] = arith.addi %[[D12]], %{{.*}} : index
// CHECK: %[[D18:.*]] = vector.transfer_read %[[D5]][%{{.*}}, %[[DUP_D16]], %[[DUP_D17]]], %[[CST]] {in_bounds = [true]} : memref<10x768x2048xf32, #hal.descriptor_type<storage_buffer>>, vector<4xf32>
// CHECK: %[[D15:.*]] = vector.shape_cast %[[D13]] : vector<4x1xf32> to vector<4xf32>
// CHECK: %[[D20:.*]] = arith.addf %[[D15]], %[[D18]] : vector<4xf32>
// CHECK: %[[D16:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D1]]]
// CHECK: %[[D17:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]]
// CHECK: vector.transfer_write %[[D20]], %[[D6]][%{{.*}}, %[[D16]], %[[D17]]] {in_bounds = [true]} : vector<4xf32>, memref<10x768x2048xf32, #hal.descriptor_type<storage_buffer>>

// -----
Expand Down Expand Up @@ -295,7 +299,7 @@ hal.executable @transpose_3d_trans_out_dispatch_0_generic_10x2048x768 {
// CHECK: %[[D16:.*]] = vector.transfer_read %[[D3]][%[[C0]], %[[D14]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<1x32x33xf32, #gpu.address_space<workgroup>>, vector<4x1xf32>
// CHECK: %[[D17:.*]] = arith.addf %[[D15]], %[[D16]] : vector<4x1xf32>
// CHECK: %[[D19:.*]] = vector.shape_cast %[[D17]] : vector<4x1xf32> to vector<4xf32>
// CHECK: %[[D21:.*]] = affine.apply #{{.*}}()[%[[D1]], %{{.*}}]
// CHECK: %[[D21:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D1]]]
// CHECK: %[[D22:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]]
// CHECK: vector.transfer_write %[[D19]], %[[D7]][%{{.*}}, %[[D21]], %[[D22]]] {in_bounds = [true]} : vector<4xf32>, memref<10x2048x768xf32, #hal.descriptor_type<storage_buffer>>

Expand Down

0 comments on commit 2602a2a

Please sign in to comment.