diff --git a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h index 10996a46babfa..b57cb90615418 100644 --- a/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h +++ b/compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h @@ -20,6 +20,9 @@ struct MaterializeEncodingInfo { SmallVector innerTileSizes; SmallVector outerDimsPerm; unsigned srcRank = 0; + // Metadata for a generalized expand_shape + transpose + SmallVector innerTileShapes; + SmallVector permutation; }; using MaterializeEncodingFn = diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel index 52d83acbc7df6..bfc0ee207569a 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel @@ -59,8 +59,8 @@ iree_compiler_cc_library( "GPUDistributionPatterns.cpp", "GPUGeneralizeNamedOps.cpp", "GPULowerToUKernels.cpp", - "GPUMultiBuffering.cpp", "GPUMaterializeEncoding.cpp", + "GPUMultiBuffering.cpp", "GPUNestedLayoutDistributionPatterns.cpp", "GPUPatterns.cpp", "GPUPipelining.cpp", diff --git a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp index 1d4c7e89c724e..548051d20c26f 100644 --- a/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp +++ b/compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp @@ -10,9 +10,13 @@ #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Tensor/Transforms/Transforms.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/Dialect/Utils/ReshapeOpsUtils.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/IR/Operation.h" + #define DEBUG_TYPE "iree-codegen-gpu-materialize-encoding" namespace mlir::iree_compiler { @@ -33,18 +37,48 @@ static std::optional getIntrinsicSize(TypeRange elementTypes) { // TODO: Query the value from GPU attributes. // TODO: Define a struct with meaningful name for the pair. -std::optional> -getIntrinsicVectorSize(TypeRange elementTypes, int64_t roleIdx) { +SmallVector getIntrinsicVectorSize(TypeRange elementTypes, + int64_t roleIdx) { Type lhs = elementTypes[0]; Type rhs = elementTypes[1]; Type out = elementTypes[2]; if (lhs.isF32() && rhs.isF32() && out.isF32()) { - if (roleIdx == 0 || roleIdx == 1) - return std::make_pair(1, 1); - if (roleIdx == 2) - return std::make_pair(4, 1); + if (roleIdx == 0 || roleIdx == 1) { + return {1, 1}; + } + if (roleIdx == 2) { + return {4, 1}; + } + } + return {}; +} + +// Given encoding's role index and element types, return the transpose +// permutation used in GPU materialization. +SmallVector getTransposePermutation(int64_t roleIdx, + TypeRange elementTypes) { + // For now, check that all types are f32: + Type lhs = elementTypes[0]; + Type rhs = elementTypes[1]; + Type out = elementTypes[2]; + if (!lhs.isF32() || !rhs.isF32() || !out.isF32()) { + return {}; + } + + switch (roleIdx) { + case 0: // A + case 1: // B + // OuterTileX x InnerTileX x OuterTileY x InnerTileY + // -> OuterTileY x OuterTileX x InnerTileY x InnerTileX + return {2, 0, 3, 1}; + case 2: // C + // ACC: + // OuterTileX x InnerTileX x OuterTileY x InnerTileY + // -> OuterTileX x OuterTileY x InnerTileX x InnerTileY + return {0, 2, 1, 3}; + default: + return {}; } - return std::nullopt; } // TODO(hanchung): Pass an ExecutableTargetAttr attribute for the target @@ -93,7 +127,17 @@ materializeEncodingForTarget(RankedTensorType tensorType) { // Map the matmul TileMxNxK to an actual tile shape for the tensor at hand, // based on its operand index in the matmul. auto rank = tensorType.getRank(); - return getEncodingInfoForMatmul(encoding, rank, enumeratedTileMxNxK[0]); + + auto encodingInfo = + getEncodingInfoForMatmul(encoding, rank, enumeratedTileMxNxK[0]); + + // insert inner tile shapes and permutation info + auto roleIdx = encoding.getOperandIndex().getInt(); + auto intrinsicVectorSizes = getIntrinsicVectorSize(elementTypes, roleIdx); + auto permutation = getTransposePermutation(roleIdx, elementTypes); + encodingInfo.innerTileShapes = intrinsicVectorSizes; + encodingInfo.permutation = permutation; + return encodingInfo; } namespace { @@ -119,6 +163,7 @@ struct GPUSetEncodingOpLoweringConversion getTypeConverter()); MaterializeEncodingFn materializeEncodingFn = converter->getMaterializeEncodingFn(); + auto packOp = lowerSetEncodingOpToPackOp( rewriter, encodingOp, adaptor.getSource(), materializeEncodingFn, this->materializeEncodingValueFn); @@ -141,6 +186,8 @@ struct GPUSetEncodingOpLoweringConversion "unhandled result encoding"); } SmallVector innerTiles = maybeEncodingInfo->innerTileSizes; + SmallVector intrinsicVectorShape = + maybeEncodingInfo->innerTileShapes; // TODO(hanchung): Add a util to the encoding attribute, so we don't need // the map_to_vector method here. @@ -150,10 +197,9 @@ struct GPUSetEncodingOpLoweringConversion encoding.getElementTypes().getValue(), [](Attribute a) { return cast(a).getValue(); }); auto loc = encodingOp.getLoc(); + std::optional intrinsicShape = getIntrinsicSize(elemTypes); - std::optional> intrinsicVectorShape = - getIntrinsicVectorSize(elemTypes, roleIdx); - if (!intrinsicShape || !intrinsicVectorShape) { + if (!intrinsicShape || intrinsicVectorShape.empty()) { return failure(); } @@ -181,50 +227,74 @@ struct GPUSetEncodingOpLoweringConversion assert(packedShape == targetShape); } - // Check that the dimensions of the matrix can be divided by the tile shape, - // if not then bail out. - auto sourceType = encodingOp.getSourceType().getShape(); - assert(sourceType.size() == 2); - if (sourceType[0] % innerTiles[0] == 0 || - sourceType[1] % innerTiles[1] == 0) { - return failure(); + // Create expand_shape op to tile the innermost two dimensions. + auto sourceShape = packOp->getDestType().getShape(); + assert(intrinsicVectorShape.size() == 2); // TODO: relax this + auto iT1 = intrinsicVectorShape[0]; + auto iT2 = intrinsicVectorShape[1]; + auto oT1 = sourceShape[2] / iT1; + auto oT2 = sourceShape[3] / iT2; + SmallVector expandShapeShape = { + sourceShape[0], sourceShape[1], oT1, iT1, oT2, iT2}; + assert(expandShapeShape.size() == 6); + auto expandShapeType = RankedTensorType::get( + expandShapeShape, encodingOp.getSourceType().getElementType()); + + std::optional> reassociationMap = + getReassociationIndicesForReshape(packOp->getDestType(), + expandShapeType); + assert(reassociationMap.has_value()); + auto expandShapeOp = rewriter.create( + loc, expandShapeType, packOp->getResult(), *reassociationMap); + + // create linalg.transpose on expandShapeShape + size_t origRank = origRank = encodingOp.getSourceType().getRank(); + + SmallVector transposePerm; + transposePerm.push_back(0); + transposePerm.push_back(1); + for (auto perm : maybeEncodingInfo->permutation) { + transposePerm.push_back(origRank + perm); } + SmallVector transposeResultDims = expandShapeShape; + applyPermutationToVector(transposeResultDims, transposePerm); - // Create expand_shape - llvm::SmallVector expandShapeShape; - auto [iM, iK] = *intrinsicVectorShape; - auto oM = sourceType[0] / iM; - auto oK = sourceType[1] / iK; - expandShapeShape = {oM, iM, oK, iK}; - assert(expandShapeShape.size() == 4); - RankedTensorType expandShapeType = - RankedTensorType::Builder(encodingOp.getSourceType()) - .setShape(expandShapeShape); - Value expandShapeOp = rewriter.create( - loc, expandShapeType, packOp->getResult()); - - // create linalg.transpose - // LHS: 16x1x4x1 -> 4x16x1x1 (perm = [2, 0, 3, 1]) - // ACC: 4x4x16x1 -> 4x16x4x1 (perm = [0, 2, 1, 3]) - auto permutation = roleIdx == 2 ? ArrayRef{0, 2, 1, 3} - : ArrayRef{2, 0, 3, 1}; auto emptyTensor = rewriter.create( - loc, expandShapeShape, encodingOp.getSourceType()); - [[maybe_unused]] auto transposeOp = rewriter.create( - loc, expandShapeOp, emptyTensor, permutation); - - // TODO(hanchung): We want to make the shape consistent, so we need to - // collpase and expand the shape. This is the shape we materialize for Flow - // and HAL ops. - // 1. Create tensor.collapse_shape. - // LHS: 4x16x1x1 -> 64 - // ACC: 4x16x4x1 -> 256 - // 2. Create tensor.expand_shape to recover the shape (i.e., innerTiles). - // LHS: 64 -> 16x4 (innerTiles[0]xinnerTiles[1]) - // ACC: 256 -> 16x16 (innerTiles[0]xinnerTiles[1]) - - // TODO(lialan): Replace the op with the tensor.expand_shape op. - rewriter.replaceOp(encodingOp, packOp->getResult()); + loc, transposeResultDims, encodingOp.getSourceType().getElementType()); + auto transposeOp = rewriter.create( + loc, expandShapeOp, emptyTensor, transposePerm); + + // We want to make the shape consistent, so we need to append it with a + // `collapse_shape` and a `expand_shape`, just to be conformant with how we + // materialize for Flow and HAL op. + + // 1. collapse tiled dimensions into one dim + SmallVector collapsedShape = {sourceShape[0], sourceShape[1], + sourceShape[2] * sourceShape[3]}; + auto revertShapeType = RankedTensorType::get( + collapsedShape, encodingOp.getSourceType().getElementType()); + + std::optional> collapseReassoc = + getReassociationIndicesForReshape(emptyTensor.getType(), + revertShapeType); + assert(collapseReassoc.has_value()); + + auto collapseShapeOp = rewriter.create( + loc, revertShapeType, transposeOp->getResult(0), *collapseReassoc); + + // 2. expand the collapsed shape to the shape intended by the encoding + assert(innerTiles.size() == 2); // TODO: relax this + auto expandTileShapeType = RankedTensorType::get( + {sourceShape[0], sourceShape[1], innerTiles[0], innerTiles[1]}, + encodingOp.getSourceType().getElementType()); + std::optional> tileAssoc = + getReassociationIndicesForReshape(collapseShapeOp.getType(), + expandTileShapeType); + assert(tileAssoc.has_value()); + auto expandTileShapeOp = rewriter.create( + loc, expandTileShapeType, collapseShapeOp, *tileAssoc); + + rewriter.replaceOp(encodingOp, expandTileShapeOp); return success(); } }; diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel index 057ff1c7e2edc..16b4c72d3d076 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/BUILD.bazel @@ -37,6 +37,7 @@ iree_lit_test_suite( "config_winograd.mlir", "extract_address_computation_gpu.mlir", "gpu_set_num_workgroups.mlir", + "gpu_materialize_encoding.mlir", "gpu_pipeline_generalize_named_ops.mlir", "nvvm_extract_address_computation.mlir", "nvvm_pipeline_test.mlir", diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt index 5f603cfbf14ee..9fd707d8d11e7 100644 --- a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/CMakeLists.txt @@ -33,6 +33,7 @@ iree_lit_test_suite( "distribute_to_thread.mlir" "elementwise_pipeline.mlir" "extract_address_computation_gpu.mlir" + "gpu_materialize_encoding.mlir" "gpu_pipeline_generalize_named_ops.mlir" "gpu_set_num_workgroups.mlir" "illegal_configuration.mlir" diff --git a/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_materialize_encoding.mlir b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_materialize_encoding.mlir new file mode 100644 index 0000000000000..9df01277cc944 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/LLVMGPU/test/gpu_materialize_encoding.mlir @@ -0,0 +1,79 @@ +// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-materialize-device-encoding))" --split-input-file %s | FileCheck %s + +#encoding = #iree_encoding.encoding, + user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], + round_dims_to = array> + +#pipeline_layout = #hal.pipeline.layout, + #hal.descriptor_set.binding<1, storage_buffer> + ]> +]> +func.func @set_encoding_LHS() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<255x513xf32> + %3 = iree_encoding.set_encoding %2 : tensor<255x513xf32> -> tensor<255x513xf32, #encoding> + flow.dispatch.tensor.store %3, %1, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : tensor<255x513xf32, #encoding> -> !flow.dispatch.tensor> + return +} + +// CHECK-LABEL: func.func @set_encoding_LHS +// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<33x64x16x4xf32> +// CHECK: %[[PACK:.*]] = tensor.pack %2 padding_value(%cst : f32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 4] into %[[EMPTY]] : tensor<255x513xf32> -> tensor<33x64x16x4xf32> +// CHECK: %[[EXPAND_LHS:.*]] = tensor.expand_shape %[[PACK]] +// CHECK-SAME: output_shape [33, 64, 16, 1, 4, 1] : tensor<33x64x16x4xf32> into tensor<33x64x16x1x4x1xf32> +// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose ins(%[[EXPAND_LHS]] + +// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]] +// CHECK: %[[EXPAND_LHS_2:.*]] = tensor.expand_shape %[[COLLAPSE]] +// CHECK: flow.dispatch.tensor.store %[[EXPAND_LHS_2]] + +//--------- + +func.func @set_encoding_RHS() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<255x513xf32> + %3 = iree_encoding.set_encoding %2 : tensor<255x513xf32> -> tensor<255x513xf32, #encoding> + flow.dispatch.tensor.store %3, %1, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : tensor<255x513xf32, #encoding> -> !flow.dispatch.tensor> + return +} + +// CHECK-LABEL: func.func @set_encoding_RHS +// CHECK: %[[EMPTY_RHS:.*]] = tensor.empty() : tensor<33x64x16x4xf32> +// CHECK: %[[PACK_RHS:.*]] = tensor.pack %2 padding_value(%cst : f32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 4] into %3 : tensor<255x513xf32> -> tensor<33x64x16x4xf32> +// CHECK: %[[EXPAND_RHS:.*]] = tensor.expand_shape %[[PACK_RHS]] +// CHECK-SAME: output_shape [33, 64, 16, 1, 4, 1] : tensor<33x64x16x4xf32> into tensor<33x64x16x1x4x1xf32> +// CHECK: %[[EMPTY_RHS2:.*]] = tensor.empty() : tensor<33x64x4x16x1x1xf32> +// CHECK: %[[TRANSPOSE_RHS:.*]] = linalg.transpose ins(%[[EXPAND_RHS]] + +// CHECK: %[[COLLAPSE_RHS:.*]] = tensor.collapse_shape %[[TRANSPOSE_RHS]] +// CHECK: %[[EXPAND_RHS_2:.*]] = tensor.expand_shape %[[COLLAPSE_RHS]] +// CHECK: flow.dispatch.tensor.store %[[EXPAND_RHS_2]] + +//--------- + +func.func @set_encoding_ACC() { + %c0 = arith.constant 0 : index + %0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor> + %1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor> + %2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : !flow.dispatch.tensor> -> tensor<255x513xf32> + %3 = iree_encoding.set_encoding %2 : tensor<255x513xf32> -> tensor<255x513xf32, #encoding> + flow.dispatch.tensor.store %3, %1, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : tensor<255x513xf32, #encoding> -> !flow.dispatch.tensor> + return +} + +// CHECK-LABEL: func.func @set_encoding_ACC +// CHECK: %[[EMPTY_ACC:.*]] = tensor.empty() : tensor<33x64x16x4xf32> +// CHECK: %[[PACK_ACC:.*]] = tensor.pack %2 padding_value(%cst : f32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 4] into %[[EMPTY_ACC]] : tensor<255x513xf32> -> tensor<33x64x16x4xf32> +// CHECK: %[[EXPAND_ACC:.*]] = tensor.expand_shape %[[PACK_ACC]] +// CHECK: %[[EMPTY_ACC2:.*]] = tensor.empty() : tensor<33x64x4x16x1x1xf32> +// CHECK: %[[TRANSPOSE_ACC:.*]] = linalg.transpose ins(%[[EXPAND_ACC]] : tensor<33x64x16x1x4x1xf32>) outs(%[[EMPTY_ACC2]] : tensor<33x64x4x16x1x1xf32>) permutation = [0, 1, 4, 2, 5, 3] + +// CHECK: %[[COLLAPSE_RHS:.*]] = tensor.collapse_shape %[[TRANSPOSE_ACC]] +// CHECK: %[[EXPAND_ACC_2:.*]] = tensor.expand_shape %[[COLLAPSE_RHS]] +// CHECK: flow.dispatch.tensor.store %[[EXPAND_ACC_2]]