From 031accb09edf4b3ee42cf9c263e404223982857e Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Tue, 26 Nov 2024 14:38:07 -0600 Subject: [PATCH] [GPU] Use affine.linearize_index (and delinearize_index) where possible (#19122) There have been issues with the composition of affine maps being too general and loosing important information, like the fact that affine_map<(s0 + s1 * 32 + ... - (s0 floorDiv 16) * 16)> realy should be affine_map<(s0 mod 16 + s1 * 32 + ...)>, and other issues with the ultimate IR that block low-level arithmetic optimizations. The affine.delinearize_index operation represents the div/mod chains needed to break a flat index into its component parts. A recently added affine.linearize_index operation is its inverse - combining multiple indices into a flat 1D value. Another advantage to linearize/delinearize is simpler upstream canonicalizations and lead to more streamlined generated code. This PR updates the vector distribution code and other GPU-related code that I could find to 1. Use affine.linearize_index to construct flat thread IDs 2. Use affine.delinearize_index in places where there was a floorDiv/mod chain. 3. Plumb the subgroup size through the transfer_read and transfer_write distribution patterns to enable better reasoning about when you do/don't need to take a mod of the lane ID --- .../Common/GPU/GPUDistributeForall.cpp | 41 ++-- .../GPU/GPUDistributeSharedMemoryCopy.cpp | 32 ++- .../GPU/test/gpu_distribute_forall.mlir | 63 +++--- .../test/gpu_distribute_shared_memory.mlir | 31 ++- ...ransform_gpu_distribute_shared_memory.mlir | 17 +- .../TransformExtensions/CommonExtensions.cpp | 9 +- .../CommonExtensionsOps.td | 3 +- .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | 17 +- .../Dialect/GPU/Transforms/Transforms.cpp | 5 +- .../test/distribute_mma_to_lanes.mlir | 43 ++-- .../LLVMGPU/LLVMGPUVectorDistribute.cpp | 26 +-- .../TransformExtensions/LLVMGPUExtensions.cpp | 5 +- .../LLVMGPUExtensionsOps.td | 3 +- .../test/ROCDL/pipeline_tile_and_fuse.mlir | 8 +- .../LLVMGPU/test/transpose_pipeline_test.mlir | 186 ++++++++---------- 15 files changed, 224 insertions(+), 265 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeForall.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeForall.cpp index 64623462a526..334427cfffb9 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeForall.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeForall.cpp @@ -17,6 +17,7 @@ #include "mlir/Dialect/GPU/Transforms/Passes.h" #include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h" #include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" namespace mlir::iree_compiler { @@ -87,9 +88,16 @@ LogicalResult resolveGPUMappedForallOp(RewriterBase &rewriter, assert(!(hasThreadMapping && hasWarpMapping)); Value flatId = linearThreadId; if (hasWarpMapping) { - OpFoldResult subgroupSizeVal = rewriter.getIndexAttr(subgroupSize); - flatId = affine::makeComposedAffineApply(rewriter, loc, d0.floorDiv(d1), - {flatId, subgroupSizeVal}); + if (flatWorkgroupSize % subgroupSize != 0) { + return forallOp->emitOpError( + "found warp mapped forall with non-multiple workgroup size"); + } + flatId = rewriter + .create( + loc, flatId, + ArrayRef{flatWorkgroupSize / subgroupSize, + subgroupSize}) + .getResult(0); } SmallVector delinSizes; @@ -190,23 +198,18 @@ void GPUDistributeForallPass::runOnOperation() { return signalPassFailure(); } - AffineExpr x, y, z; - bindSymbols(funcOp.getContext(), x, y, z); - // Compute the linearized thread id. - AffineExpr linearId = - x + workgroupSize[0] * y + workgroupSize[1] * workgroupSize[0] * z; - rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front()); - SmallVector threadGrid = { - rewriter.createOrFold(funcOp.getLoc(), - gpu::Dimension::x), - rewriter.createOrFold(funcOp.getLoc(), - gpu::Dimension::y), - rewriter.createOrFold(funcOp.getLoc(), - gpu::Dimension::z)}; - - Value linearThreadIdVal = affine::makeComposedAffineApply( - rewriter, funcOp.getLoc(), linearId, threadGrid); + SmallVector threadGrid = {rewriter.createOrFold( + funcOp.getLoc(), gpu::Dimension::z), + rewriter.createOrFold( + funcOp.getLoc(), gpu::Dimension::y), + rewriter.createOrFold( + funcOp.getLoc(), gpu::Dimension::x)}; + SmallVector threadGridBasis = {workgroupSize[2], workgroupSize[1], + workgroupSize[0]}; + + Value linearThreadIdVal = rewriter.create( + funcOp.getLoc(), threadGrid, threadGridBasis, /*disjoint=*/true); for (auto forall : forallOps) { rewriter.setInsertionPoint(forall); if (failed(resolveGPUMappedForallOp(rewriter, forall, linearThreadIdVal, diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp index 4610c545e553..47329c84f189 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUDistributeSharedMemoryCopy.cpp @@ -189,10 +189,8 @@ SmallVector getIds(OpBuilder &b, Location loc, ArrayRef parallelLoopRanges, Value flatThreadId) { SmallVector infos; - Value id = flatThreadId; - AffineExpr d0 = b.getAffineDimExpr(0); - for (Range r : llvm::reverse(parallelLoopRanges)) { - linalg::ProcInfo info; + SmallVector delinSizes; + for (Range r : parallelLoopRanges) { auto offset = dyn_cast(r.offset); auto stride = dyn_cast(r.stride); auto size = dyn_cast(r.size); @@ -200,19 +198,20 @@ SmallVector getIds(OpBuilder &b, Location loc, int64_t numThreadsDim = (llvm::cast(size).getInt() - llvm::cast(offset).getInt()) / llvm::cast(stride).getInt(); - Value dimId = id; - if (infos.size() != parallelLoopRanges.size() - 1) - dimId = - affine::makeComposedAffineApply(b, loc, d0 % numThreadsDim, {dimId}); + delinSizes.push_back(numThreadsDim); + } + ValueRange dims = + b.create(loc, flatThreadId, delinSizes) + .getResults(); + + for (auto [dimId, numThreadsDim] : llvm::zip_equal(dims, delinSizes)) { + linalg::ProcInfo info; info.procId = dimId; info.nprocs = b.create(loc, numThreadsDim); info.distributionMethod = linalg::DistributionMethod::CyclicNumProcsEqNumIters; infos.push_back(info); - id = affine::makeComposedAffineApply(b, loc, d0.floorDiv(numThreadsDim), - {id}); } - std::reverse(infos.begin(), infos.end()); return infos; } @@ -288,19 +287,16 @@ static Value createFlatId(mlir::FunctionOpInterface funcOp, ArrayRef workgroupSize) { OpBuilder b(funcOp.getFunctionBody()); Type indexType = b.getIndexType(); - AffineExpr d0 = getAffineDimExpr(0, b.getContext()); - AffineExpr d1 = getAffineDimExpr(1, b.getContext()); - AffineExpr d2 = getAffineDimExpr(2, b.getContext()); Value threadX = b.create(funcOp.getLoc(), indexType, gpu::Dimension::x); Value threadY = b.create(funcOp.getLoc(), indexType, gpu::Dimension::y); Value threadZ = b.create(funcOp.getLoc(), indexType, gpu::Dimension::z); - Value flatThreadId = affine::makeComposedAffineApply( - b, funcOp.getLoc(), - d0 + workgroupSize[0] * d1 + (workgroupSize[0] * workgroupSize[1]) * d2, - {threadX, threadY, threadZ}); + Value flatThreadId = b.create( + funcOp.getLoc(), ValueRange{threadZ, threadY, threadX}, + ArrayRef{workgroupSize[2], workgroupSize[1], workgroupSize[0]}, + /*disjoint=*/true); return flatThreadId; } diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_forall.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_forall.mlir index 214337437b76..32bda8c90f05 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_forall.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_forall.mlir @@ -15,11 +15,9 @@ func.func @distribute_thread_forall(%out : memref) // CHECK-LABEL: func @distribute_thread_forall // CHECK-DAG: %[[TX:.+]] = gpu.thread_id x // CHECK-DAG: %[[TY:.+]] = gpu.thread_id y -// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z +// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) // CHECK: scf.for %[[I:.+]] = %c0 to %c1024 step %c128 { -// CHECK: %[[LINID:.+]] = affine.apply -// CHECK-SAME: affine_map<(d0)[s0, s1, s2] -> (d0 + s0 + s1 * 64 + s2 * 128)>(%[[I]]) -// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]] +// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[TFLAT]]] // CHECK: memref.store {{.*}}[%[[LINID]]] // ----- @@ -38,11 +36,10 @@ func.func @distribute_warp_forall(%out : memref) // CHECK-LABEL: func @distribute_warp_forall // CHECK-DAG: %[[TX:.+]] = gpu.thread_id x // CHECK-DAG: %[[TY:.+]] = gpu.thread_id y -// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z +// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) +// CHECK: %[[WARPSPLIT:.+]]:2 = affine.delinearize_index %[[TFLAT]] into (4, 32) // CHECK: scf.for %[[I:.+]] = %c0 to %c32 step %c4 { -// CHECK: %[[LINID:.+]] = affine.apply -// CHECK-SAME: affine_map<(d0)[s0, s1, s2] -> (d0 + s1 * 2 + s2 * 4 + s0 floordiv 32)>(%[[I]]) -// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]] +// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[WARPSPLIT]]#0] // CHECK: memref.store {{.*}}[%[[LINID]]] // ----- @@ -78,11 +75,7 @@ func.func @distribute_thread_forall_drop_for_loop(%out : memref) // CHECK-LABEL: func @distribute_thread_forall_drop_for_loop // CHECK-DAG: %[[TX:.+]] = gpu.thread_id x // CHECK-DAG: %[[TY:.+]] = gpu.thread_id y -// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z -// CHECK-NOT: scf.for -// CHECK: %[[LINID:.+]] = affine.apply -// CHECK-SAME: affine_map<()[s0, s1, s2] -> (s0 + s1 * 64 + s2 * 128)> -// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]] +// CHECK: %[[LINID:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) // CHECK: memref.store {{.*}}[%[[LINID]]] // ----- @@ -99,13 +92,32 @@ func.func @distribute_thread_forall_single_thread(%out : memref) } // CHECK-LABEL: func @distribute_thread_forall_single_thread +// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[TX:.+]] = gpu.thread_id x // CHECK-DAG: %[[TY:.+]] = gpu.thread_id y -// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z -// CHECK: %[[LINID:.+]] = affine.apply -// CHECK-SAME: affine_map<()[s0, s1, s2] -> (s0 + s1 * 64 + s2 * 128)> -// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]] -// CHECK: scf.for %[[I:.+]] = %[[LINID]] to %c1 step %c128 { +// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) +// CHECK: scf.for %[[I:.+]] = %[[TFLAT]] to %c1 step %c128 { +// CHECK: memref.store {{.*}}[%[[I]]] + +// ----- + +#translation_info = #iree_codegen.translation_info + +func.func @distribute_thread_forall_overhang(%out : memref) + attributes {translation_info = #translation_info} { + %c0 = arith.constant 0 : i32 + scf.forall (%arg0) in (513) { + memref.store %c0, %out[%arg0] : memref + } {mapping = [#gpu.thread]} + return +} + +// CHECK-LABEL: func @distribute_thread_forall_overhang +// CHECK-DAG: %[[C513:.+]] = arith.constant 513 : index +// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x +// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y +// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) +// CHECK: scf.for %[[I:.+]] = %[[TFLAT]] to %[[C513]] step %c128 { // CHECK: memref.store {{.*}}[%[[I]]] // ----- @@ -124,11 +136,9 @@ func.func @distribute_thread_forall_multi_dim(%out : memref) // CHECK-LABEL: func @distribute_thread_forall_multi_dim // CHECK-DAG: %[[TX:.+]] = gpu.thread_id x // CHECK-DAG: %[[TY:.+]] = gpu.thread_id y -// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z +// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64) // CHECK: scf.for %[[I:.+]] = %c0 to %c512 step %c128 { -// CHECK: %[[LINID:.+]] = affine.apply -// CHECK-SAME: affine_map<(d0)[s0, s1, s2] -> (d0 + s0 + s1 * 64 + s2 * 128)>(%[[I]]) -// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]] +// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[TFLAT]]] // CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LINID]] into (16, 8, 4) : index // CHECK: memref.store {{.*}}[%[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2] @@ -147,10 +157,5 @@ func.func @distribute_thread_forall_small_workgroup(%out : memref) } // CHECK-LABEL: func @distribute_thread_forall_small_workgroup -// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x -// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y -// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z -// CHECK: %[[LINID:.+]] = affine.apply -// CHECK-SAME: affine_map<()[s0, s1, s2] -> (s0 + s1 * 7 + s2 * 7)> -// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]] -// CHECK: memref.store {{.*}}[%[[LINID]]] +// CHECK: %[[TX:.+]] = gpu.thread_id x +// CHECK: memref.store {{.*}}[%[[TX]]] diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_shared_memory.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_shared_memory.mlir index 636add66dd0d..8f526bd4dd91 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_shared_memory.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/gpu_distribute_shared_memory.mlir @@ -49,12 +49,9 @@ module { } } -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 32 + s0 floordiv 4)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16)> -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 32 + s0 floordiv 4 + 32)> -// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 128)> -// CHECK-DAG: #[[$MAP4:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 128 + 128)> -// CHECK-DAG: #[[$MAP5:.*]] = affine_map<()[s0, s1, s2] -> (s0 * 4 + s1 * 128 + s2 * 512)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 * 4)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 + 32)> +// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 128)> // CHECK-LABEL: @shared_mem_cpy( // CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index @@ -62,24 +59,22 @@ module { // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index // CHECK-DAG: %[[TX:.*]] = gpu.thread_id x // CHECK-DAG: %[[TY:.*]] = gpu.thread_id y -// CHECK-DAG: %[[TZ:.*]] = gpu.thread_id z - -// CHECK-DAG: %[[Y0:.*]] = affine.apply #[[$MAP0]]()[%[[TX]], %[[TY]], %[[TZ]]] -// CHECK-DAG: %[[X0:.*]] = affine.apply #[[$MAP1]]()[%[[TX]]] -// CHECK: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[Y0]], %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32>, vector<1x4xf32> -// CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[Y0]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3> -// CHECK-DAG: %[[Y1:.*]] = affine.apply #[[$MAP2]]()[%[[TX]], %[[TY]], %[[TZ]]] +// CHECK: %[[TFLAT:.*]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (4, 32) +// CHECK: %[[YX:.*]]:2 = affine.delinearize_index %[[TFLAT]] into (32, 4) +// CHECK: %[[X0:.*]] = affine.apply #[[$MAP0]]()[%[[YX]]#1] +// CHECK: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[YX]]#0, %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32>, vector<1x4xf32> +// CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[YX]]#0, %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3> +// CHECK-DAG: %[[Y1:.*]] = affine.apply #[[$MAP1]]()[%[[YX]]#0] // CHECK: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[Y1]], %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32>, vector<1x4xf32> // CHECK: vector.transfer_write %[[R1]], %{{.*}}[%[[Y1]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3> -// CHECK: %[[Y1:.*]] = affine.apply #[[$MAP3]]()[%[[TX]], %[[TY]], %[[TZ]]] -// CHECK: %[[R2:.*]] = vector.transfer_read %{{.*}}[%[[Y1]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : memref<256x4xf32>, vector<1x4xf32> -// CHECK: vector.transfer_write %[[R2]], %{{.*}}[%[[Y1]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3> -// CHECK: %[[Y2:.*]] = affine.apply #[[$MAP4]]()[%[[TX]], %[[TY]], %[[TZ]]] +// CHECK: %[[R2:.*]] = vector.transfer_read %{{.*}}[%[[TFLAT]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : memref<256x4xf32>, vector<1x4xf32> +// CHECK: vector.transfer_write %[[R2]], %{{.*}}[%[[TFLAT]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3> +// CHECK: %[[Y2:.*]] = affine.apply #[[$MAP2]]()[%[[TFLAT]]] // CHECK: %[[R3:.*]] = vector.transfer_read %{{.*}}[%[[Y2]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : memref<256x4xf32>, vector<1x4xf32> // CHECK: vector.transfer_write %[[R3]], %{{.*}}[%[[Y2]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3> -// CHECK: %[[X1:.*]] = affine.apply #[[$MAP5]]()[%[[TX]], %[[TY]], %[[TZ]]] +// CHECK: %[[X1:.*]] = affine.apply #[[$MAP0]]()[%[[TFLAT]]] // CHECK: %[[R4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[X1]]], %{{.*}} {in_bounds = [true, true]} : memref<3x512xf32>, vector<1x4xf32> // CHECK: vector.transfer_write %[[R4]], %{{.*}}[%[[C0]], %[[X1]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<3x512xf32, 3> // CHECK: %[[R5:.*]] = vector.transfer_read %{{.*}}[%[[C1]], %[[X1]]], %{{.*}} {in_bounds = [true, true]} : memref<3x512xf32>, vector<1x4xf32> diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/test/transform_gpu_distribute_shared_memory.mlir b/compiler/src/iree/compiler/Codegen/Common/GPU/test/transform_gpu_distribute_shared_memory.mlir index ec765a1d5aa6..907070a35c5c 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/test/transform_gpu_distribute_shared_memory.mlir +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/test/transform_gpu_distribute_shared_memory.mlir @@ -46,20 +46,19 @@ module attributes {transform.with_named_sequence} { transform.yield } } -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 32 + s0 floordiv 4)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16)> -// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 32 + s0 floordiv 4 + 32)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 * 4)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 + 32)> // CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: @shared_mem_cpy( // CHECK-DAG: %[[TX:.*]] = gpu.thread_id x // CHECK-DAG: %[[TY:.*]] = gpu.thread_id y -// CHECK-DAG: %[[TZ:.*]] = gpu.thread_id z -// CHECK-DAG: %[[Y0:.*]] = affine.apply #[[$MAP0]]()[%[[TX]], %[[TY]], %[[TZ]]] -// CHECK-DAG: %[[X0:.*]] = affine.apply #[[$MAP1]]()[%[[TX]]] -// CHECK: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[Y0]], %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32, #hal.descriptor_type>, vector<1x4xf32> -// CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[Y0]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, #gpu.address_space> -// CHECK-DAG: %[[Y1:.*]] = affine.apply #[[$MAP2]]()[%[[TX]], %[[TY]], %[[TZ]]] +// CHECK-DAG: %[[TFLAT:.*]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (4, 32) +// CHECK-DAG: %[[YX:.*]]:2 = affine.delinearize_index %[[TFLAT]] into (32, 4) +// CHECK-DAG: %[[X0:.*]] = affine.apply #[[$MAP0]]()[%[[YX]]#1] +// CHECK: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[YX]]#0, %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32, #hal.descriptor_type>, vector<1x4xf32> +// CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[YX]]#0, %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, #gpu.address_space> +// CHECK-DAG: %[[Y1:.*]] = affine.apply #[[$MAP1]]()[%[[YX]]#0] // CHECK: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[Y1]], %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32, #hal.descriptor_type>, vector<1x4xf32> // CHECK: vector.transfer_write %[[R1]], %{{.*}}[%[[Y1]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, #gpu.address_space> // CHECK: linalg.generic diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp index cc2649823f4e..bb841bf10e72 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensions.cpp @@ -1113,16 +1113,15 @@ transform_dialect::TestGpuVectorDistribution::applyToOne( rewriter.setInsertionPointToStart(&target.getFunctionBody().front()); // This is a test op so we unsafely use thread_id x as the lane ID. In // general this should linearize the thread IDs based on the workgroup size - // and divide by the subgroup size. i.e. + // and take the modulo by the subgroup size. i.e. // - // lane_id = (tid_x + tid_y * dim_x + tid_z * dim_y * dim_x) / subgroup_size; + // lane_id = (tid_x + tid_y * dim_x + tid_z * dim_y * dim_x) % subgroup_size; Value laneId = rewriter.create(target.getLoc(), gpu::Dimension::x); + int64_t subgroupSize = getSubgroupSize(); populateGPUDistributionPatterns(patterns); - // For testing we use subgroup size = 64. - populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId, - /*subgroupSize=*/64); + populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId, subgroupSize); populateGPUDistributeNestedLayoutContractAMDGPUPatterns(patterns); if (failed(distributeVectorOps(target, patterns, options))) { return emitDefaultDefiniteFailure(target); diff --git a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td index 5219b4a2da9c..0c05178043c8 100644 --- a/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/Common/TransformExtensions/CommonExtensionsOps.td @@ -631,7 +631,8 @@ def TestGpuVectorDistribution : }]; let arguments = (ins TransformHandleTypeInterface:$target, - DefaultValuedOptionalAttr:$experimental); + DefaultValuedOptionalAttr:$experimental, + DefaultValuedOptionalAttr:$subgroup_size); let results = (outs); let assemblyFormat = [{ $target attr-dict `:` type($target)}]; diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index eaa3f7249c05..803040d0451a 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -649,21 +649,16 @@ LogicalResult DataTiledMMAAttr::populateOperandOffsetsSizesStrides( getSubgroupSize() / intrinsicLayoutThreadBound); } - // AffineDelinearizeIndexOp requires an in-bounds input index, so we bound it. - OpFoldResult threadIdBound = - builder.getIndexAttr(ShapedType::getNumElements(distributionThreadSizes)); - AffineExpr d0 = builder.getAffineDimExpr(0), d1 = builder.getAffineDimExpr(1); - OpFoldResult boundedThreadId = affine::makeComposedFoldedAffineApply( - builder, loc, {d0 % d1}, {threadId, threadIdBound}); - // Obtain the offsets from delinearization along the distributionThreadSizes. + // Use a delinearize without outer bound and throw away its initial result + // to get clamping behavior. SmallVector tileOffsets = builder .create( - loc, - getValueOrCreateConstantIndexOp(builder, loc, boundedThreadId), - getAsIndexOpFoldResult(ctx, distributionThreadSizes)) - ->getResults(); + loc, getValueOrCreateConstantIndexOp(builder, loc, threadId), + distributionThreadSizes, /*hasOuterBound=*/false) + ->getResults() + .drop_front(); if (hasDistributionOnlyDim) { // Erase the delinearized index that corresponds to the extra distribution diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp index ff4f17648aa8..75bf5e51d54c 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/Transforms.cpp @@ -209,11 +209,10 @@ LogicalResult fuseForallIntoConsumer(RewriterBase &rewriter, // Compute the total producer loop worker count (P0 * ... * Pn). Value linearConsumerIdVal = getValueOrCreateConstantIndexOp(rewriter, loc, linearId); - SmallVector producerRanges; + SmallVector producerRanges; OpFoldResult producerWorkerCount = rewriter.getIndexAttr(1); for (auto workerCount : producer.getMixedUpperBound()) { - producerRanges.push_back( - getValueOrCreateConstantIndexOp(rewriter, loc, workerCount)); + producerRanges.push_back(workerCount); producerWorkerCount = affine::makeComposedFoldedAffineApply( rewriter, loc, d0 * d1, {producerWorkerCount, workerCount}); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir index 31c5074972c6..07729a11e2b5 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/Transforms/test/distribute_mma_to_lanes.mlir @@ -387,24 +387,21 @@ func.func @data_tiled_1x1x1_tensor_multi_mma(%lhs: tensor<1x1x4x16xf32>, %rhs: t return %0 : tensor<1x1x4x16x4xf32> } -// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (d0 mod 64)> - // CHECK-LABEL: func @data_tiled_1x1x1_tensor_multi_mma // CHECK-SAME: %[[LHS:[A-Za-z0-9]+]] // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]] // CHECK-SAME: %[[ACC:[A-Za-z0-9]+]] // CHECK: scf.forall (%[[THREAD_ID:.+]]) in (64) shared_outs(%[[ACC_ARG:.+]] = %[[ACC]]) -> (tensor<1x1x4x16x4xf32>) -// CHECK: %[[ID_CLAMPED:.+]] = affine.apply #[[$MAP]](%[[THREAD_ID]]) -// CHECK-DAG: %[[IN_IDS:.+]]:2 = affine.delinearize_index %[[ID_CLAMPED]] into (4, 16) -// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[IN_IDS]]#0, %[[IN_IDS]]#1] [1, 1, 1, 1] [1, 1, 1, 1] -// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[IN_IDS]]#0, %[[IN_IDS]]#1] [1, 1, 1, 1] [1, 1, 1, 1] +// CHECK-DAG: %[[IN_IDS:.+]]:3 = affine.delinearize_index %[[THREAD_ID]] into (4, 16) +// CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]][0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2] [1, 1, 1, 1] [1, 1, 1, 1] +// CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]][0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2] [1, 1, 1, 1] [1, 1, 1, 1] // CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC_ARG]] -// CHECK-SAME: [0, 0, %[[IN_IDS]]#0, %[[IN_IDS]]#1, 0] [1, 1, 1, 1, 4] [1, 1, 1, 1, 1] +// CHECK-SAME: [0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 1, 1, 4] [1, 1, 1, 1, 1] // CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_SLICE]], %[[RHS_SLICE]], %[[ACC_SLICE]] // CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout // CHECK-SAME: : tensor<1x1x1x1xf32>, tensor<1x1x1x1xf32> into tensor<1x1x1x1x4xf32> // CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC_ARG]] -// CHECK-SAME: [0, 0, %[[IN_IDS]]#0, %[[IN_IDS]]#1, 0] [1, 1, 1, 1, 4] [1, 1, 1, 1, 1] +// CHECK-SAME: [0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 1, 1, 4] [1, 1, 1, 1, 1] // CHECK: mapping = [#gpu.thread] // ----- @@ -424,26 +421,23 @@ func.func @data_tiled_2x2x4_tensor_multi_mma_unrolled(%lhs: tensor<1x1x2x4x16x4x return %0 : tensor<1x1x2x2x4x16x4xf32> } -// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (d0 mod 64)> - // CHECK-LABEL: func @data_tiled_2x2x4_tensor_multi_mma_unrolled // CHECK-SAME: %[[LHS:[A-Za-z0-9]+]] // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]] // CHECK-SAME: %[[ACC:[A-Za-z0-9]+]] // CHECK: scf.forall (%[[THREAD_ID:.+]]) in (64) shared_outs(%[[ACC_ARG:.+]] = %[[ACC]]) -> (tensor<1x1x2x2x4x16x4xf32>) -// CHECK: %[[ID_CLAMPED:.+]] = affine.apply #[[$MAP]](%[[THREAD_ID]]) -// CHECK-DAG: %[[IN_IDS:.+]]:2 = affine.delinearize_index %[[ID_CLAMPED]] into (4, 16) +// CHECK-DAG: %[[IN_IDS:.+]]:3 = affine.delinearize_index %[[THREAD_ID]] into (4, 16) // CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]] -// CHECK-SAME: [0, 0, 0, %[[IN_IDS]]#0, %[[IN_IDS]]#1, 0] [1, 1, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1] +// CHECK-SAME: [0, 0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1] // CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]] -// CHECK-SAME: [0, 0, 0, %[[IN_IDS]]#0, %[[IN_IDS]]#1, 0] [1, 1, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1] +// CHECK-SAME: [0, 0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1] // CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC_ARG]] -// CHECK-SAME: [0, 0, 0, 0, %[[IN_IDS]]#0, %[[IN_IDS]]#1, 0] [1, 1, 2, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1] +// CHECK-SAME: [0, 0, 0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 2, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1] // CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_SLICE]], %[[RHS_SLICE]], %[[ACC_SLICE]] // CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout // CHECK-SAME: : tensor<1x1x2x1x1x4xf32>, tensor<1x1x2x1x1x4xf32> into tensor<1x1x2x2x1x1x4xf32> // CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC_ARG]] -// CHECK-SAME: [0, 0, 0, 0, %[[IN_IDS]]#0, %[[IN_IDS]]#1, 0] [1, 1, 2, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1] +// CHECK-SAME: [0, 0, 0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 2, 2, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1] // CHECK: mapping = [#gpu.thread] // ----- @@ -463,27 +457,22 @@ func.func @data_tiled_2x2x4_tensor_multi_mma_unrolled_to_subgroups(%lhs: tensor< return %0 : tensor<1x1x2x2x4x16x4xf32> } -// CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (d0 mod 128)> -// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0) -> (d0 mod 256)> - // CHECK-LABEL: func @data_tiled_2x2x4_tensor_multi_mma_unrolled_to_subgroups // CHECK-SAME: %[[LHS:[A-Za-z0-9]+]] // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]] // CHECK-SAME: %[[ACC:[A-Za-z0-9]+]] // CHECK: scf.forall (%[[THREAD_ID:.+]]) in (256) shared_outs(%[[ACC_ARG:.+]] = %[[ACC]]) -> (tensor<1x1x2x2x4x16x4xf32>) -// CHECK: %[[ID_CLAMPED_128:.+]] = affine.apply #[[$MAP]](%[[THREAD_ID]]) -// CHECK-DAG: %[[IN_IDS:.+]]:3 = affine.delinearize_index %[[ID_CLAMPED_128]] into (2, 4, 16) +// CHECK-DAG: %[[IN_IDS:.+]]:4 = affine.delinearize_index %[[THREAD_ID]] into (2, 4, 16) // CHECK-DAG: %[[LHS_SLICE:.+]] = tensor.extract_slice %[[LHS]] -// CHECK-SAME: [0, 0, %[[IN_IDS]]#0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1] +// CHECK-SAME: [0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, %[[IN_IDS]]#3, 0] [1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1] // CHECK-DAG: %[[RHS_SLICE:.+]] = tensor.extract_slice %[[RHS]] -// CHECK-SAME: [0, 0, %[[IN_IDS]]#0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, 0] [1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1] -// CHECK: %[[ID_CLAMPED_256:.+]] = affine.apply #[[$MAP1]](%[[THREAD_ID]]) -// CHECK-DAG: %[[ACC_IDS:.+]]:4 = affine.delinearize_index %[[ID_CLAMPED_256]] into (2, 2, 4, 16) +// CHECK-SAME: [0, 0, %[[IN_IDS]]#1, %[[IN_IDS]]#2, %[[IN_IDS]]#3, 0] [1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1] +// CHECK-DAG: %[[ACC_IDS:.+]]:5 = affine.delinearize_index %[[THREAD_ID]] into (2, 2, 4, 16) // CHECK-DAG: %[[ACC_SLICE:.+]] = tensor.extract_slice %[[ACC_ARG]] -// CHECK-SAME: [0, 0, %[[ACC_IDS]]#0, %[[ACC_IDS]]#1, %[[ACC_IDS]]#2, %[[ACC_IDS]]#3, 0] [1, 1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1] +// CHECK-SAME: [0, 0, %[[ACC_IDS]]#1, %[[ACC_IDS]]#2, %[[ACC_IDS]]#3, %[[ACC_IDS]]#4, 0] [1, 1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1] // CHECK: %[[MMA:.+]] = iree_gpu.multi_mma %[[LHS_SLICE]], %[[RHS_SLICE]], %[[ACC_SLICE]] // CHECK-SAME: kind = #iree_gpu.data_tiled_mma_layout} // CHECK-SAME: : tensor<1x1x1x1x1x4xf32>, tensor<1x1x1x1x1x4xf32> into tensor<1x1x1x1x1x1x4xf32> // CHECK: tensor.parallel_insert_slice %[[MMA]] into %[[ACC_ARG]] -// CHECK-SAME: [0, 0, %[[ACC_IDS]]#0, %[[ACC_IDS]]#1, %[[ACC_IDS]]#2, %[[ACC_IDS]]#3, 0] [1, 1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1] +// CHECK-SAME: [0, 0, %[[ACC_IDS]]#1, %[[ACC_IDS]]#2, %[[ACC_IDS]]#3, %[[ACC_IDS]]#4, 0] [1, 1, 1, 1, 1, 1, 4] [1, 1, 1, 1, 1, 1, 1] // CHECK: mapping = [#gpu.thread] diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp index 1640656b71a8..d5e0af1bd119 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/LLVMGPUVectorDistribute.cpp @@ -80,24 +80,18 @@ struct LLVMGPUVectorDistributePass final } } - AffineExpr x, y, z; - bindSymbols(func.getContext(), x, y, z); - // Construct the expression for linearizing the thread indices. - AffineExpr linearId = - x + workgroupSize[0] * y + workgroupSize[1] * workgroupSize[0] * z; - IRRewriter rewriter(func); rewriter.setInsertionPointToStart(&func.getFunctionBody().front()); - SmallVector threadGrid = { - rewriter.createOrFold(func.getLoc(), - gpu::Dimension::x), - rewriter.createOrFold(func.getLoc(), - gpu::Dimension::y), - rewriter.createOrFold(func.getLoc(), - gpu::Dimension::z)}; - - Value linearThreadIdVal = affine::makeComposedAffineApply( - rewriter, func.getLoc(), linearId, threadGrid); + SmallVector threadGrid = {rewriter.createOrFold( + func.getLoc(), gpu::Dimension::z), + rewriter.createOrFold( + func.getLoc(), gpu::Dimension::y), + rewriter.createOrFold( + func.getLoc(), gpu::Dimension::x)}; + std::reverse(workgroupSize.begin(), workgroupSize.end()); + + Value linearThreadIdVal = rewriter.create( + func.getLoc(), threadGrid, workgroupSize, /*disjoint=*/true); std::optional subgroupSize = getSubgroupSize(func); if (!subgroupSize) { diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp index c52ae4bcc157..5c4c3ff471dd 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensions.cpp @@ -1476,11 +1476,10 @@ transform_dialect::AMDGPUDistributeVectorsOp::applyToOne( rewriter.setInsertionPointToStart(&target.getFunctionBody().front()); Value laneId = rewriter.create(target.getLoc(), gpu::Dimension::x); + int64_t subgroupSize = getSubgroupSize(); populateGPUDistributionPatterns(patterns); - // For testing we use subgroup size = 64. - populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId, - /*subgroupSize=*/64); + populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId, subgroupSize); if (failed(distributeVectorOps(target, patterns, options))) { return emitDefaultSilenceableFailure(target); } diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td index 69e766537c0b..28bd4eebdbbd 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/TransformExtensions/LLVMGPUExtensionsOps.td @@ -699,7 +699,8 @@ def AMDGPUDistributeVectorsOp : }]; let arguments = (ins TransformHandleTypeInterface:$target, - UnitAttr:$test_conversion); + UnitAttr:$test_conversion, + DefaultValuedOptionalAttr:$subgroup_size); let results = (outs TransformHandleTypeInterface:$result); let assemblyFormat = [{ diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir index 4e9758f83c78..3f5b280b6342 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/pipeline_tile_and_fuse.mlir @@ -544,17 +544,13 @@ hal.executable public @main { } } -// CHECK: #[[$MAP:.+]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * 8 + s2 * 32)> -// CHECK: #[[$MAP1:.+]] = affine_map<()[s0, s1] -> (s0 * 8 + s1)> +// CHECK: #[[$MAP0:.+]] = affine_map<()[s0, s1] -> (s0 * 8 + s1)> // CHECK-LABEL: func @skinny_matmul_config // CHECK-DAG: %[[IDX:.+]] = gpu.thread_id x // CHECK-DAG: %[[IDY:.+]] = gpu.thread_id y -// CHECK-DAG: %[[IDZ:.+]] = gpu.thread_id z -// CHECK: %[[LINID0:.+]] = affine.apply #[[$MAP]]()[%[[IDX]], %[[IDY]], %[[IDZ]]] -// CHECK: %[[IDS:.+]]:2 = affine.delinearize_index %[[LINID0:.+]] into (4, 8) : index, index -// CHECK: %[[LINID1:.+]] = affine.apply #[[$MAP1]]()[%[[IDS]]#0, %[[IDS]]#1] +// CHECK: %[[LINID1:.+]] = affine.apply #[[$MAP0]]()[%[[IDY]], %[[IDX]]] // CHECK: scf.forall ({{.*}}) in (32, 98) { // CHECK: scf.for %{{.*}} = %c0 to %c256 step %c4 {{.*}} -> (vector<1x4xf32>) // CHECK: scf.for %{{.*}} = %[[LINID1]] to %c4 step %c32 diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir index 23c3977c8389..fd373f7ebadc 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/transpose_pipeline_test.mlir @@ -34,28 +34,25 @@ hal.executable @transpose_dispatch_0 { // CHECK-LABEL: hal.executable public @transpose_dispatch_0 // CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK-DAG: %[[D0:.*]] = gpu.thread_id x -// CHECK-DAG: %[[D1:.*]] = gpu.thread_id y -// CHECK-DAG: %[[D2:.*]] = gpu.thread_id z -// CHECK-DAG: %[[D3:.*]] = memref.alloc() : memref<32x33xf32, #gpu.address_space> -// CHECK: %[[D4:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%[[C0]]) : memref<4096x4096xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D4]], 64 : memref<4096x4096xf32, #hal.descriptor_type> -// CHECK: %[[D5:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[C0]]) : memref<4096x4096xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D5]], 64 : memref<4096x4096xf32, #hal.descriptor_type> +// CHECK-DAG: %[[TX:.*]] = gpu.thread_id x +// CHECK-DAG: %[[TY:.*]] = gpu.thread_id y +// CHECK-DAG: %[[ALLOC:.*]] = memref.alloc() : memref<32x33xf32, #gpu.address_space> +// CHECK: %[[D0:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%[[C0]]) : memref<4096x4096xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D0]], 64 : memref<4096x4096xf32, #hal.descriptor_type> +// CHECK: %[[D1:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[C0]]) : memref<4096x4096xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D1]], 64 : memref<4096x4096xf32, #hal.descriptor_type> // CHECK: gpu.barrier -// CHECK: %[[D6:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]], %[[D1]], %[[D2]]] -// CHECK: %[[D7:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] -// CHECK: %[[D8:.*]] = vector.transfer_read %[[D4]][%[[D6]], %[[D7]]], %[[CST]] {in_bounds = [true, true]} : memref<4096x4096xf32, #hal.descriptor_type>, vector<1x4xf32> -// CHECK: %[[D9:.*]] = affine.apply #{{.*}}()[%[[D0]], %[[D1]], %[[D2]]] -// CHECK: %[[D10:.*]] = affine.apply #{{.*}}()[%[[D0]]] -// CHECK: vector.transfer_write %[[D8]], %[[D3]][%[[D9]], %[[D10]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<32x33xf32, #gpu.address_space> +// CHECK: %[[D2:.*]] = affine.apply #{{.*}}()[%{{.+}}, %[[TY]]] +// CHECK: %[[D3:.*]] = affine.apply #{{.*}}()[%{{.+}}, %[[TX]]] +// CHECK: %[[D4:.*]] = vector.transfer_read %[[D0]][%[[D2]], %[[D3]]], %[[CST]] {in_bounds = [true, true]} : memref<4096x4096xf32, #hal.descriptor_type>, vector<1x4xf32> +// CHECK: %[[D5:.*]] = affine.apply #{{.*}}()[%[[TX]]] +// CHECK: vector.transfer_write %[[D4]], %[[ALLOC]][%[[TY]], %[[D5]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<32x33xf32, #gpu.address_space> // CHECK: gpu.barrier -// CHECK: %[[D11:.*]] = affine.apply #{{.*}}()[%[[D0]]] -// CHECK: %[[D12:.*]] = vector.transfer_read %[[D3]][%[[D11]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<32x33xf32, #gpu.address_space>, vector<4x1xf32> -// CHECK: %[[D13:.*]] = vector.shape_cast %[[D12]] : vector<4x1xf32> to vector<4xf32> -// CHECK: %[[D15:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D1]]] -// CHECK: %[[D16:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] -// CHECK: vector.transfer_write %[[D13]], %[[D5]][%[[D15]], %[[D16]]] {in_bounds = [true]} : vector<4xf32>, memref<4096x4096xf32, #hal.descriptor_type> +// CHECK: %[[D6:.*]] = vector.transfer_read %[[ALLOC]][%[[D5]], %[[TY]]], %[[CST]] {in_bounds = [true, true]} : memref<32x33xf32, #gpu.address_space>, vector<4x1xf32> +// CHECK: %[[D7:.*]] = vector.shape_cast %[[D6]] : vector<4x1xf32> to vector<4xf32> +// CHECK: %[[D8:.*]] = affine.apply #{{.*}}()[%{{.+}}, %[[TY]]] +// CHECK: %[[D9:.*]] = affine.apply #{{.*}}()[%{{.+}}, %[[TX]]] +// CHECK: vector.transfer_write %[[D7]], %[[D1]][%[[D8]], %[[D9]]] {in_bounds = [true]} : vector<4xf32>, memref<4096x4096xf32, #hal.descriptor_type> // ----- @@ -96,34 +93,31 @@ hal.executable @transpose_single_operand_dispatch_0_generic_768x2048 { // CHECK-LABEL: hal.executable public @transpose_single_operand_dispatch_0_generic_768x2048 // CHECK: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[D0:.*]] = gpu.thread_id x -// CHECK: %[[D1:.*]] = gpu.thread_id y -// CHECK: %[[D2:.*]] = gpu.thread_id z -// CHECK: %[[D3:.*]] = memref.alloc() : memref<32x33xf32, #gpu.address_space> -// CHECK: %[[D4:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%[[C0]]) : memref<2048x768xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D4]], 64 : memref<2048x768xf32, #hal.descriptor_type> -// CHECK: %[[D5:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[C0]]) : memref<768x2048xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D5]], 64 : memref<768x2048xf32, #hal.descriptor_type> -// CHECK: %[[D6:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%[[C0]]) : memref<768x2048xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D6]], 64 : memref<768x2048xf32, #hal.descriptor_type> +// CHECK: %[[TX:.*]] = gpu.thread_id x +// CHECK: %[[TY:.*]] = gpu.thread_id y +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<32x33xf32, #gpu.address_space> +// CHECK: %[[D0:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%[[C0]]) : memref<2048x768xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D0]], 64 : memref<2048x768xf32, #hal.descriptor_type> +// CHECK: %[[D1:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[C0]]) : memref<768x2048xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D1]], 64 : memref<768x2048xf32, #hal.descriptor_type> +// CHECK: %[[D2:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%[[C0]]) : memref<768x2048xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D2]], 64 : memref<768x2048xf32, #hal.descriptor_type> // CHECK: gpu.barrier -// CHECK: %[[D7:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]], %[[D1]], %[[D2]]] -// CHECK: %[[D8:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] -// CHECK: %[[D9:.*]] = vector.transfer_read %[[D4]][%[[D7]], %[[D8]]], %[[CST]] {in_bounds = [true, true]} : memref<2048x768xf32, #hal.descriptor_type>, vector<1x4xf32> -// CHECK: %[[D10:.*]] = affine.apply #{{.*}}()[%[[D0]], %[[D1]], %[[D2]]] -// CHECK: %[[D11:.*]] = affine.apply #{{.*}}()[%[[D0]]] -// CHECK: vector.transfer_write %[[D9]], %[[D3]][%[[D10]], %[[D11]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<32x33xf32, #gpu.address_space> +// CHECK: %[[D3:.*]] = affine.apply #{{.*}}()[%[[TX]]] +// CHECK: %[[D4:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TY]]] +// CHECK: %[[D5:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TX]]] +// CHECK: %[[D6:.*]] = vector.transfer_read %[[D0]][%[[D4]], %[[D5]]], %[[CST]] {in_bounds = [true, true]} : memref<2048x768xf32, #hal.descriptor_type>, vector<1x4xf32> +// CHECK: vector.transfer_write %[[D6]], %[[ALLOC]][%[[TY]], %[[D3]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<32x33xf32, #gpu.address_space> // CHECK: gpu.barrier -// CHECK: %[[D12:.*]] = affine.apply #{{.*}}()[%[[D0]]] -// CHECK: %[[D13:.*]] = vector.transfer_read %[[D3]][%[[D12]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<32x33xf32, #gpu.address_space>, vector<4x1xf32> -// CHECK: %[[DUP_D15:.*]] = arith.addi %[[D1]], %{{.*}} : index -// CHECK: %[[DUP_D16:.*]] = arith.addi %[[D12]], %{{.*}} : index -// CHECK: %[[D17:.*]] = vector.transfer_read %[[D5]][%[[DUP_D15]], %[[DUP_D16]]], %[[CST]] {in_bounds = [true]} : memref<768x2048xf32, #hal.descriptor_type>, vector<4xf32> -// CHECK: %[[D14:.*]] = vector.shape_cast %[[D13]] : vector<4x1xf32> to vector<4xf32> -// CHECK: %[[D19:.*]] = arith.addf %[[D14]], %[[D17]] : vector<4xf32> -// CHECK: %[[D15:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D1]]] -// CHECK: %[[D16:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] -// CHECK: vector.transfer_write %[[D19]], %[[D6]][%[[D15]], %[[D16]]] {in_bounds = [true]} : vector<4xf32>, memref<768x2048xf32, #hal.descriptor_type> +// CHECK: %[[D7:.*]] = vector.transfer_read %[[ALLOC]][%[[D3]], %[[TY]]], %[[CST]] {in_bounds = [true, true]} : memref<32x33xf32, #gpu.address_space>, vector<4x1xf32> +// CHECK: %[[D8:.*]] = arith.addi %[[TY]], %{{.*}} +// CHECK: %[[D9:.*]] = arith.addi %[[D3]], %{{.*}} +// CHECK: %[[D10:.*]] = vector.transfer_read %[[D1]][%[[D8]], %[[D9]]], %[[CST]] {in_bounds = [true]} : memref<768x2048xf32, #hal.descriptor_type>, vector<4xf32> +// CHECK: %[[D11:.*]] = vector.shape_cast %[[D7]] : vector<4x1xf32> to vector<4xf32> +// CHECK: %[[D12:.*]] = arith.addf %[[D11]], %[[D10]] : vector<4xf32> +// CHECK: %[[D13:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TY]]] +// CHECK: %[[D14:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TX]]] +// CHECK: vector.transfer_write %[[D12]], %[[D2]][%[[D13]], %[[D14]]] {in_bounds = [true]} : vector<4xf32>, memref<768x2048xf32, #hal.descriptor_type> // ----- @@ -205,34 +199,31 @@ hal.executable @transpose_3d_yes_dispatch_0_generic_10x768x2048 { // CHECK-LABEL: hal.executable public @transpose_3d_yes_dispatch_0_generic_10x768x2048 { // CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[D0:.*]] = gpu.thread_id x -// CHECK: %[[D1:.*]] = gpu.thread_id y -// CHECK: %[[D2:.*]] = gpu.thread_id z -// CHECK: %[[D3:.*]] = memref.alloc() : memref<1x32x33xf32, #gpu.address_space> -// CHECK: %[[D4:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%[[C0]]) : memref<10x2048x768xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D4]], 64 : memref<10x2048x768xf32, #hal.descriptor_type> -// CHECK: %[[D5:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[C0]]) : memref<10x768x2048xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D5]], 64 : memref<10x768x2048xf32, #hal.descriptor_type> -// CHECK: %[[D6:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%[[C0]]) : memref<10x768x2048xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D6]], 64 : memref<10x768x2048xf32, #hal.descriptor_type> +// CHECK: %[[TX:.*]] = gpu.thread_id x +// CHECK: %[[TY:.*]] = gpu.thread_id y +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<1x32x33xf32, #gpu.address_space> +// CHECK: %[[D0:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%[[C0]]) : memref<10x2048x768xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D0]], 64 : memref<10x2048x768xf32, #hal.descriptor_type> +// CHECK: %[[D1:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[C0]]) : memref<10x768x2048xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D1]], 64 : memref<10x768x2048xf32, #hal.descriptor_type> +// CHECK: %[[D2:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%[[C0]]) : memref<10x768x2048xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D2]], 64 : memref<10x768x2048xf32, #hal.descriptor_type> // CHECK: gpu.barrier -// CHECK: %[[D7:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]], %[[D1]], %[[D2]]] -// CHECK: %[[D8:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] -// CHECK: %[[D9:.*]] = vector.transfer_read %[[D4]][%{{.*}}, %[[D7]], %[[D8]]], %[[CST]] {in_bounds = [true, true, true]} : memref<10x2048x768xf32, #hal.descriptor_type>, vector<1x1x4xf32> -// CHECK: %[[D10:.*]] = affine.apply #{{.*}}()[%[[D0]], %[[D1]], %[[D2]]] -// CHECK: %[[D11:.*]] = affine.apply #{{.*}}()[%[[D0]]] -// CHECK: vector.transfer_write %[[D9]], %[[D3]][%[[C0]], %[[D10]], %[[D11]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, memref<1x32x33xf32, #gpu.address_space> +// CHECK: %[[D3:.*]] = affine.apply #{{.*}}()[%[[TX]]] +// CHECK: %[[D4:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TY]]] +// CHECK: %[[D5:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TX]]] +// CHECK: %[[D6:.*]] = vector.transfer_read %[[D0]][%{{.*}}, %[[D4]], %[[D5]]], %[[CST]] {in_bounds = [true, true, true]} : memref<10x2048x768xf32, #hal.descriptor_type>, vector<1x1x4xf32> +// CHECK: vector.transfer_write %[[D6]], %[[ALLOC]][%[[C0]], %[[TY]], %[[D3]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, memref<1x32x33xf32, #gpu.address_space> // CHECK: gpu.barrier -// CHECK: %[[D12:.*]] = affine.apply #{{.*}}()[%[[D0]]] -// CHECK: %[[D13:.*]] = vector.transfer_read %[[D3]][%[[C0]], %[[D12]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<1x32x33xf32, #gpu.address_space>, vector<4x1xf32> -// CHECK: %[[DUP_D16:.*]] = arith.addi %[[D1]], %{{.*}} : index -// CHECK: %[[DUP_D17:.*]] = arith.addi %[[D12]], %{{.*}} : index -// CHECK: %[[D18:.*]] = vector.transfer_read %[[D5]][%{{.*}}, %[[DUP_D16]], %[[DUP_D17]]], %[[CST]] {in_bounds = [true]} : memref<10x768x2048xf32, #hal.descriptor_type>, vector<4xf32> -// CHECK: %[[D15:.*]] = vector.shape_cast %[[D13]] : vector<4x1xf32> to vector<4xf32> -// CHECK: %[[D20:.*]] = arith.addf %[[D15]], %[[D18]] : vector<4xf32> -// CHECK: %[[D16:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D1]]] -// CHECK: %[[D17:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] -// CHECK: vector.transfer_write %[[D20]], %[[D6]][%{{.*}}, %[[D16]], %[[D17]]] {in_bounds = [true]} : vector<4xf32>, memref<10x768x2048xf32, #hal.descriptor_type> +// CHECK: %[[D7:.*]] = vector.transfer_read %[[ALLOC]][%[[C0]], %[[D3]], %[[TY]]], %[[CST]] {in_bounds = [true, true]} : memref<1x32x33xf32, #gpu.address_space>, vector<4x1xf32> +// CHECK: %[[D8:.*]] = arith.addi %[[TY]], %{{.*}} +// CHECK: %[[D9:.*]] = arith.addi %[[D3]], %{{.*}} +// CHECK: %[[D10:.*]] = vector.transfer_read %[[D1]][%{{.*}}, %[[D8]], %[[D9]]], %[[CST]] {in_bounds = [true]} : memref<10x768x2048xf32, #hal.descriptor_type>, vector<4xf32> +// CHECK: %[[D11:.*]] = vector.shape_cast %[[D7]] : vector<4x1xf32> to vector<4xf32> +// CHECK: %[[D12:.*]] = arith.addf %[[D11]], %[[D10]] : vector<4xf32> +// CHECK: %[[D13:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TY]]] +// CHECK: %[[D14:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TX]]] +// CHECK: vector.transfer_write %[[D12]], %[[D2]][%{{.*}}, %[[D13]], %[[D14]]] {in_bounds = [true]} : vector<4xf32>, memref<10x768x2048xf32, #hal.descriptor_type> // ----- @@ -273,35 +264,32 @@ hal.executable @transpose_3d_trans_out_dispatch_0_generic_10x2048x768 { // CHECK-LABEL: hal.executable public @transpose_3d_trans_out_dispatch_0_generic_10x2048x768 { // CHECK-DAG: %[[CST:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[D0:.*]] = gpu.thread_id x -// CHECK: %[[D1:.*]] = gpu.thread_id y -// CHECK: %[[D2:.*]] = gpu.thread_id z -// CHECK: %[[D3:.*]] = memref.alloc() : memref<1x32x33xf32, #gpu.address_space> -// CHECK: %[[D4:.*]] = memref.alloc() : memref<1x32x33xf32, #gpu.address_space> -// CHECK: %[[D5:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%[[C0]]) : memref<10x768x2048xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D5]], 64 : memref<10x768x2048xf32, #hal.descriptor_type> -// CHECK: %[[D6:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[C0]]) : memref<10x768x2048xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D6]], 64 : memref<10x768x2048xf32, #hal.descriptor_type> -// CHECK: %[[D7:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%[[C0]]) : memref<10x2048x768xf32, #hal.descriptor_type> -// CHECK: memref.assume_alignment %[[D7]], 64 : memref<10x2048x768xf32, #hal.descriptor_type> +// CHECK: %[[TX:.*]] = gpu.thread_id x +// CHECK: %[[TY:.*]] = gpu.thread_id y +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<1x32x33xf32, #gpu.address_space> +// CHECK: %[[ALLOC1:.*]] = memref.alloc() : memref<1x32x33xf32, #gpu.address_space> +// CHECK: %[[D0:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(0) alignment(64) offset(%[[C0]]) : memref<10x768x2048xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D0]], 64 : memref<10x768x2048xf32, #hal.descriptor_type> +// CHECK: %[[D1:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(1) alignment(64) offset(%[[C0]]) : memref<10x768x2048xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D1]], 64 : memref<10x768x2048xf32, #hal.descriptor_type> +// CHECK: %[[D2:.*]] = hal.interface.binding.subspan layout({{.+}}) binding(2) alignment(64) offset(%[[C0]]) : memref<10x2048x768xf32, #hal.descriptor_type> +// CHECK: memref.assume_alignment %[[D2]], 64 : memref<10x2048x768xf32, #hal.descriptor_type> // CHECK: gpu.barrier -// CHECK: %[[D8:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]], %[[D1]], %[[D2]]] -// CHECK: %[[D9:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] -// CHECK: %[[D10:.*]] = vector.transfer_read %[[D5]][%{{.*}}, %[[D8]], %[[D9]]], %[[CST]] {in_bounds = [true, true, true]} : memref<10x768x2048xf32, #hal.descriptor_type>, vector<1x1x4xf32> -// CHECK: %[[D11:.*]] = affine.apply #{{.*}}()[%[[D0]], %[[D1]], %[[D2]]] -// CHECK: %[[D12:.*]] = affine.apply #{{.*}}()[%[[D0]]] -// CHECK: vector.transfer_write %[[D10]], %[[D4]][%[[C0]], %[[D11]], %[[D12]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, memref<1x32x33xf32, #gpu.address_space> -// CHECK: %[[D13:.*]] = vector.transfer_read %[[D6]][%{{.*}}, %[[D8]], %[[D9]]], %[[CST]] {in_bounds = [true, true, true]} : memref<10x768x2048xf32, #hal.descriptor_type>, vector<1x1x4xf32> -// CHECK: vector.transfer_write %[[D13]], %[[D3]][%[[C0]], %[[D11]], %[[D12]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, memref<1x32x33xf32, #gpu.address_space> +// CHECK: %[[D3:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TY]]] +// CHECK: %[[D4:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TX]]] +// CHECK: %[[D5:.*]] = vector.transfer_read %[[D0]][%{{.*}}, %[[D3]], %[[D4]]], %[[CST]] {in_bounds = [true, true, true]} : memref<10x768x2048xf32, #hal.descriptor_type>, vector<1x1x4xf32> +// CHECK: %[[D6:.*]] = affine.apply #{{.*}}()[%[[TX]]] +// CHECK: vector.transfer_write %[[D5]], %[[ALLOC1]][%[[C0]], %[[TY]], %[[D6]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, memref<1x32x33xf32, #gpu.address_space> +// CHECK: %[[D7:.*]] = vector.transfer_read %[[D1]][%{{.*}}, %[[D3]], %[[D4]]], %[[CST]] {in_bounds = [true, true, true]} : memref<10x768x2048xf32, #hal.descriptor_type>, vector<1x1x4xf32> +// CHECK: vector.transfer_write %[[D7]], %[[ALLOC]][%[[C0]], %[[TY]], %[[D6]]] {in_bounds = [true, true, true]} : vector<1x1x4xf32>, memref<1x32x33xf32, #gpu.address_space> // CHECK: gpu.barrier -// CHECK: %[[D14:.*]] = affine.apply #{{.*}}()[%[[D0]]] -// CHECK: %[[D15:.*]] = vector.transfer_read %[[D4]][%[[C0]], %[[D14]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<1x32x33xf32, #gpu.address_space>, vector<4x1xf32> -// CHECK: %[[D16:.*]] = vector.transfer_read %[[D3]][%[[C0]], %[[D14]], %[[D1]]], %[[CST]] {in_bounds = [true, true]} : memref<1x32x33xf32, #gpu.address_space>, vector<4x1xf32> -// CHECK: %[[D17:.*]] = arith.addf %[[D15]], %[[D16]] : vector<4x1xf32> -// CHECK: %[[D19:.*]] = vector.shape_cast %[[D17]] : vector<4x1xf32> to vector<4xf32> -// CHECK: %[[D21:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D1]]] -// CHECK: %[[D22:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[D0]]] -// CHECK: vector.transfer_write %[[D19]], %[[D7]][%{{.*}}, %[[D21]], %[[D22]]] {in_bounds = [true]} : vector<4xf32>, memref<10x2048x768xf32, #hal.descriptor_type> +// CHECK: %[[D8:.*]] = vector.transfer_read %[[ALLOC1]][%[[C0]], %[[D6]], %[[TY]]], %[[CST]] {in_bounds = [true, true]} : memref<1x32x33xf32, #gpu.address_space>, vector<4x1xf32> +// CHECK: %[[D9:.*]] = vector.transfer_read %[[ALLOC]][%[[C0]], %[[D6]], %[[TY]]], %[[CST]] {in_bounds = [true, true]} : memref<1x32x33xf32, #gpu.address_space>, vector<4x1xf32> +// CHECK: %[[D10:.*]] = arith.addf %[[D8]], %[[D9]] : vector<4x1xf32> +// CHECK: %[[D11:.*]] = vector.shape_cast %[[D10]] : vector<4x1xf32> to vector<4xf32> +// CHECK: %[[D12:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TY]]] +// CHECK: %[[D13:.*]] = affine.apply #{{.*}}()[%{{.*}}, %[[TX]]] +// CHECK: vector.transfer_write %[[D11]], %[[D2]][%{{.*}}, %[[D12]], %[[D13]]] {in_bounds = [true]} : vector<4xf32>, memref<10x2048x768xf32, #hal.descriptor_type> // -----