Skip to content

Commit

Permalink
Add expand_shape to encoding (#18135)
Browse files Browse the repository at this point in the history
Signed-off-by: Alan Li <[email protected]>
  • Loading branch information
lialan authored Aug 15, 2024
1 parent 9cc6549 commit caf1b2a
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 44 deletions.
3 changes: 3 additions & 0 deletions compiler/src/iree/compiler/Codegen/Common/EncodingUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ struct MaterializeEncodingInfo {
SmallVector<int64_t> innerTileSizes;
SmallVector<int64_t> outerDimsPerm;
unsigned srcRank = 0;
// Metadata for a generalized expand_shape + transpose
SmallVector<int64_t> innerTileShapes;
SmallVector<int64_t> permutation;
};

using MaterializeEncodingFn =
Expand Down
2 changes: 1 addition & 1 deletion compiler/src/iree/compiler/Codegen/Common/GPU/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -59,8 +59,8 @@ iree_compiler_cc_library(
"GPUDistributionPatterns.cpp",
"GPUGeneralizeNamedOps.cpp",
"GPULowerToUKernels.cpp",
"GPUMultiBuffering.cpp",
"GPUMaterializeEncoding.cpp",
"GPUMultiBuffering.cpp",
"GPUNestedLayoutDistributionPatterns.cpp",
"GPUPatterns.cpp",
"GPUPipelining.cpp",
Expand Down
184 changes: 141 additions & 43 deletions compiler/src/iree/compiler/Codegen/Common/GPU/GPUMaterializeEncoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,17 @@

#include "iree/compiler/Codegen/Common/EncodingUtils.h"
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/Dialect/Utils/ReshapeOpsUtils.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "mlir/IR/Operation.h"

#define DEBUG_TYPE "iree-codegen-gpu-materialize-encoding"

namespace mlir::iree_compiler {
Expand All @@ -30,16 +37,48 @@ static std::optional<TileMxNxK> getIntrinsicSize(TypeRange elementTypes) {

// TODO: Query the value from GPU attributes.
// TODO: Define a struct with meaningful name for the pair.
std::optional<std::pair<int64_t, int64_t>>
getIntrinsicVectorSize(TypeRange elementTypes, int64_t roleIdx) {
SmallVector<int64_t> getIntrinsicVectorSize(TypeRange elementTypes,
int64_t roleIdx) {
Type lhs = elementTypes[0];
Type rhs = elementTypes[1];
Type out = elementTypes[2];
if (lhs.isF32() && rhs.isF32() && out.isF32()) {
if (roleIdx == 0 || roleIdx == 1) return std::make_pair(1, 1);
if (roleIdx == 2) return std::make_pair(4, 1);
if (roleIdx == 0 || roleIdx == 1) {
return {1, 1};
}
if (roleIdx == 2) {
return {4, 1};
}
}
return {};
}

// Given encoding's role index and element types, return the transpose
// permutation used in GPU materialization.
SmallVector<int64_t> getTransposePermutation(int64_t roleIdx,
TypeRange elementTypes) {
// For now, check that all types are f32:
Type lhs = elementTypes[0];
Type rhs = elementTypes[1];
Type out = elementTypes[2];
if (!lhs.isF32() || !rhs.isF32() || !out.isF32()) {
return {};
}

switch (roleIdx) {
case 0: // A
case 1: // B
// OuterTileX x InnerTileX x OuterTileY x InnerTileY
// -> OuterTileY x OuterTileX x InnerTileY x InnerTileX
return {2, 0, 3, 1};
case 2: // C
// ACC:
// OuterTileX x InnerTileX x OuterTileY x InnerTileY
// -> OuterTileX x OuterTileY x InnerTileX x InnerTileY
return {0, 2, 1, 3};
default:
return {};
}
return std::nullopt;
}

// TODO(hanchung): Pass an ExecutableTargetAttr attribute for the target
Expand Down Expand Up @@ -88,7 +127,17 @@ materializeEncodingForTarget(RankedTensorType tensorType) {
// Map the matmul TileMxNxK to an actual tile shape for the tensor at hand,
// based on its operand index in the matmul.
auto rank = tensorType.getRank();
return getEncodingInfoForMatmul(encoding, rank, enumeratedTileMxNxK[0]);

auto encodingInfo =
getEncodingInfoForMatmul(encoding, rank, enumeratedTileMxNxK[0]);

// insert inner tile shapes and permutation info
auto roleIdx = encoding.getOperandIndex().getInt();
auto intrinsicVectorSizes = getIntrinsicVectorSize(elementTypes, roleIdx);
auto permutation = getTransposePermutation(roleIdx, elementTypes);
encodingInfo.innerTileShapes = intrinsicVectorSizes;
encodingInfo.permutation = permutation;
return encodingInfo;
}

namespace {
Expand All @@ -114,6 +163,7 @@ struct GPUSetEncodingOpLoweringConversion
getTypeConverter());
MaterializeEncodingFn materializeEncodingFn =
converter->getMaterializeEncodingFn();

auto packOp = lowerSetEncodingOpToPackOp(
rewriter, encodingOp, adaptor.getSource(), materializeEncodingFn,
this->materializeEncodingValueFn);
Expand All @@ -136,6 +186,8 @@ struct GPUSetEncodingOpLoweringConversion
"unhandled result encoding");
}
SmallVector<int64_t> innerTiles = maybeEncodingInfo->innerTileSizes;
SmallVector<int64_t> intrinsicVectorShape =
maybeEncodingInfo->innerTileShapes;

// TODO(hanchung): Add a util to the encoding attribute, so we don't need
// the map_to_vector method here.
Expand All @@ -144,26 +196,26 @@ struct GPUSetEncodingOpLoweringConversion
auto elemTypes = llvm::map_to_vector(
encoding.getElementTypes().getValue(),
[](Attribute a) { return cast<TypeAttr>(a).getValue(); });
auto loc = encodingOp.getLoc();

std::optional<TileMxNxK> intrinsicShape = getIntrinsicSize(elemTypes);
std::optional<std::pair<int64_t, int64_t>> intrinsicVectorShape =
getIntrinsicVectorSize(elemTypes, roleIdx);
if (!intrinsicShape || !intrinsicVectorShape) {
if (!intrinsicShape || intrinsicVectorShape.empty()) {
return failure();
}

SmallVector<int64_t> targetShape; // for unrolling
switch(roleIdx) {
case 0:
targetShape = {intrinsicShape->M, intrinsicShape->K};
break;
case 1:
targetShape = {intrinsicShape->N, intrinsicShape->K};
break;
case 2:
targetShape = {intrinsicShape->M, intrinsicShape->N};
break;
default:
return failure();
switch (roleIdx) {
case 0: // A
targetShape = {intrinsicShape->M, intrinsicShape->K};
break;
case 1: // B
targetShape = {intrinsicShape->N, intrinsicShape->K};
break;
case 2: // C
targetShape = {intrinsicShape->M, intrinsicShape->N};
break;
default:
return failure();
}

assert(innerTiles.size() == targetShape.size());
Expand All @@ -175,28 +227,74 @@ struct GPUSetEncodingOpLoweringConversion
assert(packedShape == targetShape);
}

// TODO(lialan): create expand_shape. Take LHS as an example:
// 16x4xf32 -> 16x1x4x1. Because the vector size used in the intrinsic is
// 1x1.
// For C-Matrix (i.e., ACC), it is 16x16xf32 -> 4x4x16x1xf32. Because the
// vector size is 4x1.

// TODO(lialan): create linalg.transpose op.
// LHS: 16x1x4x1 -> 4x16x1x1 (perm = [2, 0, 3, 1])
// ACC: 4x4x16x1 -> 4x16x4x1 (perm = [0, 2, 1, 3])

// TODO(hanchung): We want to make the shape consistent, so we need to
// collpase and expand the shape. This is the shape we materialize for Flow
// and HAL ops.
// 1. Create tensor.collapse_shape.
// LHS: 4x16x1x1 -> 64
// ACC: 4x16x4x1 -> 256
// 2. Create tensor.expand_shape to recover the shape (i.e., innerTiles).
// LHS: 64 -> 16x4 (innerTiles[0]xinnerTiles[1])
// ACC: 256 -> 16x16 (innerTiles[0]xinnerTiles[1])

// TODO(lialan): Replace the op with the tensor.expand_shape op.
rewriter.replaceOp(encodingOp, packOp->getResult());
// Create expand_shape op to tile the innermost two dimensions.
auto sourceShape = packOp->getDestType().getShape();
assert(intrinsicVectorShape.size() == 2); // TODO: relax this
auto iT1 = intrinsicVectorShape[0];
auto iT2 = intrinsicVectorShape[1];
auto oT1 = sourceShape[2] / iT1;
auto oT2 = sourceShape[3] / iT2;
SmallVector<int64_t> expandShapeShape = {
sourceShape[0], sourceShape[1], oT1, iT1, oT2, iT2};
assert(expandShapeShape.size() == 6);
auto expandShapeType = RankedTensorType::get(
expandShapeShape, encodingOp.getSourceType().getElementType());

std::optional<SmallVector<ReassociationIndices>> reassociationMap =
getReassociationIndicesForReshape(packOp->getDestType(),
expandShapeType);
assert(reassociationMap.has_value());
auto expandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
loc, expandShapeType, packOp->getResult(), *reassociationMap);

// create linalg.transpose on expandShapeShape
size_t origRank = encodingOp.getSourceType().getRank();

SmallVector<int64_t> transposePerm;
transposePerm.push_back(0);
transposePerm.push_back(1);
for (auto perm : maybeEncodingInfo->permutation) {
transposePerm.push_back(origRank + perm);
}
SmallVector<int64_t> transposeResultDims = expandShapeShape;
applyPermutationToVector(transposeResultDims, transposePerm);

auto emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, transposeResultDims, encodingOp.getSourceType().getElementType());
auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, expandShapeOp, emptyTensor, transposePerm);

// We want to make the shape consistent, so we need to append it with a
// `collapse_shape` and a `expand_shape`, just to be conformant with how we
// materialize for Flow and HAL op.

// 1. collapse tiled dimensions into one dim
SmallVector<int64_t> collapsedShape = {sourceShape[0], sourceShape[1],
sourceShape[2] * sourceShape[3]};
auto revertShapeType = RankedTensorType::get(
collapsedShape, encodingOp.getSourceType().getElementType());

std::optional<SmallVector<ReassociationIndices>> collapseReassoc =
getReassociationIndicesForReshape(emptyTensor.getType(),
revertShapeType);
assert(collapseReassoc.has_value());

auto collapseShapeOp = rewriter.create<tensor::CollapseShapeOp>(
loc, revertShapeType, transposeOp->getResult(0), *collapseReassoc);

// 2. expand the collapsed shape to the shape intended by the encoding
assert(innerTiles.size() == 2); // TODO: relax this
auto expandTileShapeType = RankedTensorType::get(
{sourceShape[0], sourceShape[1], innerTiles[0], innerTiles[1]},
encodingOp.getSourceType().getElementType());
std::optional<SmallVector<ReassociationIndices>> tileAssoc =
getReassociationIndicesForReshape(collapseShapeOp.getType(),
expandTileShapeType);
assert(tileAssoc.has_value());
auto expandTileShapeOp = rewriter.create<tensor::ExpandShapeOp>(
loc, expandTileShapeType, collapseShapeOp, *tileAssoc);

rewriter.replaceOp(encodingOp, expandTileShapeOp);
return success();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ iree_lit_test_suite(
"config_winograd.mlir",
"extract_address_computation_gpu.mlir",
"gpu_set_num_workgroups.mlir",
"gpu_materialize_encoding.mlir",
"gpu_pipeline_generalize_named_ops.mlir",
"nvvm_extract_address_computation.mlir",
"nvvm_pipeline_test.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ iree_lit_test_suite(
"distribute_to_thread.mlir"
"elementwise_pipeline.mlir"
"extract_address_computation_gpu.mlir"
"gpu_materialize_encoding.mlir"
"gpu_pipeline_generalize_named_ops.mlir"
"gpu_set_num_workgroups.mlir"
"illegal_configuration.mlir"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-materialize-device-encoding))" --split-input-file %s | FileCheck %s

#encoding = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<255x513xf32>,
user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>],
round_dims_to = array<i64: 16, 16, 16>>

#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>
]>
]>
func.func @set_encoding_LHS() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<255x513xf32>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<255x513xf32, #encoding>>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<255x513xf32>> -> tensor<255x513xf32>
%3 = iree_encoding.set_encoding %2 : tensor<255x513xf32> -> tensor<255x513xf32, #encoding>
flow.dispatch.tensor.store %3, %1, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : tensor<255x513xf32, #encoding> -> !flow.dispatch.tensor<writeonly:tensor<255x513xf32, #encoding>>
return
}

// CHECK-LABEL: func.func @set_encoding_LHS
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<33x64x16x4xf32>
// CHECK: %[[PACK:.*]] = tensor.pack %2 padding_value(%cst : f32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 4] into %[[EMPTY]] : tensor<255x513xf32> -> tensor<33x64x16x4xf32>
// CHECK: %[[EXPAND_LHS:.*]] = tensor.expand_shape %[[PACK]]
// CHECK-SAME: output_shape [33, 64, 16, 1, 4, 1] : tensor<33x64x16x4xf32> into tensor<33x64x16x1x4x1xf32>
// CHECK: %[[EMPTY_LHS2:.*]] = tensor.empty() : tensor<33x64x4x16x1x1xf32>
// CHECK: %[[TRANSPOSE:.*]] = linalg.transpose ins(%[[EXPAND_LHS]] : tensor<33x64x16x1x4x1xf32>) outs(%[[EMPTY_LHS2]] : tensor<33x64x4x16x1x1xf32>) permutation = [0, 1, 4, 2, 5, 3]
// CHECK: %[[COLLAPSE:.*]] = tensor.collapse_shape %[[TRANSPOSE]]
// CHECK: %[[EXPAND_LHS_2:.*]] = tensor.expand_shape %[[COLLAPSE]]
// CHECK: flow.dispatch.tensor.store %[[EXPAND_LHS_2]]

//---------

func.func @set_encoding_RHS() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<255x513xf32>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<255x513xf32, #encoding>>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<255x513xf32>> -> tensor<255x513xf32>
%3 = iree_encoding.set_encoding %2 : tensor<255x513xf32> -> tensor<255x513xf32, #encoding>
flow.dispatch.tensor.store %3, %1, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : tensor<255x513xf32, #encoding> -> !flow.dispatch.tensor<writeonly:tensor<255x513xf32, #encoding>>
return
}

// CHECK-LABEL: func.func @set_encoding_RHS
// CHECK: %[[EMPTY_RHS:.*]] = tensor.empty() : tensor<33x64x16x4xf32>
// CHECK: %[[PACK_RHS:.*]] = tensor.pack %2 padding_value(%cst : f32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 4] into %3 : tensor<255x513xf32> -> tensor<33x64x16x4xf32>
// CHECK: %[[EXPAND_RHS:.*]] = tensor.expand_shape %[[PACK_RHS]]
// CHECK-SAME: output_shape [33, 64, 16, 1, 4, 1] : tensor<33x64x16x4xf32> into tensor<33x64x16x1x4x1xf32>
// CHECK: %[[EMPTY_RHS2:.*]] = tensor.empty() : tensor<33x64x4x16x1x1xf32>
// CHECK: %[[TRANSPOSE_RHS:.*]] = linalg.transpose ins(%[[EXPAND_RHS]] : tensor<33x64x16x1x4x1xf32>) outs(%[[EMPTY_RHS2]] : tensor<33x64x4x16x1x1xf32>) permutation = [0, 1, 4, 2, 5, 3]
// CHECK: %[[COLLAPSE_RHS:.*]] = tensor.collapse_shape %[[TRANSPOSE_RHS]]
// CHECK: %[[EXPAND_RHS_2:.*]] = tensor.expand_shape %[[COLLAPSE_RHS]]
// CHECK: flow.dispatch.tensor.store %[[EXPAND_RHS_2]]

//---------

func.func @set_encoding_ACC() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<255x513xf32>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<255x513xf32, #encoding>>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<255x513xf32>> -> tensor<255x513xf32>
%3 = iree_encoding.set_encoding %2 : tensor<255x513xf32> -> tensor<255x513xf32, #encoding>
flow.dispatch.tensor.store %3, %1, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : tensor<255x513xf32, #encoding> -> !flow.dispatch.tensor<writeonly:tensor<255x513xf32, #encoding>>
return
}

// CHECK-LABEL: func.func @set_encoding_ACC
// CHECK: %[[EMPTY_ACC:.*]] = tensor.empty() : tensor<33x64x16x4xf32>
// CHECK: %[[PACK_ACC:.*]] = tensor.pack %2 padding_value(%cst : f32) outer_dims_perm = [1, 0] inner_dims_pos = [1, 0] inner_tiles = [16, 4] into %[[EMPTY_ACC]] : tensor<255x513xf32> -> tensor<33x64x16x4xf32>
// CHECK: %[[EXPAND_ACC:.*]] = tensor.expand_shape %[[PACK_ACC]]
// CHECK: %[[EMPTY_ACC2:.*]] = tensor.empty() : tensor<33x64x4x16x1x1xf32>
// CHECK: %[[TRANSPOSE_ACC:.*]] = linalg.transpose ins(%[[EXPAND_ACC]] : tensor<33x64x16x1x4x1xf32>) outs(%[[EMPTY_ACC2]] : tensor<33x64x4x16x1x1xf32>) permutation = [0, 1, 4, 2, 5, 3]

// CHECK: %[[COLLAPSE_RHS:.*]] = tensor.collapse_shape %[[TRANSPOSE_ACC]]
// CHECK: %[[EXPAND_ACC_2:.*]] = tensor.expand_shape %[[COLLAPSE_RHS]]
// CHECK: flow.dispatch.tensor.store %[[EXPAND_ACC_2]]

0 comments on commit caf1b2a

Please sign in to comment.