Skip to content

Commit

Permalink
Materialize batch_matmul to batch_mmt4d (iree-org#14731)
Browse files Browse the repository at this point in the history
Add pattern the materialize `batch_matmul` with data-tiling encoding to
`batch_mmt4d`

Tracking issue: iree-org#14431
  • Loading branch information
Jerry Wu authored Aug 25, 2023
1 parent eadbfcb commit a705de3
Show file tree
Hide file tree
Showing 4 changed files with 175 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ namespace iree_compiler {
namespace IREE {
namespace LinalgExt {

// Check if encoding user is one of matmul encodings.
bool isMatmulEncodingUser(EncodingUser user);

// Check if encoding user is one of batch matmul encodings.
bool isBatchMatmulEncodingUser(EncodingUser user);

struct MatmulTileParams {
int64_t M = 1;
int64_t K = 1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,10 @@ lowerOpWithEncoding(RewriterBase &rewriter, linalg::MatmulOp matmulOp,
if (!lhsEncoding || !rhsEncoding || !resultEncoding) {
return failure();
}
if (lhsEncoding.getRole().getValue() !=
if (!isMatmulEncodingUser(lhsEncoding.getUser().getValue()) ||
!isMatmulEncodingUser(rhsEncoding.getUser().getValue()) ||
!isMatmulEncodingUser(resultEncoding.getUser().getValue()) ||
lhsEncoding.getRole().getValue() !=
mlir::iree_compiler::IREE::LinalgExt::EncodingRole::LHS ||
rhsEncoding.getRole().getValue() !=
mlir::iree_compiler::IREE::LinalgExt::EncodingRole::RHS ||
Expand All @@ -262,8 +265,46 @@ lowerOpWithEncoding(RewriterBase &rewriter, linalg::MatmulOp matmulOp,
return mmt4DOp;
}

/// Utility method to convert from `linalg.fill` on `tensor` type with encoding
/// to fill of the materialized type
/// Utility method to convert from `linalg.batch_matmul` with
/// - lhs encoding with user=BATCH_MATMUL_*, role=LHS
/// - rhs encoding with user=BATCH_MATMUL_*, role=RHS
/// - result encoding with user=BATCH_MATMUL_*, role=RESULT
/// to linalg.batch_mmt4d op.
static FailureOr<Operation *>
lowerOpWithEncoding(RewriterBase &rewriter, linalg::BatchMatmulOp batchMatmulOp,
ValueRange convertedInputOperands,
ValueRange convertedOutputOperands, MaterializeEncodingFn,
MaterializeEncodingValueFn) {
if (!batchMatmulOp.hasTensorSemantics())
return failure();
auto inputs = batchMatmulOp.getDpsInputOperands();
auto outputs = batchMatmulOp.getDpsInitOperands();
auto lhsEncoding =
getEncodingAttr(inputs[0]->get().getType().cast<RankedTensorType>());
auto rhsEncoding =
getEncodingAttr(inputs[1]->get().getType().cast<RankedTensorType>());
auto resultEncoding =
getEncodingAttr(outputs[0]->get().getType().cast<RankedTensorType>());
if (!lhsEncoding || !rhsEncoding || !resultEncoding) {
return failure();
}

if (!isBatchMatmulEncodingUser(lhsEncoding.getUser().getValue()) ||
!isBatchMatmulEncodingUser(rhsEncoding.getUser().getValue()) ||
!isBatchMatmulEncodingUser(resultEncoding.getUser().getValue()) ||
lhsEncoding.getRole().getValue() != EncodingRole::LHS ||
rhsEncoding.getRole().getValue() != EncodingRole::RHS ||
resultEncoding.getRole().getValue() != EncodingRole::RESULT) {
return failure();
}
Operation *batchMmt4DOp = rewriter.create<linalg::BatchMmt4DOp>(
batchMatmulOp.getLoc(), convertedOutputOperands[0].getType(),
convertedInputOperands, convertedOutputOperands);
return batchMmt4DOp;
}

/// Utility method to convert from `linalg.fill` on `tensor` type with
/// encoding to fill of the materialized type
static FailureOr<Operation *>
lowerOpWithEncoding(RewriterBase &rewriter, linalg::FillOp fillOp,
ValueRange convertedInputOperands,
Expand Down Expand Up @@ -515,9 +556,11 @@ void populateMaterializeEncodingPatterns(
MaterializeEncodingTypeConverter &typeConverter,
MaterializeEncodingValueFn materializeEncodingValueFn) {

// Add all patterns for converting from encoded type to the materialized type
// Add all patterns for converting from encoded type to the materialized
// type
patterns.insert<MaterializeDPSOperation<linalg::FillOp>,
MaterializeDPSOperation<linalg::MatmulOp>,
MaterializeDPSOperation<linalg::BatchMatmulOp>,
MaterializeOperation<tensor::EmptyOp>,
SetEncodingOpToPackOpConversion,
UnsetEncodingOpToPackOpConversion>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,23 +11,39 @@ namespace iree_compiler {
namespace IREE {
namespace LinalgExt {

MaterializeEncodingInfo
chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role,
MatmulTileParams tileParams) {
// Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix.
int64_t matmulDimBase = 0;
bool isMatmulEncodingUser(EncodingUser user) {
switch (user) {
case EncodingUser::MATMUL_F32F32F32:
case EncodingUser::MATMUL_F16F16F32:
case EncodingUser::MATMUL_F16F16F16:
case EncodingUser::MATMUL_BF16BF16F32:
case EncodingUser::MATMUL_BF16BF16BF16:
case EncodingUser::MATMUL_I8I8I32:
return true;
default:
return false;
}
}

bool isBatchMatmulEncodingUser(EncodingUser user) {
switch (user) {
case EncodingUser::BATCH_MATMUL_F32F32F32:
case EncodingUser::BATCH_MATMUL_F16F16F32:
case EncodingUser::BATCH_MATMUL_F16F16F16:
case EncodingUser::BATCH_MATMUL_BF16BF16F32:
case EncodingUser::BATCH_MATMUL_BF16BF16BF16:
case EncodingUser::BATCH_MATMUL_I8I8I32:
matmulDimBase = 1;
break;
return true;
default:
break;
return false;
}
}

MaterializeEncodingInfo
chooseEncodingInfoForMatmul(EncodingUser user, EncodingRole role,
MatmulTileParams tileParams) {
// Start dim of the MxK (LHS), KxN (RHS), or MxN (RESULT) 2D matrix.
int64_t matmulDimBase = isBatchMatmulEncodingUser(user) ? 1 : 0;

MaterializeEncodingInfo encodingInfo;
encodingInfo.innerDimsPos = {matmulDimBase, matmulDimBase + 1};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,3 +240,101 @@ func.func @pack_unpack_batch_matmul_result(%arg0 : tensor<?x?x?xf32>) -> tensor<
// CHECK: %[[UNPACK_DEST:.+]] = tensor.empty(%[[D0]], %[[D1]], %[[D2]]) : tensor<?x?x?xf32>
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[PACK]] inner_dims_pos = [1, 2] inner_tiles = [8, 8] into %[[UNPACK_DEST]]
// CHECK: return %[[UNPACK]]

// -----

func.func @pack_batch_matmul(%arg0 : tensor<128x80x32xf32>, %arg1 : tensor<128x32x320xf32>, %arg2 : tensor<128x80x320xf32>) -> tensor<128x80x320xf32> {
%0 = iree_linalg_ext.set_encoding %arg0 : tensor<128x80x32xf32> -> tensor<128x80x32xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>
%1 = iree_linalg_ext.set_encoding %arg1 : tensor<128x32x320xf32> -> tensor<128x32x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>
%2 = iree_linalg_ext.set_encoding %arg2 : tensor<128x80x320xf32> -> tensor<128x80x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
%3 = linalg.batch_matmul ins(%0, %1 : tensor<128x80x32xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>, tensor<128x32x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>)
outs(%2 : tensor<128x80x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>) -> tensor<128x80x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
%4 = iree_linalg_ext.unset_encoding %3 : tensor<128x80x320xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>> -> tensor<128x80x320xf32>
return %4 : tensor<128x80x320xf32>
}
// CHECK: func @pack_batch_matmul(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<128x80x32xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<128x32x320xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<128x80x320xf32>
// CHECK: %[[PACK_LHS:.+]] = tensor.pack
// CHECK-SAME: %[[ARG0]]
// CHECK: %[[PACK_RHS:.+]] = tensor.pack
// CHECK-SAME: %[[ARG1]]
// CHECK: %[[PACK_RESULT:.+]] = tensor.pack
// CHECK-SAME: %[[ARG2]]
// CHECK: %[[BATCH_MMT4D:.+]] = linalg.batch_mmt4d
// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] :
// CHECK-SAME: outs(%[[PACK_RESULT]] :
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[BATCH_MMT4D]]
// CHECK: return %[[UNPACK]]

// -----

func.func @pack_batch_matmul_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>, %arg2 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>
%1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>
%2 = iree_linalg_ext.set_encoding %arg2 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
%3 = linalg.batch_matmul ins(%0, %1 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>, tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>)
outs(%2 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>) -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
%4 = iree_linalg_ext.unset_encoding %3 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>> -> tensor<?x?x?xf32>
return %4 : tensor<?x?x?xf32>
}
// CHECK: func @pack_batch_matmul_dynamic(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG2:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
// CHECK: %[[PACK_LHS:.+]] = tensor.pack
// CHECK-SAME: %[[ARG0]]
// CHECK: %[[PACK_RHS:.+]] = tensor.pack
// CHECK-SAME: %[[ARG1]]
// CHECK: %[[PACK_RESULT:.+]] = tensor.pack
// CHECK-SAME: %[[ARG2]]
// CHECK: %[[BATCH_MMT4D:.+]] = linalg.batch_mmt4d
// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] :
// CHECK-SAME: outs(%[[PACK_RESULT]] :
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[BATCH_MMT4D]]
// CHECK: return %[[UNPACK]]

// -----

func.func @pack_batch_matmul_fill_dynamic(%arg0 : tensor<?x?x?xf32>, %arg1 : tensor<?x?x?xf32>) -> tensor<?x?x?xf32> {
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
%c2 = arith.constant 2 : index
%cst = arith.constant 0.0 : f32
%d0 = tensor.dim %arg0, %c0 : tensor<?x?x?xf32>
%d1 = tensor.dim %arg0, %c1 : tensor<?x?x?xf32>
%d2 = tensor.dim %arg1, %c2 : tensor<?x?x?xf32>
%0 = iree_linalg_ext.set_encoding %arg0 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>
%1 = iree_linalg_ext.set_encoding %arg1 : tensor<?x?x?xf32> -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>
%2 = tensor.empty(%d0, %d1, %d2) : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
%3 = linalg.fill ins(%cst : f32) outs(%2 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>)
-> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
%4 = linalg.batch_matmul ins(%0, %1 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = LHS>>, tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RHS>>)
outs(%3 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>) -> tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>>
%5 = iree_linalg_ext.unset_encoding %4 : tensor<?x?x?xf32, #iree_linalg_ext.encoding<user = BATCH_MATMUL_F32F32F32, role = RESULT>> -> tensor<?x?x?xf32>
return %5 : tensor<?x?x?xf32>
}
// CHECK-DAG: #[[MAP0:.+]] = affine_map<()[s0] -> (s0 ceildiv 8)>
// CHECK-DAG: #[[MAP1:.+]] = affine_map<()[s0] -> (s0 ceildiv 4)>
// CHECK: func @pack_batch_matmul_fill_dynamic(
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9]+]]: tensor<?x?x?xf32>
// CHECK-DAG: %[[C0:.+]] = arith.constant 0 : index
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
// CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[D0:.+]] = tensor.dim %[[ARG0]], %[[C0]]
// CHECK-DAG: %[[D1:.+]] = tensor.dim %[[ARG0]], %[[C1]]
// CHECK-DAG: %[[D2:.+]] = tensor.dim %[[ARG1]], %[[C2]]
// CHECK-DAG: %[[OUT_D1:.+]] = affine.apply #[[MAP0]]()[%[[D1]]]
// CHECK-DAG: %[[OUT_D2:.+]] = affine.apply #[[MAP0]]()[%[[D2]]]
// CHECK-DAG: %[[PACK_LHS:.+]] = tensor.pack %[[ARG0]]
// CHECK-DAG: %[[PACK_RHS:.+]] = tensor.pack %[[ARG1]]
// CHECK-DAG: %[[EMPTY:.+]] = tensor.empty(%[[D0]], %[[OUT_D1]], %[[OUT_D2]]) : tensor<?x?x?x8x8xf32>
// CHECK: %[[FILL:.+]] = linalg.fill
// CHECK-SAME: outs(%[[EMPTY]] : tensor<?x?x?x8x8xf32>)
// CHECK: %[[BATCH_MMT4D:.+]] = linalg.batch_mmt4d
// CHECK-SAME: ins(%[[PACK_LHS]], %[[PACK_RHS]] :
// CHECK-SAME: outs(%[[FILL]] :
// CHECK: %[[UNPACK:.+]] = tensor.unpack %[[BATCH_MMT4D]]
// CHECK: return %[[UNPACK]]

0 comments on commit a705de3

Please sign in to comment.