Skip to content

Commit

Permalink
[GPU] Use affine.linearize_index (and delinearize_index) where possib…
Browse files Browse the repository at this point in the history
…le (iree-org#19122)

There have been issues with the composition of affine maps being too
general and loosing important information, like the fact that
affine_map<(s0 + s1 * 32 + ... - (s0 floorDiv 16) * 16)> realy should be
affine_map<(s0 mod 16 + s1 * 32 + ...)>, and other issues with the
ultimate IR that block low-level arithmetic optimizations.

The affine.delinearize_index operation represents the div/mod chains
needed to break a flat index into its component parts. A recently added
affine.linearize_index operation is its inverse - combining multiple
indices into a flat 1D value.

Another advantage to linearize/delinearize is simpler upstream
canonicalizations and lead to more streamlined generated code.

This PR updates the vector distribution code and other GPU-related code
that I could find to

1. Use affine.linearize_index to construct flat thread IDs
2. Use affine.delinearize_index in places where there was a floorDiv/mod
chain.
3. Plumb the subgroup size through the transfer_read and transfer_write
distribution patterns to enable better reasoning about when you do/don't
need to take a mod of the lane ID
  • Loading branch information
krzysz00 authored Nov 26, 2024
1 parent 746ad1e commit 031accb
Show file tree
Hide file tree
Showing 15 changed files with 224 additions and 265 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "mlir/Dialect/GPU/Transforms/Passes.h"
#include "mlir/Dialect/SCF/IR/DeviceMappingInterface.h"
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"

namespace mlir::iree_compiler {

Expand Down Expand Up @@ -87,9 +88,16 @@ LogicalResult resolveGPUMappedForallOp(RewriterBase &rewriter,
assert(!(hasThreadMapping && hasWarpMapping));
Value flatId = linearThreadId;
if (hasWarpMapping) {
OpFoldResult subgroupSizeVal = rewriter.getIndexAttr(subgroupSize);
flatId = affine::makeComposedAffineApply(rewriter, loc, d0.floorDiv(d1),
{flatId, subgroupSizeVal});
if (flatWorkgroupSize % subgroupSize != 0) {
return forallOp->emitOpError(
"found warp mapped forall with non-multiple workgroup size");
}
flatId = rewriter
.create<affine::AffineDelinearizeIndexOp>(
loc, flatId,
ArrayRef<int64_t>{flatWorkgroupSize / subgroupSize,
subgroupSize})
.getResult(0);
}

SmallVector<Value> delinSizes;
Expand Down Expand Up @@ -190,23 +198,18 @@ void GPUDistributeForallPass::runOnOperation() {
return signalPassFailure();
}

AffineExpr x, y, z;
bindSymbols(funcOp.getContext(), x, y, z);
// Compute the linearized thread id.
AffineExpr linearId =
x + workgroupSize[0] * y + workgroupSize[1] * workgroupSize[0] * z;

rewriter.setInsertionPointToStart(&funcOp.getFunctionBody().front());
SmallVector<OpFoldResult> threadGrid = {
rewriter.createOrFold<gpu::ThreadIdOp>(funcOp.getLoc(),
gpu::Dimension::x),
rewriter.createOrFold<gpu::ThreadIdOp>(funcOp.getLoc(),
gpu::Dimension::y),
rewriter.createOrFold<gpu::ThreadIdOp>(funcOp.getLoc(),
gpu::Dimension::z)};

Value linearThreadIdVal = affine::makeComposedAffineApply(
rewriter, funcOp.getLoc(), linearId, threadGrid);
SmallVector<Value> threadGrid = {rewriter.createOrFold<gpu::ThreadIdOp>(
funcOp.getLoc(), gpu::Dimension::z),
rewriter.createOrFold<gpu::ThreadIdOp>(
funcOp.getLoc(), gpu::Dimension::y),
rewriter.createOrFold<gpu::ThreadIdOp>(
funcOp.getLoc(), gpu::Dimension::x)};
SmallVector<int64_t> threadGridBasis = {workgroupSize[2], workgroupSize[1],
workgroupSize[0]};

Value linearThreadIdVal = rewriter.create<affine::AffineLinearizeIndexOp>(
funcOp.getLoc(), threadGrid, threadGridBasis, /*disjoint=*/true);
for (auto forall : forallOps) {
rewriter.setInsertionPoint(forall);
if (failed(resolveGPUMappedForallOp(rewriter, forall, linearThreadIdVal,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,30 +189,29 @@ SmallVector<linalg::ProcInfo> getIds(OpBuilder &b, Location loc,
ArrayRef<Range> parallelLoopRanges,
Value flatThreadId) {
SmallVector<linalg::ProcInfo> infos;
Value id = flatThreadId;
AffineExpr d0 = b.getAffineDimExpr(0);
for (Range r : llvm::reverse(parallelLoopRanges)) {
linalg::ProcInfo info;
SmallVector<int64_t> delinSizes;
for (Range r : parallelLoopRanges) {
auto offset = dyn_cast<Attribute>(r.offset);
auto stride = dyn_cast<Attribute>(r.stride);
auto size = dyn_cast<Attribute>(r.size);
assert(offset && stride && size);
int64_t numThreadsDim = (llvm::cast<IntegerAttr>(size).getInt() -
llvm::cast<IntegerAttr>(offset).getInt()) /
llvm::cast<IntegerAttr>(stride).getInt();
Value dimId = id;
if (infos.size() != parallelLoopRanges.size() - 1)
dimId =
affine::makeComposedAffineApply(b, loc, d0 % numThreadsDim, {dimId});
delinSizes.push_back(numThreadsDim);
}
ValueRange dims =
b.create<affine::AffineDelinearizeIndexOp>(loc, flatThreadId, delinSizes)
.getResults();

for (auto [dimId, numThreadsDim] : llvm::zip_equal(dims, delinSizes)) {
linalg::ProcInfo info;
info.procId = dimId;
info.nprocs = b.create<arith::ConstantIndexOp>(loc, numThreadsDim);
info.distributionMethod =
linalg::DistributionMethod::CyclicNumProcsEqNumIters;
infos.push_back(info);
id = affine::makeComposedAffineApply(b, loc, d0.floorDiv(numThreadsDim),
{id});
}
std::reverse(infos.begin(), infos.end());
return infos;
}

Expand Down Expand Up @@ -288,19 +287,16 @@ static Value createFlatId(mlir::FunctionOpInterface funcOp,
ArrayRef<int64_t> workgroupSize) {
OpBuilder b(funcOp.getFunctionBody());
Type indexType = b.getIndexType();
AffineExpr d0 = getAffineDimExpr(0, b.getContext());
AffineExpr d1 = getAffineDimExpr(1, b.getContext());
AffineExpr d2 = getAffineDimExpr(2, b.getContext());
Value threadX =
b.create<gpu::ThreadIdOp>(funcOp.getLoc(), indexType, gpu::Dimension::x);
Value threadY =
b.create<gpu::ThreadIdOp>(funcOp.getLoc(), indexType, gpu::Dimension::y);
Value threadZ =
b.create<gpu::ThreadIdOp>(funcOp.getLoc(), indexType, gpu::Dimension::z);
Value flatThreadId = affine::makeComposedAffineApply(
b, funcOp.getLoc(),
d0 + workgroupSize[0] * d1 + (workgroupSize[0] * workgroupSize[1]) * d2,
{threadX, threadY, threadZ});
Value flatThreadId = b.create<affine::AffineLinearizeIndexOp>(
funcOp.getLoc(), ValueRange{threadZ, threadY, threadX},
ArrayRef<int64_t>{workgroupSize[2], workgroupSize[1], workgroupSize[0]},
/*disjoint=*/true);
return flatThreadId;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,9 @@ func.func @distribute_thread_forall(%out : memref<?xi32>)
// CHECK-LABEL: func @distribute_thread_forall
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z
// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
// CHECK: scf.for %[[I:.+]] = %c0 to %c1024 step %c128 {
// CHECK: %[[LINID:.+]] = affine.apply
// CHECK-SAME: affine_map<(d0)[s0, s1, s2] -> (d0 + s0 + s1 * 64 + s2 * 128)>(%[[I]])
// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]]
// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[TFLAT]]]
// CHECK: memref.store {{.*}}[%[[LINID]]]

// -----
Expand All @@ -38,11 +36,10 @@ func.func @distribute_warp_forall(%out : memref<?xi32>)
// CHECK-LABEL: func @distribute_warp_forall
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z
// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
// CHECK: %[[WARPSPLIT:.+]]:2 = affine.delinearize_index %[[TFLAT]] into (4, 32)
// CHECK: scf.for %[[I:.+]] = %c0 to %c32 step %c4 {
// CHECK: %[[LINID:.+]] = affine.apply
// CHECK-SAME: affine_map<(d0)[s0, s1, s2] -> (d0 + s1 * 2 + s2 * 4 + s0 floordiv 32)>(%[[I]])
// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]]
// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[WARPSPLIT]]#0]
// CHECK: memref.store {{.*}}[%[[LINID]]]

// -----
Expand Down Expand Up @@ -78,11 +75,7 @@ func.func @distribute_thread_forall_drop_for_loop(%out : memref<?xi32>)
// CHECK-LABEL: func @distribute_thread_forall_drop_for_loop
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z
// CHECK-NOT: scf.for
// CHECK: %[[LINID:.+]] = affine.apply
// CHECK-SAME: affine_map<()[s0, s1, s2] -> (s0 + s1 * 64 + s2 * 128)>
// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]]
// CHECK: %[[LINID:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
// CHECK: memref.store {{.*}}[%[[LINID]]]

// -----
Expand All @@ -99,13 +92,32 @@ func.func @distribute_thread_forall_single_thread(%out : memref<?xi32>)
}

// CHECK-LABEL: func @distribute_thread_forall_single_thread
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z
// CHECK: %[[LINID:.+]] = affine.apply
// CHECK-SAME: affine_map<()[s0, s1, s2] -> (s0 + s1 * 64 + s2 * 128)>
// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]]
// CHECK: scf.for %[[I:.+]] = %[[LINID]] to %c1 step %c128 {
// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
// CHECK: scf.for %[[I:.+]] = %[[TFLAT]] to %c1 step %c128 {
// CHECK: memref.store {{.*}}[%[[I]]]

// -----

#translation_info = #iree_codegen.translation_info<pipeline = LLVMGPUTileAndFuse workgroup_size = [64, 2, 1] subgroup_size = 32>

func.func @distribute_thread_forall_overhang(%out : memref<?xi32>)
attributes {translation_info = #translation_info} {
%c0 = arith.constant 0 : i32
scf.forall (%arg0) in (513) {
memref.store %c0, %out[%arg0] : memref<?xi32>
} {mapping = [#gpu.thread<linear_dim_0>]}
return
}

// CHECK-LABEL: func @distribute_thread_forall_overhang
// CHECK-DAG: %[[C513:.+]] = arith.constant 513 : index
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
// CHECK: scf.for %[[I:.+]] = %[[TFLAT]] to %[[C513]] step %c128 {
// CHECK: memref.store {{.*}}[%[[I]]]

// -----
Expand All @@ -124,11 +136,9 @@ func.func @distribute_thread_forall_multi_dim(%out : memref<?x?x?xi32>)
// CHECK-LABEL: func @distribute_thread_forall_multi_dim
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z
// CHECK: %[[TFLAT:.+]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (2, 64)
// CHECK: scf.for %[[I:.+]] = %c0 to %c512 step %c128 {
// CHECK: %[[LINID:.+]] = affine.apply
// CHECK-SAME: affine_map<(d0)[s0, s1, s2] -> (d0 + s0 + s1 * 64 + s2 * 128)>(%[[I]])
// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]]
// CHECK: %[[LINID:.+]] = affine.apply affine_map<(d0)[s0] -> (d0 + s0)>(%[[I]])[%[[TFLAT]]]
// CHECK: %[[DELIN:.+]]:3 = affine.delinearize_index %[[LINID]] into (16, 8, 4) : index
// CHECK: memref.store {{.*}}[%[[DELIN]]#0, %[[DELIN]]#1, %[[DELIN]]#2]

Expand All @@ -147,10 +157,5 @@ func.func @distribute_thread_forall_small_workgroup(%out : memref<?xi32>)
}

// CHECK-LABEL: func @distribute_thread_forall_small_workgroup
// CHECK-DAG: %[[TX:.+]] = gpu.thread_id x
// CHECK-DAG: %[[TY:.+]] = gpu.thread_id y
// CHECK-DAG: %[[TZ:.+]] = gpu.thread_id z
// CHECK: %[[LINID:.+]] = affine.apply
// CHECK-SAME: affine_map<()[s0, s1, s2] -> (s0 + s1 * 7 + s2 * 7)>
// CHECK-SAME: [%[[TX]], %[[TY]], %[[TZ]]]
// CHECK: memref.store {{.*}}[%[[LINID]]]
// CHECK: %[[TX:.+]] = gpu.thread_id x
// CHECK: memref.store {{.*}}[%[[TX]]]
Original file line number Diff line number Diff line change
Expand Up @@ -49,37 +49,32 @@ module {
}
}

// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 32 + s0 floordiv 4)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 32 + s0 floordiv 4 + 32)>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 128)>
// CHECK-DAG: #[[$MAP4:.*]] = affine_map<()[s0, s1, s2] -> (s0 + s1 * 32 + s2 * 128 + 128)>
// CHECK-DAG: #[[$MAP5:.*]] = affine_map<()[s0, s1, s2] -> (s0 * 4 + s1 * 128 + s2 * 512)>
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 * 4)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 + 32)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0] -> (s0 + 128)>
// CHECK-LABEL: @shared_mem_cpy(

// CHECK-DAG: %[[C2:.*]] = arith.constant 2 : index
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
// CHECK-DAG: %[[TX:.*]] = gpu.thread_id x
// CHECK-DAG: %[[TY:.*]] = gpu.thread_id y
// CHECK-DAG: %[[TZ:.*]] = gpu.thread_id z

// CHECK-DAG: %[[Y0:.*]] = affine.apply #[[$MAP0]]()[%[[TX]], %[[TY]], %[[TZ]]]
// CHECK-DAG: %[[X0:.*]] = affine.apply #[[$MAP1]]()[%[[TX]]]
// CHECK: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[Y0]], %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32>, vector<1x4xf32>
// CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[Y0]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3>
// CHECK-DAG: %[[Y1:.*]] = affine.apply #[[$MAP2]]()[%[[TX]], %[[TY]], %[[TZ]]]
// CHECK: %[[TFLAT:.*]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (4, 32)
// CHECK: %[[YX:.*]]:2 = affine.delinearize_index %[[TFLAT]] into (32, 4)
// CHECK: %[[X0:.*]] = affine.apply #[[$MAP0]]()[%[[YX]]#1]
// CHECK: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[YX]]#0, %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32>, vector<1x4xf32>
// CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[YX]]#0, %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3>
// CHECK-DAG: %[[Y1:.*]] = affine.apply #[[$MAP1]]()[%[[YX]]#0]
// CHECK: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[Y1]], %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32>, vector<1x4xf32>
// CHECK: vector.transfer_write %[[R1]], %{{.*}}[%[[Y1]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, 3>

// CHECK: %[[Y1:.*]] = affine.apply #[[$MAP3]]()[%[[TX]], %[[TY]], %[[TZ]]]
// CHECK: %[[R2:.*]] = vector.transfer_read %{{.*}}[%[[Y1]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : memref<256x4xf32>, vector<1x4xf32>
// CHECK: vector.transfer_write %[[R2]], %{{.*}}[%[[Y1]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3>
// CHECK: %[[Y2:.*]] = affine.apply #[[$MAP4]]()[%[[TX]], %[[TY]], %[[TZ]]]
// CHECK: %[[R2:.*]] = vector.transfer_read %{{.*}}[%[[TFLAT]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : memref<256x4xf32>, vector<1x4xf32>
// CHECK: vector.transfer_write %[[R2]], %{{.*}}[%[[TFLAT]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3>
// CHECK: %[[Y2:.*]] = affine.apply #[[$MAP2]]()[%[[TFLAT]]]
// CHECK: %[[R3:.*]] = vector.transfer_read %{{.*}}[%[[Y2]], %[[C0]]], %{{.*}} {in_bounds = [true, true]} : memref<256x4xf32>, vector<1x4xf32>
// CHECK: vector.transfer_write %[[R3]], %{{.*}}[%[[Y2]], %[[C0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<256x4xf32, 3>

// CHECK: %[[X1:.*]] = affine.apply #[[$MAP5]]()[%[[TX]], %[[TY]], %[[TZ]]]
// CHECK: %[[X1:.*]] = affine.apply #[[$MAP0]]()[%[[TFLAT]]]
// CHECK: %[[R4:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[X1]]], %{{.*}} {in_bounds = [true, true]} : memref<3x512xf32>, vector<1x4xf32>
// CHECK: vector.transfer_write %[[R4]], %{{.*}}[%[[C0]], %[[X1]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<3x512xf32, 3>
// CHECK: %[[R5:.*]] = vector.transfer_read %{{.*}}[%[[C1]], %[[X1]]], %{{.*}} {in_bounds = [true, true]} : memref<3x512xf32>, vector<1x4xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,20 +46,19 @@ module attributes {transform.with_named_sequence} {
transform.yield
}
}
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 32 + s0 floordiv 4)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 4 - (s0 floordiv 4) * 16)>
// CHECK-DAG: #[[$MAP2:.*]] = affine_map<()[s0, s1, s2] -> (s1 * 8 + s2 * 32 + s0 floordiv 4 + 32)>
// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 * 4)>
// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 + 32)>
// CHECK-DAG: #[[$MAP3:.*]] = affine_map<(d0, d1) -> (d0, d1)>
// CHECK-LABEL: @shared_mem_cpy(
// CHECK-DAG: %[[TX:.*]] = gpu.thread_id x
// CHECK-DAG: %[[TY:.*]] = gpu.thread_id y
// CHECK-DAG: %[[TZ:.*]] = gpu.thread_id z

// CHECK-DAG: %[[Y0:.*]] = affine.apply #[[$MAP0]]()[%[[TX]], %[[TY]], %[[TZ]]]
// CHECK-DAG: %[[X0:.*]] = affine.apply #[[$MAP1]]()[%[[TX]]]
// CHECK: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[Y0]], %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32, #hal.descriptor_type<storage_buffer>>, vector<1x4xf32>
// CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[Y0]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, #gpu.address_space<workgroup>>
// CHECK-DAG: %[[Y1:.*]] = affine.apply #[[$MAP2]]()[%[[TX]], %[[TY]], %[[TZ]]]
// CHECK-DAG: %[[TFLAT:.*]] = affine.linearize_index disjoint [%[[TY]], %[[TX]]] by (4, 32)
// CHECK-DAG: %[[YX:.*]]:2 = affine.delinearize_index %[[TFLAT]] into (32, 4)
// CHECK-DAG: %[[X0:.*]] = affine.apply #[[$MAP0]]()[%[[YX]]#1]
// CHECK: %[[R0:.*]] = vector.transfer_read %{{.*}}[%[[YX]]#0, %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32, #hal.descriptor_type<storage_buffer>>, vector<1x4xf32>
// CHECK: vector.transfer_write %[[R0]], %{{.*}}[%[[YX]]#0, %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, #gpu.address_space<workgroup>>
// CHECK-DAG: %[[Y1:.*]] = affine.apply #[[$MAP1]]()[%[[YX]]#0]
// CHECK: %[[R1:.*]] = vector.transfer_read %{{.*}}[%[[Y1]], %[[X0]]], %{{.*}} {in_bounds = [true, true]} : memref<64x16xf32, #hal.descriptor_type<storage_buffer>>, vector<1x4xf32>
// CHECK: vector.transfer_write %[[R1]], %{{.*}}[%[[Y1]], %[[X0]]] {in_bounds = [true, true]} : vector<1x4xf32>, memref<64x16xf32, #gpu.address_space<workgroup>>
// CHECK: linalg.generic
Original file line number Diff line number Diff line change
Expand Up @@ -1113,16 +1113,15 @@ transform_dialect::TestGpuVectorDistribution::applyToOne(
rewriter.setInsertionPointToStart(&target.getFunctionBody().front());
// This is a test op so we unsafely use thread_id x as the lane ID. In
// general this should linearize the thread IDs based on the workgroup size
// and divide by the subgroup size. i.e.
// and take the modulo by the subgroup size. i.e.
//
// lane_id = (tid_x + tid_y * dim_x + tid_z * dim_y * dim_x) / subgroup_size;
// lane_id = (tid_x + tid_y * dim_x + tid_z * dim_y * dim_x) % subgroup_size;
Value laneId =
rewriter.create<gpu::ThreadIdOp>(target.getLoc(), gpu::Dimension::x);
int64_t subgroupSize = getSubgroupSize();

populateGPUDistributionPatterns(patterns);
// For testing we use subgroup size = 64.
populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId,
/*subgroupSize=*/64);
populateGPUDistributeNestedLayoutAttrPatterns(patterns, laneId, subgroupSize);
populateGPUDistributeNestedLayoutContractAMDGPUPatterns(patterns);
if (failed(distributeVectorOps(target, patterns, options))) {
return emitDefaultDefiniteFailure(target);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,8 @@ def TestGpuVectorDistribution :
}];

let arguments = (ins TransformHandleTypeInterface:$target,
DefaultValuedOptionalAttr<BoolAttr, "false">:$experimental);
DefaultValuedOptionalAttr<BoolAttr, "false">:$experimental,
DefaultValuedOptionalAttr<I64Attr, "64">:$subgroup_size);
let results = (outs);

let assemblyFormat = [{ $target attr-dict `:` type($target)}];
Expand Down
Loading

0 comments on commit 031accb

Please sign in to comment.