From d8d140764c40c328537e374e9143ca6bab8f5a14 Mon Sep 17 00:00:00 2001 From: Han-Chung Wang Date: Wed, 31 Jul 2024 10:43:49 -0700 Subject: [PATCH] [GPU] Fix offsets calculation formula in MultiMmaOp distribution. (#18055) It was ``` vtid: virtual thread id tid: lane id vtid = (tid floordiv stride_i) mod size_i ``` However, it does not take `element` into account. Each thread grabs `element` contiguous data, so the vtid needs to be multiplied by `element` to get the next bunch of data. I.e., it becomes ``` vtid: virtual thread id tid: lane id vtid = ((tid floordiv stride_i) mod size_i) * element_i ``` Fixes https://github.com/iree-org/iree/issues/17973 --------- Signed-off-by: hanhanW --- .../Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp | 17 ++++++++++------- .../test/distribute_multi_mma.mlir | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 86e7bc074fb9..966bfdae93b6 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -677,27 +677,30 @@ static LogicalResult populateCanonicalOffsetsSizesAndStrides( OpFoldResult one = builder.getIndexAttr(1); canonicalStrides.append(rankReducedShape.size(), one); + // Each thread grabs `element` contiguous data, so the vtid needs to be + // multiplied by `element` to get the next bunch of data. // vtid: virtual thread id // tid: lane id - // vtid = (tid floordiv stride_i) mod size_i. + // vtid = ((tid floordiv stride_i) mod size_i) * element_i. SmallVector vtids; - for (auto [dimSize, dimStride] : - llvm::zip_equal(subgroupLayout.thread, subgroupLayout.tstrides)) { + for (auto [dimSize, dimStride, element] : + llvm::zip_equal(subgroupLayout.thread, subgroupLayout.tstrides, + subgroupLayout.element)) { if (dimSize == 1) { vtids.push_back(zero); } - // (tid floordiv stride) mod size + // ((tid floordiv stride) mod size) * element. AffineExpr tidExpr = builder.getAffineDimExpr(0); AffineMap vtidMap = AffineMap::get( - /*dims=*/1, /*syms=*/0, tidExpr.floorDiv(dimStride) % dimSize); + /*dims=*/1, /*syms=*/0, + (tidExpr.floorDiv(dimStride) % dimSize) * element); Value vtid = builder.create(loc, vtidMap, laneId); vtids.push_back(vtid); } int64_t idx = 0; - for (auto [thread, element] : - llvm::zip_equal(subgroupLayout.thread, subgroupLayout.element)) { + for (int64_t element : subgroupLayout.element) { canonicalSizes.push_back(builder.getIndexAttr(element)); canonicalOffsets.push_back(vtids[idx++]); } diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/distribute_multi_mma.mlir b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/distribute_multi_mma.mlir index ccf4082f7609..a723155b8bf7 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/distribute_multi_mma.mlir +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/TransformExtensions/test/distribute_multi_mma.mlir @@ -29,7 +29,7 @@ module attributes { transform.with_named_sequence } { } // CHECK-DAG: #[[$MAP:.+]] = affine_map<(d0) -> (d0 mod 16)> -// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0) -> ((d0 floordiv 16) mod 4)> +// CHECK-DAG: #[[$MAP1:.+]] = affine_map<(d0) -> ((d0 floordiv 16) * 4 - ((d0 floordiv 16) floordiv 4) * 16)> // CHECK-LABEL: func @distribute_multi_mma_16x16x16 // CHECK-SAME: %[[LHS:[A-Za-z0-9]+]]: tensor<2x2x16x16xf16> // CHECK-SAME: %[[RHS:[A-Za-z0-9]+]]: tensor<2x2x16x16xf16>