Skip to content

Commit

Permalink
[Encoding] Implement sizeof calculation for pad encoding
Browse files Browse the repository at this point in the history
Signed-off-by: Jakub Kuderski <[email protected]>
  • Loading branch information
kuhar committed Feb 3, 2025
1 parent 976dd70 commit 25664ca
Show file tree
Hide file tree
Showing 2 changed files with 94 additions and 3 deletions.
37 changes: 34 additions & 3 deletions compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,22 @@
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include <cassert>
#include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h"

#include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/Dialect/Affine/Utils.h"
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/Arith/Utils/Utils.h"
#include "mlir/Dialect/Linalg/Utils/Utils.h"
#include "mlir/Dialect/Utils/StructuredOpsUtils.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinAttributeInterfaces.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypeInterfaces.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/TypeUtilities.h"
#include "mlir/Interfaces/InferTypeOpInterface.h"
Expand Down Expand Up @@ -313,9 +316,37 @@ std::string stringifyOperandIndex(IntegerAttr valueAttr) {
Value PadEncodingLayoutAttr::calculateStorageSizeInBytes(
Location loc, OpBuilder &builder, RankedTensorType type,
ValueRange dynamicDims) const {
// TODO(kuhar): Add sizeof calculation.
assert(false && "Unimplemented");
return nullptr;
ArrayRef<int32_t> padding = getPadding().asArrayRef();
assert(padding.size() == type.getRank() && "Invalid padding");

const int64_t elementSize =
llvm::divideCeil(type.getElementTypeBitWidth(), 8);
int64_t staticProduct = elementSize;
Value dynamicProduct = builder.create<arith::ConstantIndexOp>(loc, 1);

size_t dynamicDimIdx = 0;
for (auto [dimSize, padValue] : llvm::zip_equal(type.getShape(), padding)) {
if (!ShapedType::isDynamic(dimSize)) {
staticProduct *= (dimSize + padValue);
continue;
}

Value dynamicDimSize = dynamicDims[dynamicDimIdx];
++dynamicDimIdx;

if (padValue != 0) {
dynamicDimSize = builder.create<arith::AddIOp>(
loc, dynamicDimSize,
builder.create<arith::ConstantIndexOp>(loc, padValue),
arith::IntegerOverflowFlags::nsw);
}
dynamicProduct = builder.createOrFold<arith::MulIOp>(
loc, dynamicProduct, dynamicDimSize, arith::IntegerOverflowFlags::nsw);
}

return builder.createOrFold<arith::MulIOp>(
loc, builder.create<arith::ConstantIndexOp>(loc, staticProduct),
dynamicProduct, arith::IntegerOverflowFlags::nsw);
}

//===---------------------------------------------------------------------===//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -270,6 +270,66 @@ util.func public @sizeof_lhs_encoding_with_bcast_across_m_dim_dynamic(%arg0: ind

// -----

#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#no_pad_layout = #iree_encoding.pad_encoding_layout<[0, 0]>
#no_pad_encoding = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f16, f16, f32], user_indexing_maps = [#map, #map1, #map2], layouts = [#no_pad_layout]>
#pad_layout_a = #iree_encoding.pad_encoding_layout<[0, 64]>
#pad_encoding_a = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f16, f16, f32], user_indexing_maps = [#map, #map1, #map2], layouts = [#pad_layout_a]>
#pad_layout_b = #iree_encoding.pad_encoding_layout<[64, 0]>
#pad_encoding_b = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f16, f16, f32], user_indexing_maps = [#map, #map1, #map2], layouts = [#pad_layout_b]>
util.func public @sizeof_lhs_pad_encoding_static() -> index, index, index {
%0 = stream.tensor.sizeof tensor<2048x4096xf16, #no_pad_encoding>{} : index
%1 = stream.tensor.sizeof tensor<2048x4096xf16, #pad_encoding_a>{} : index
%2 = stream.tensor.sizeof tensor<2048x4096xf16, #pad_encoding_b>{} : index
util.return %0, %1, %2 : index, index, index
}

// We expect (2048 + pad[0]) * (4096 + pad[1]) * (16 / 8).

// CHECK-LABEL: @sizeof_lhs_pad_encoding_static
// CHECK-DAG: %[[CST_A:.+]] = arith.constant 16777216 : index
// CHECK-DAG: %[[CST_B:.+]] = arith.constant 17039360 : index
// CHECK-DAG: %[[CST_C:.+]] = arith.constant 17301504 : index
// CHECK: return %[[CST_A]], %[[CST_B]], %[[CST_C]]

// -----

#map = affine_map<(d0, d1, d2) -> (d0, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d1, d2)>
#map2 = affine_map<(d0, d1, d2) -> (d0, d1)>
#no_pad_layout = #iree_encoding.pad_encoding_layout<[0, 0]>
#no_pad_encoding = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f16, f16, f32], user_indexing_maps = [#map, #map1, #map2], layouts = [#no_pad_layout]>
#pad_layout_a = #iree_encoding.pad_encoding_layout<[0, 64]>
#pad_encoding_a = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f16, f16, f32], user_indexing_maps = [#map, #map1, #map2], layouts = [#pad_layout_a]>
#pad_layout_b = #iree_encoding.pad_encoding_layout<[64, 0]>
#pad_encoding_b = #iree_encoding.encoding<operand_index = 1, op_type = matmul, element_types = [f16, f16, f32], user_indexing_maps = [#map, #map1, #map2], layouts = [#pad_layout_b]>
util.func public @sizeof_rhs_pad_encoding_dynamic(%arg0 : index, %arg1 : index) -> index, index, index, index {
%0 = stream.tensor.sizeof tensor<2048x?xf16, #no_pad_encoding>{%arg0} : index
%1 = stream.tensor.sizeof tensor<?x4096xf16, #pad_encoding_a>{%arg0} : index
%2 = stream.tensor.sizeof tensor<?x4096xf16, #pad_encoding_b>{%arg0} : index
%3 = stream.tensor.sizeof tensor<?x?xf16, #pad_encoding_b>{%arg0, %arg1} : index
util.return %0, %1, %2, %3 : index, index, index, index
}

// CHECK-LABEL: @sizeof_rhs_pad_encoding_dynamic
// CHECK-DAG: %[[CST_2:.+]] = arith.constant 2 : index
// CHECK-DAG: %[[CST_64:.+]] = arith.constant 64 : index
// CHECK-DAG: %[[CST_4096:.+]] = arith.constant 4096 : index
// CHECK-DAG: %[[CST_8192:.+]] = arith.constant 8192 : index
// CHECK-DAG: %[[CST_8320:.+]] = arith.constant 8320 : index
// CHECK: %[[A:.+]] = arith.muli %arg0, %[[CST_4096]] overflow<nsw>
// CHECK: %[[B:.+]] = arith.muli %arg0, %[[CST_8320]] overflow<nsw>
// CHECK: %[[C_0:.+]] = arith.addi %arg0, %[[CST_64]] overflow<nsw>
// CHECK: %[[C_1:.+]] = arith.muli %[[C_0]], %[[CST_8192]] overflow<nsw>
// CHECK: %[[D_0:.+]] = arith.addi %arg0, %[[CST_64]] overflow<nsw>
// CHECK: %[[D_1:.+]] = arith.muli %[[D_0]], %arg1 overflow<nsw>
// CHECK: %[[D_2:.+]] = arith.muli %[[D_1]], %[[CST_2]] overflow<nsw>
// CHECK: return %[[A]], %[[B]], %[[C_1]], %[[D_2]]

// -----

#encoding_layout_0 = #iree_cpu.cpu_encoding_layout<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [4, 8], outerDimsPerm = [0, 1]}}>
#encoding_layout_1 = #iree_cpu.vmvx_encoding_layout<configuration = {encoding_info = {innerDimsPos = [0, 1], innerTileSizes = [2, 16], outerDimsPerm = [0, 1]}}>
#encoding = #iree_encoding.encoding<operand_index = 0, op_type = matmul, element_types = [f32, f32, f32], layouts = [#encoding_layout_0, #encoding_layout_1]>
Expand Down

0 comments on commit 25664ca

Please sign in to comment.