Skip to content

Commit

Permalink
Add expand_shape to encoding
Browse files Browse the repository at this point in the history
Signed-off-by: Alan Li <[email protected]>
  • Loading branch information
lialan committed Aug 8, 2024
1 parent 6b85ee7 commit 644bd0f
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 22 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@

#include "iree/compiler/Codegen/Common/EncodingUtils.h"
#include "iree/compiler/Codegen/Common/GPU/Passes.h"
#include "llvm/ADT/SmallVector.h"
#include "mlir/Dialect/MemRef/Transforms/Transforms.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/Dialect/Tensor/Transforms/Transforms.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-codegen-gpu-materialize-encoding"
Expand Down Expand Up @@ -36,12 +39,31 @@ getIntrinsicVectorSize(TypeRange elementTypes, int64_t roleIdx) {
Type rhs = elementTypes[1];
Type out = elementTypes[2];
if (lhs.isF32() && rhs.isF32() && out.isF32()) {
if (roleIdx == 0 || roleIdx == 1) return std::make_pair(1, 1);
if (roleIdx == 2) return std::make_pair(4, 1);
if (roleIdx == 0 || roleIdx == 1)
return std::make_pair(1, 1);
if (roleIdx == 2)
return std::make_pair(4, 1);
}
return std::nullopt;
}

ArrayRef<int64_t> getTransposePermutation(int64_t roleIdx) {
switch (roleIdx) {
case 0: // A
case 1: // B
// OuterTileX x InnerTileX x OuterTileY x InnerTileY
// -> OuterTileY x OuterTileX x InnerTileY x InnerTileX
return {2, 0, 3, 1};
case 2: // C
// ACC:
// OuterTileX x InnerTileX x OuterTileY x InnerTileY
// -> OuterTileX x OuterTileY x InnerTileX x InnerTileY
return {0, 2, 1, 3};
default:
llvm_unreachable("unexpected roleIdx");
}
}

// TODO(hanchung): Pass an ExecutableTargetAttr attribute for the target
// encoding. Here we assume that every mfma op is available.
// TODO(hanchung): Handle wmma ops.
Expand Down Expand Up @@ -144,6 +166,7 @@ struct GPUSetEncodingOpLoweringConversion
auto elemTypes = llvm::map_to_vector(
encoding.getElementTypes().getValue(),
[](Attribute a) { return cast<TypeAttr>(a).getValue(); });
auto loc = encodingOp.getLoc();
std::optional<TileMxNxK> intrinsicShape = getIntrinsicSize(elemTypes);
std::optional<std::pair<int64_t, int64_t>> intrinsicVectorShape =
getIntrinsicVectorSize(elemTypes, roleIdx);
Expand All @@ -152,18 +175,18 @@ struct GPUSetEncodingOpLoweringConversion
}

SmallVector<int64_t> targetShape; // for unrolling
switch(roleIdx) {
case 0:
targetShape = {intrinsicShape->M, intrinsicShape->K};
break;
case 1:
targetShape = {intrinsicShape->N, intrinsicShape->K};
break;
case 2:
targetShape = {intrinsicShape->M, intrinsicShape->N};
break;
default:
return failure();
switch (roleIdx) {
case 0: // A
targetShape = {intrinsicShape->M, intrinsicShape->K};
break;
case 1: // B
targetShape = {intrinsicShape->N, intrinsicShape->K};
break;
case 2: // C
targetShape = {intrinsicShape->M, intrinsicShape->N};
break;
default:
return failure();
}

assert(innerTiles.size() == targetShape.size());
Expand All @@ -175,15 +198,33 @@ struct GPUSetEncodingOpLoweringConversion
assert(packedShape == targetShape);
}

// TODO(lialan): create expand_shape. Take LHS as an example:
// 16x4xf32 -> 16x1x4x1. Because the vector size used in the intrinsic is
// 1x1.
// For C-Matrix (i.e., ACC), it is 16x16xf32 -> 4x4x16x1xf32. Because the
// vector size is 4x1.
// Check that the dimensions of the matrix can be divided by the tile shape,
// if not then bail out.
auto sourceType = encodingOp.getSourceType().getShape();
assert(sourceType.size() == 2);
if (sourceType[0] % innerTiles[0] == 0 ||
sourceType[1] % innerTiles[1] == 0) {
return failure();
}

// Create expand_shape
llvm::SmallVector<int64_t> expandShapeShape;
auto [iT1, iT2] = *intrinsicVectorShape;
auto oT1 = sourceType[0] / iT1;
auto oT2 = sourceType[1] / iT2;
expandShapeShape = {oT1, iT1, oT2, iT2};
assert(expandShapeShape.size() == 4);
RankedTensorType expandShapeType =
RankedTensorType::Builder(encodingOp.getSourceType())
.setShape(expandShapeShape);
Value expandShapeOp = rewriter.create<tensor::ExpandShapeOp>(
loc, expandShapeType, packOp->getResult());

// TODO(lialan): create linalg.transpose op.
// LHS: 16x1x4x1 -> 4x16x1x1 (perm = [2, 0, 3, 1])
// ACC: 4x4x16x1 -> 4x16x4x1 (perm = [0, 2, 1, 3])
// create linalg.transpose
auto emptyTensor = rewriter.create<tensor::EmptyOp>(
loc, expandShapeShape, encodingOp.getSourceType().getElementType());
[[maybe_unused]] auto transposeOp = rewriter.create<linalg::TransposeOp>(
loc, expandShapeOp, emptyTensor, getTransposePermutation(roleIdx));

// TODO(hanchung): We want to make the shape consistent, so we need to
// collpase and expand the shape. This is the shape we materialize for Flow
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ iree_lit_test_suite(
"check_ir_before_llvm_conversion.mlir",
"check_ir_before_llvm_conversion_not_fail_unbound.mlir",
"convert_to_llvm.mlir",
"data_tile.mlir",
"emit_vectorization_remarks.mlir",
"expand_f16_op_to_f32.mlir",
"hal_executable_constants.mlir",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ iree_lit_test_suite(
"check_ir_before_llvm_conversion.mlir"
"check_ir_before_llvm_conversion_not_fail_unbound.mlir"
"convert_to_llvm.mlir"
"data_tile.mlir"
"emit_vectorization_remarks.mlir"
"expand_f16_op_to_f32.mlir"
"hal_executable_constants.mlir"
Expand Down
22 changes: 22 additions & 0 deletions compiler/src/iree/compiler/Codegen/LLVMCPU/test/data_tile.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
// RUN: iree-opt --pass-pipeline="builtin.module(func.func(iree-codegen-gpu-materialize-device-encoding))" --split-input-file %s | FileCheck %s
#pipeline_layout = #hal.pipeline.layout<push_constants = 0, sets = [
#hal.descriptor_set.layout<0, bindings = [
#hal.descriptor_set.binding<0, storage_buffer>,
#hal.descriptor_set.binding<1, storage_buffer>
]>
]>
func.func @set_encoding() {
%c0 = arith.constant 0 : index
%0 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(0) alignment(64) offset(%c0) flags(ReadOnly) : !flow.dispatch.tensor<readonly:tensor<255x513xf32>>
%1 = hal.interface.binding.subspan layout(#pipeline_layout) set(0) binding(1) alignment(64) offset(%c0) : !flow.dispatch.tensor<writeonly:tensor<255x513xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<255x513xf32>, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], round_dims_to = array<i64: 16, 16, 16>>>>
%2 = flow.dispatch.tensor.load %0, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : !flow.dispatch.tensor<readonly:tensor<255x513xf32>> -> tensor<255x513xf32>
%3 = iree_encoding.set_encoding %2 : tensor<255x513xf32> -> tensor<255x513xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<255x513xf32>, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], round_dims_to = array<i64: 16, 16, 16>>>
flow.dispatch.tensor.store %3, %1, offsets = [0, 0], sizes = [255, 513], strides = [1, 1] : tensor<255x513xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<255x513xf32>, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], round_dims_to = array<i64: 16, 16, 16>>> -> !flow.dispatch.tensor<writeonly:tensor<255x513xf32, #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], original_type = tensor<255x513xf32>, user_indexing_maps = [affine_map<(d0, d1, d2) -> (d0, d2)>, affine_map<(d0, d1, d2) -> (d2, d1)>, affine_map<(d0, d1, d2) -> (d0, d1)>], round_dims_to = array<i64: 16, 16, 16>>>>
return
}


// CHECK-LABEL: func.func @set_encoding
// CHECK: %[[EMPTY:.*]] = tensor.empty() : tensor<16x129x16x4xf32>
// CHECK: %[[PACK:.*]] = tensor.pack %2 padding_value(%cst : f32) outer_dims_perm = [0, 1] inner_dims_pos = [0, 1] inner_tiles = [16, 4] into %[[EMPTY]] : tensor<255x513xf32> -> tensor<16x129x16x4xf32>

0 comments on commit 644bd0f

Please sign in to comment.