From a77b3be926d1d352f89688f6bb28e98d860dc1ad Mon Sep 17 00:00:00 2001 From: Alan Li Date: Wed, 7 Aug 2024 03:18:29 +0000 Subject: [PATCH] Add `expand_shape` to GPU encoding pipeline Signed-off-by: Alan Li --- .../compiler/Codegen/Common/EncodingUtils.h | 3 + .../compiler/Codegen/Common/GPU/BUILD.bazel | 2 +- .../Common/GPU/GPUMaterializeEncoding.cpp | 177 +++++++++++++++--- .../compiler/Codegen/LLVMCPU/test/BUILD.bazel | 1 + .../Codegen/LLVMCPU/test/CMakeLists.txt | 1 + .../Codegen/LLVMCPU/test/data_tile.mlir | 27 +++ 6 files changed, 179 insertions(+), 32 deletions(-) create mode 100644 compiler/src/iree/compiler/Codegen/LLVMCPU/test/data_tile.mlir diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h index 10996a46babfa..b57cb90615418 100644 --- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h +++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h @@ -20,6 +20,9 @@ struct MaterializeEncodingInfo { SmallVector innerTileSizes; SmallVector outerDimsPerm; unsigned srcRank = 0; + // Metadata for a generalized expand_shape + transpose + SmallVector innerTileShapes; + SmallVector permutation; }; using MaterializeEncodingFn = diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel index 52d83acbc7df6..bfc0ee207569a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel @@ -59,8 +59,8 @@ iree_compiler_cc_library( "GPUDistributionPatterns.cpp", "GPUGeneralizeNamedOps.cpp", "GPULowerToUKernels.cpp", - "GPUMultiBuffering.cpp", "GPUMaterializeEncoding.cpp", + "GPUMultiBuffering.cpp", "GPUNestedLayoutDistributionPatterns.cpp", "GPUPatterns.cpp", "GPUPipelining.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp index 3a76f44dea692..08013c0335c2f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp @@ -6,10 +6,16 @@ #include "iree/compiler/Codegen/Common/EncodingUtils.h" #include "iree/compiler/Codegen/Common/GPU/Passes.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/IR/Operation.h" + #define DEBUG_TYPE "iree-codegen-gpu-materialize-encoding" namespace mlir::iree_compiler { @@ -30,16 +36,49 @@ static std::optional getIntrinsicSize(TypeRange elementTypes) { // TODO: Query the value from GPU attributes. // TODO: Define a struct with meaningful name for the pair. -std::optional> -getIntrinsicVectorSize(TypeRange elementTypes, int64_t roleIdx) { +SmallVector getIntrinsicVectorSize(TypeRange elementTypes, + int64_t roleIdx) { Type lhs = elementTypes[0]; Type rhs = elementTypes[1]; Type out = elementTypes[2]; if (lhs.isF32() && rhs.isF32() && out.isF32()) { - if (roleIdx == 0 || roleIdx == 1) return std::make_pair(1, 1); - if (roleIdx == 2) return std::make_pair(4, 1); + if (roleIdx == 0 || roleIdx == 1) { + return {1, 1}; + } + if (roleIdx == 2) { + return {4, 1}; + } + } + return {}; +} + +SmallVector getTransposePermutation(int64_t roleIdx) { + switch (roleIdx) { + case 0: // A + case 1: // B + // OuterTileX x InnerTileX x OuterTileY x InnerTileY + // -> OuterTileY x OuterTileX x InnerTileY x InnerTileX + return {2, 0, 3, 1}; + case 2: // C + // ACC: + // OuterTileX x InnerTileX x OuterTileY x InnerTileY + // -> OuterTileX x OuterTileY x InnerTileX x InnerTileY + return {0, 2, 1, 3}; + default: + llvm_unreachable("unexpected roleIdx"); + } +} + +SmallVector getReverseTransposePermutation(int64_t roleIdx) { + switch (roleIdx) { + case 0: // A + case 1: // B + return {1, 2, 0, 3}; + case 2: // C + return {0, 2, 1, 3}; + default: + llvm_unreachable("unexpected roleIdx"); } - return std::nullopt; } // TODO(hanchung): Pass an ExecutableTargetAttr attribute for the target @@ -88,7 +127,17 @@ materializeEncodingForTarget(RankedTensorType tensorType) { // Map the matmul TileMxNxK to an actual tile shape for the tensor at hand, // based on its operand index in the matmul. auto rank = tensorType.getRank(); - return getEncodingInfoForMatmul(encoding, rank, enumeratedTileMxNxK[0]); + + auto encodingInfo = + getEncodingInfoForMatmul(encoding, rank, enumeratedTileMxNxK[0]); + + // insert inner tile shapes and permutation info + auto roleIdx = encoding.getOperandIndex().getInt(); + auto intrinsicVectorSizes = getIntrinsicVectorSize(elementTypes, roleIdx); + auto permutation = getTransposePermutation(roleIdx); + encodingInfo.innerTileShapes = intrinsicVectorSizes; + encodingInfo.permutation = permutation; + return encodingInfo; } namespace { @@ -114,6 +163,7 @@ struct GPUSetEncodingOpLoweringConversion getTypeConverter()); MaterializeEncodingFn materializeEncodingFn = converter->getMaterializeEncodingFn(); + auto packOp = lowerSetEncodingOpToPackOp( rewriter, encodingOp, adaptor.getSource(), materializeEncodingFn, this->materializeEncodingValueFn); @@ -136,6 +186,9 @@ struct GPUSetEncodingOpLoweringConversion "unhandled result encoding"); } SmallVector innerTiles = maybeEncodingInfo->innerTileSizes; + SmallVector intrinsicVectorShape = + maybeEncodingInfo->innerTileShapes; + SmallVector transposePermutation = maybeEncodingInfo->permutation; // TODO(hanchung): Add a util to the encoding attribute, so we don't need // the map_to_vector method here. @@ -144,26 +197,26 @@ struct GPUSetEncodingOpLoweringConversion auto elemTypes = llvm::map_to_vector( encoding.getElementTypes().getValue(), [](Attribute a) { return cast(a).getValue(); }); + auto loc = encodingOp.getLoc(); + std::optional intrinsicShape = getIntrinsicSize(elemTypes); - std::optional> intrinsicVectorShape = - getIntrinsicVectorSize(elemTypes, roleIdx); - if (!intrinsicShape || !intrinsicVectorShape) { + if (!intrinsicShape || intrinsicVectorShape.empty()) { return failure(); } SmallVector targetShape; // for unrolling - switch(roleIdx) { - case 0: - targetShape = {intrinsicShape->M, intrinsicShape->K}; - break; - case 1: - targetShape = {intrinsicShape->N, intrinsicShape->K}; - break; - case 2: - targetShape = {intrinsicShape->M, intrinsicShape->N}; - break; - default: - return failure(); + switch (roleIdx) { + case 0: // A + targetShape = {intrinsicShape->M, intrinsicShape->K}; + break; + case 1: // B + targetShape = {intrinsicShape->N, intrinsicShape->K}; + break; + case 2: // C + targetShape = {intrinsicShape->M, intrinsicShape->N}; + break; + default: + return failure(); } assert(innerTiles.size() == targetShape.size()); @@ -175,15 +228,57 @@ struct GPUSetEncodingOpLoweringConversion assert(packedShape == targetShape); } - // TODO(lialan): create expand_shape. Take LHS as an example: - // 16x4xf32 -> 16x1x4x1. Because the vector size used in the intrinsic is - // 1x1. - // For C-Matrix (i.e., ACC), it is 16x16xf32 -> 4x4x16x1xf32. Because the - // vector size is 4x1. + // Check that the dimensions of the matrix can be divided by the tile shape, + // if not then bail out. + auto sourceShape = packOp->getResult().getType().getShape(); + assert(sourceShape.size() == 4); + // inner most two dimensions must be divisible. + if (sourceShape[2] % innerTiles[0] != 0 || + sourceShape[3] % innerTiles[1] != 0) { + return failure(); + } + + // Create expand_shape, to tile the inner most two dimensions. + llvm::SmallVector expandShapeShape; + assert(intrinsicVectorShape.size() == 2); // TODO: relax this + auto iT1 = intrinsicVectorShape[0]; + auto iT2 = intrinsicVectorShape[1]; + auto oT1 = sourceShape[2] / iT1; + auto oT2 = sourceShape[3] / iT2; + expandShapeShape = {sourceShape[0], sourceShape[1], oT1, iT1, oT2, iT2}; + assert(expandShapeShape.size() == 6); + RankedTensorType expandShapeType = RankedTensorType::get( + expandShapeShape, encodingOp.getSourceType().getElementType()); - // TODO(lialan): create linalg.transpose op. - // LHS: 16x1x4x1 -> 4x16x1x1 (perm = [2, 0, 3, 1]) - // ACC: 4x4x16x1 -> 4x16x4x1 (perm = [0, 2, 1, 3]) + SmallVector reassociationMap(4); + reassociationMap[0].push_back(rewriter.getAffineDimExpr(0)); + reassociationMap[1].push_back(rewriter.getAffineDimExpr(1)); + reassociationMap[2].push_back(rewriter.getAffineDimExpr(2)); + reassociationMap[2].push_back(rewriter.getAffineDimExpr(3)); + reassociationMap[3].push_back(rewriter.getAffineDimExpr(4)); + reassociationMap[3].push_back(rewriter.getAffineDimExpr(5)); + auto expandShapeOp = rewriter.create( + loc, expandShapeType, packOp->getResult(), reassociationMap); + + // create linalg.transpose on expandShapeShape + SmallVector transposeResultDims; + transposeResultDims.push_back(expandShapeShape[0]); + transposeResultDims.push_back(expandShapeShape[1]); + for (int i = 0; i < transposePermutation.size(); i++) { + transposeResultDims.push_back( + expandShapeShape[2 + transposePermutation[i]]); + } + SmallVector newTransposePermutation(transposePermutation.size() + + 2); + newTransposePermutation[0] = 0; + newTransposePermutation[1] = 1; + for (int i = 0; i < transposePermutation.size(); i++) { + newTransposePermutation[i + 2] = transposePermutation[i] + 2; + } + auto emptyTensor = rewriter.create( + loc, transposeResultDims, encodingOp.getSourceType().getElementType()); + auto transposeOp = rewriter.create( + loc, expandShapeOp, emptyTensor, newTransposePermutation); // TODO(hanchung): We want to make the shape consistent, so we need to // collpase and expand the shape. This is the shape we materialize for Flow @@ -195,8 +290,28 @@ struct GPUSetEncodingOpLoweringConversion // LHS: 64 -> 16x4 (innerTiles[0]xinnerTiles[1]) // ACC: 256 -> 16x16 (innerTiles[0]xinnerTiles[1]) - // TODO(lialan): Replace the op with the tensor.expand_shape op. - rewriter.replaceOp(encodingOp, packOp->getResult()); + SmallVector reverse_transpose_permutation = {0, 1, 3, 4, 2, 5}; + auto reverseEmptyTensor = rewriter.create( + loc, expandShapeShape, encodingOp.getSourceType().getElementType()); + auto reverseTransposeOp = rewriter.create( + loc, transposeOp->getResult(0), reverseEmptyTensor, + reverse_transpose_permutation); + + // collapse tiled dimensions + SmallVector reassoc; + reassoc.push_back(ReassociationIndices{0}); + reassoc.push_back(ReassociationIndices{1}); + reassoc.push_back(ReassociationIndices{2, 3}); + reassoc.push_back(ReassociationIndices{4, 5}); + + RankedTensorType revertShapeType = RankedTensorType::get( + sourceShape, encodingOp.getSourceType().getElementType()); + + auto collapseShapeOp = rewriter.create( + loc, revertShapeType, reverseTransposeOp->getResult(0), + ArrayRef(reassoc)); + + rewriter.replaceOp(encodingOp, collapseShapeOp); return success(); } }; diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel index cd9c4e923ec55..9e2c3d9036ce5 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/BUILD.bazel @@ -28,6 +28,7 @@ iree_lit_test_suite( "check_ir_before_llvm_conversion.mlir", "check_ir_before_llvm_conversion_not_fail_unbound.mlir", "convert_to_llvm.mlir", + "data_tile.mlir", "emit_vectorization_remarks.mlir", "expand_f16_op_to_f32.mlir", "hal_executable_constants.mlir", diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt index de22f621fd83f..8a2ede1990ecb 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/CMakeLists.txt @@ -23,6 +23,7 @@ iree_lit_test_suite( "check_ir_before_llvm_conversion.mlir" "check_ir_before_llvm_conversion_not_fail_unbound.mlir" "convert_to_llvm.mlir" + "data_tile.mlir" "emit_vectorization_remarks.mlir" "expand_f16_op_to_f32.mlir" "hal_executable_constants.mlir" diff --git a/compiler/src/iree/compiler/Codegen/LLVMCPU/test/data_tile.mlir b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/data_tile.mlir new file mode 100644 index 0000000000000..8b42ee9e8bb10 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMCPU/test/data_tile.mlir @@ -0,0 +1,27 @@ +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-materialize-device-encoding))" --split-input-file %s | FileCheck %s +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer> + ]> +]> +func.func @set_encoding() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], round_dims_to = array>>> + %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<255x513xf32> + %3 = iree_encoding.set_encoding %2 : tensor<255x513xf32> -> tensor<255x513xf32, #iree_encoding.encoding, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], round_dims_to = array>> + flow.dispatch.tensor.store %3, %1, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : tensor<255x513xf32, #iree_encoding.encoding, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], round_dims_to = array>> -> !flow.dispatch.tensor, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], round_dims_to = array>>> + return +} + + +// CHECK-LABEL: func.func @set_encoding +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<16x129x16x4xf32> +// CHECK: %[[PACK:.*]] = tensor.pack %2 padding_value(%cst : f32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 4] into %[[EMPTY]] : tensor<255x513xf32> -> tensor<16x129x16x4xf32> +// CHECK: %[[EXPAND:.*]] = tensor.expand_shape %[[PACK]] +// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose ins(%[[EXPAND]] +// -- the following generated code is just to make the pipeline working. +// CHECK: %[[REVERSE_TRANSPOSE:.*]] = linalg.transpose ins(%[[TRANSPOSE]] +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[REVERSE_TRANSPOSE]] +// CHECK: flow.dispatch.tensor.store %[[COLLAPSE]]