diff --git a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp index b76fd1dbc0a8..be1462cd30a4 100644 --- a/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp +++ b/compiler/src/iree/compiler/Dialect/Encoding/IR/EncodingAttrs.cpp @@ -4,12 +4,14 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +#include #include "iree/compiler/Dialect/Encoding/IR/EncodingTypes.h" #include "iree/compiler/Dialect/Encoding/IR/EncodingDialect.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/TypeSwitch.h" #include "mlir/Dialect/Affine/Utils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/Utils/StructuredOpsUtils.h" @@ -17,6 +19,7 @@ #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" #include "mlir/IR/MLIRContext.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Interfaces/InferTypeOpInterface.h" @@ -313,9 +316,36 @@ std::string stringifyOperandIndex(IntegerAttr valueAttr) { Value PadEncodingLayoutAttr::calculateStorageSizeInBytes( Location loc, OpBuilder &builder, RankedTensorType type, ValueRange dynamicDims) const { - // TODO(kuhar): Add sizeof calculation. - assert(false && "Unimplemented"); - return nullptr; + ArrayRef padding = getPadding().asArrayRef(); + assert(padding.size() == type.getRank() && "Invalid padding"); + + const int64_t elementSize = getRoundedElementByteWidth(type.getElementType()); + int64_t staticProduct = elementSize; + Value dynamicProduct = builder.create(loc, 1); + + size_t dynamicDimIdx = 0; + for (auto [dimSize, padValue] : llvm::zip_equal(type.getShape(), padding)) { + if (!ShapedType::isDynamic(dimSize)) { + staticProduct *= (dimSize + padValue); + continue; + } + + Value dynamicDimSize = dynamicDims[dynamicDimIdx]; + ++dynamicDimIdx; + + if (padValue != 0) { + dynamicDimSize = builder.create( + loc, dynamicDimSize, + builder.create(loc, padValue), + arith::IntegerOverflowFlags::nsw); + } + dynamicProduct = builder.createOrFold( + loc, dynamicProduct, dynamicDimSize, arith::IntegerOverflowFlags::nsw); + } + + return builder.createOrFold( + loc, builder.create(loc, staticProduct), + dynamicProduct, arith::IntegerOverflowFlags::nsw); } //===---------------------------------------------------------------------===// diff --git a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/encode_host_tensors_encoding.mlir b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/encode_host_tensors_encoding.mlir index 8d670ebd6d1c..46186a82204e 100644 --- a/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/encode_host_tensors_encoding.mlir +++ b/compiler/src/iree/compiler/Dialect/Stream/Transforms/test/encode_host_tensors_encoding.mlir @@ -270,6 +270,66 @@ util.func public @sizeof_lhs_encoding_with_bcast_across_m_dim_dynamic(%arg0: ind // ----- +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#no_pad_layout = #iree_encoding.pad_encoding_layout<[0, 0]> +#no_pad_encoding = #iree_encoding.encoding +#pad_layout_a = #iree_encoding.pad_encoding_layout<[0, 64]> +#pad_encoding_a = #iree_encoding.encoding +#pad_layout_b = #iree_encoding.pad_encoding_layout<[64, 0]> +#pad_encoding_b = #iree_encoding.encoding +util.func public @sizeof_lhs_pad_encoding_static() -> index, index, index { + %0 = stream.tensor.sizeof tensor<2048x4096xf16, #no_pad_encoding>{} : index + %1 = stream.tensor.sizeof tensor<2048x4096xf16, #pad_encoding_a>{} : index + %2 = stream.tensor.sizeof tensor<2048x4096xf16, #pad_encoding_b>{} : index + util.return %0, %1, %2 : index, index, index +} + +// We expect (2048 + pad[0]) * (4096 + pad[1]) * (16 / 8). + +// CHECK-LABEL: @sizeof_lhs_pad_encoding_static +// CHECK-DAG: %[[CST_A:.+]] = arith.constant 16777216 : index +// CHECK-DAG: %[[CST_B:.+]] = arith.constant 17039360 : index +// CHECK-DAG: %[[CST_C:.+]] = arith.constant 17301504 : index +// CHECK: return %[[CST_A]], %[[CST_B]], %[[CST_C]] + +// ----- + +#map = affine_map<(d0, d1, d2) -> (d0, d2)> +#map1 = affine_map<(d0, d1, d2) -> (d1, d2)> +#map2 = affine_map<(d0, d1, d2) -> (d0, d1)> +#no_pad_layout = #iree_encoding.pad_encoding_layout<[0, 0]> +#no_pad_encoding = #iree_encoding.encoding +#pad_layout_a = #iree_encoding.pad_encoding_layout<[0, 64]> +#pad_encoding_a = #iree_encoding.encoding +#pad_layout_b = #iree_encoding.pad_encoding_layout<[64, 0]> +#pad_encoding_b = #iree_encoding.encoding +util.func public @sizeof_rhs_pad_encoding_dynamic(%arg0 : index, %arg1 : index) -> index, index, index, index { + %0 = stream.tensor.sizeof tensor<2048x?xf16, #no_pad_encoding>{%arg0} : index + %1 = stream.tensor.sizeof tensor{%arg0} : index + %2 = stream.tensor.sizeof tensor{%arg0} : index + %3 = stream.tensor.sizeof tensor{%arg0, %arg1} : index + util.return %0, %1, %2, %3 : index, index, index, index +} + +// CHECK-LABEL: @sizeof_rhs_pad_encoding_dynamic +// CHECK-DAG: %[[CST_2:.+]] = arith.constant 2 : index +// CHECK-DAG: %[[CST_64:.+]] = arith.constant 64 : index +// CHECK-DAG: %[[CST_4096:.+]] = arith.constant 4096 : index +// CHECK-DAG: %[[CST_8192:.+]] = arith.constant 8192 : index +// CHECK-DAG: %[[CST_8320:.+]] = arith.constant 8320 : index +// CHECK: %[[A:.+]] = arith.muli %arg0, %[[CST_4096]] overflow +// CHECK: %[[B:.+]] = arith.muli %arg0, %[[CST_8320]] overflow +// CHECK: %[[C_0:.+]] = arith.addi %arg0, %[[CST_64]] overflow +// CHECK: %[[C_1:.+]] = arith.muli %[[C_0]], %[[CST_8192]] overflow +// CHECK: %[[D_0:.+]] = arith.addi %arg0, %[[CST_64]] overflow +// CHECK: %[[D_1:.+]] = arith.muli %[[D_0]], %arg1 overflow +// CHECK: %[[D_2:.+]] = arith.muli %[[D_1]], %[[CST_2]] overflow +// CHECK: return %[[A]], %[[B]], %[[C_1]], %[[D_2]] + +// ----- + #encoding_layout_0 = #iree_cpu.cpu_encoding_layout #encoding_layout_1 = #iree_cpu.vmvx_encoding_layout #encoding = #iree_encoding.encoding