Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TosaLayerwiseConstantFoldPass: speed up folding of transpose for bf16 #75

Merged
merged 1 commit into from
Sep 4, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 62 additions & 31 deletions mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,17 +127,6 @@ template DenseElementsAttr applyElementWise<APFloat, APFloat, FloatType>(
const std::function<APFloat(const APFloat &, FloatType)> &toApply,
FloatType targetType);

/// Function that checks if the type contained in \p toCheck is float.
LogicalResult notifyIfNotFloat(TypedValue<TensorType> toCheck, TosaOp location,
mgehre-amd marked this conversation as resolved.
Show resolved Hide resolved
PatternRewriter &rewriter) {
if (isa<FloatType>(toCheck.getType().getElementType())) {
return success();
}
return rewriter.notifyMatchFailure(location,
"Unexpected input tensor type: the "
"TOSA spec only allows floats");
}

template <class ElementType, class ResultType>
DenseElementsAttr applyElementWise(
const DenseElementsAttr &first, const DenseElementsAttr &second,
Expand Down Expand Up @@ -201,14 +190,10 @@ LogicalResult notifyIfNoTosaDenseConstantTensor(Value toCheck,
"it operates on a TOSA constant");
}

template <typename BaseType>
DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
ShapedType outputType,
llvm::ArrayRef<int64_t> permValues) {
if (inputType.getNumElements() == 0)
return DenseElementsAttr::get(outputType, llvm::ArrayRef<BaseType>{});

auto attrValues = attr.getValues<BaseType>();
template <typename BaseType, typename RangeT>
void transposeArray(RangeT inputValues, ShapedType inputType,
SmallVector<BaseType> &outputValues, ShapedType outputType,
llvm::ArrayRef<int64_t> permValues) {
auto inputShape = inputType.getShape();

// The inverted permutation map and strides of the output are used to compute
Expand All @@ -217,13 +202,10 @@ DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
auto outputStrides = computeStrides(outputType.getShape());
auto invertedPermValues = invertPermutationVector(permValues);

auto initialValue = *std::begin(attrValues);
SmallVector<BaseType> outputValues(inputType.getNumElements(), initialValue);

for (const auto &it : llvm::enumerate(attrValues)) {
for (auto it : llvm::enumerate(inputValues)) {
auto srcLinearIndex = it.index();

uint64_t dstLinearIndex = 0;

for (int64_t dim = inputShape.size() - 1; dim >= 0; --dim) {
// Compute the index into the current dimension of the source vector.
auto sourceIndexForDim = srcLinearIndex % inputShape[dim];
Expand All @@ -237,7 +219,37 @@ DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,

outputValues[dstLinearIndex] = it.value();
}
}

template <typename BaseType>
DenseElementsAttr transposeTypeRaw(DenseElementsAttr attr, ShapedType inputType,
ShapedType outputType,
llvm::ArrayRef<int64_t> permValues) {

ArrayRef inputValues(
reinterpret_cast<const BaseType *>(attr.getRawData().data()),
attr.getNumElements());

SmallVector<BaseType> outputValues;
outputValues.resize_for_overwrite(inputType.getNumElements());
transposeArray<BaseType>(inputValues, inputType, /*out*/ outputValues,
outputType, permValues);

ArrayRef rawOutputValues(reinterpret_cast<const char *>(outputValues.data()),
outputValues.size() * sizeof(BaseType));
return DenseElementsAttr::getFromRawBuffer(outputType, rawOutputValues);
}

template <typename BaseType>
DenseElementsAttr transposeType(DenseElementsAttr attr, ShapedType inputType,
ShapedType outputType,
llvm::ArrayRef<int64_t> permValues) {

auto inputValues = attr.getValues<BaseType>();
SmallVector<BaseType> outputValues(inputType.getNumElements(),
*std::begin(inputValues));
transposeArray<BaseType>(inputValues, inputType, /*out*/ outputValues,
outputType, permValues);
return DenseElementsAttr::get(outputType,
llvm::ArrayRef<BaseType>(outputValues));
}
Expand All @@ -246,32 +258,46 @@ DenseElementsAttr transposeType(ElementsAttr attr, ShapedType inputType,
// This implementation tries to operate on the underlying data in its raw
// representation when possible to avoid allocating a large number of Attribute
// objects.
DenseElementsAttr transpose(ElementsAttr attr, ShapedType inputType,
DenseElementsAttr transpose(DenseElementsAttr attr, ShapedType inputType,
ShapedType outputType,
llvm::ArrayRef<int64_t> permValues) {

assert(outputType.getNumElements() == inputType.getNumElements());
assert(outputType.getElementType() == inputType.getElementType());

auto baseType = inputType.getElementType();

// Handle possible integer types
if (auto intType = dyn_cast<IntegerType>(baseType)) {
switch (intType.getWidth()) {
case 1:
// i1 has special alignment which is not handled by transposeTypeRaw.
return transposeType<bool>(attr, inputType, outputType, permValues);
mgehre-amd marked this conversation as resolved.
Show resolved Hide resolved
case 8:
return transposeType<int8_t>(attr, inputType, outputType, permValues);
return transposeTypeRaw<uint8_t>(attr, inputType, outputType, permValues);
case 16:
return transposeType<int16_t>(attr, inputType, outputType, permValues);
return transposeTypeRaw<uint16_t>(attr, inputType, outputType,
permValues);
case 32:
return transposeType<int32_t>(attr, inputType, outputType, permValues);
return transposeTypeRaw<uint32_t>(attr, inputType, outputType,
permValues);
case 64:
return transposeType<int64_t>(attr, inputType, outputType, permValues);
return transposeTypeRaw<uint64_t>(attr, inputType, outputType,
permValues);
default:
return transposeType<APInt>(attr, inputType, outputType, permValues);
}
}

// Handle possible float types
if (baseType.isF32()) {
return transposeType<float>(attr, inputType, outputType, permValues);
return transposeTypeRaw<uint32_t>(attr, inputType, outputType, permValues);
}
if (baseType.isF64()) {
return transposeTypeRaw<uint64_t>(attr, inputType, outputType, permValues);
}
if (baseType.isBF16()) {
return transposeTypeRaw<uint16_t>(attr, inputType, outputType, permValues);
}

return transposeType<APFloat>(attr, inputType, outputType, permValues);
Expand Down Expand Up @@ -501,9 +527,14 @@ struct TosaFoldConstantTranspose : public TosaFoldConstantBase<tosa::TransposeOp
if (!outputType.getElementType().isIntOrIndexOrFloat())
return failure();

ElementsAttr inputValues;
DenseElementsAttr inputValues;
if (!matchPattern(op.getInput1(), m_Constant(&inputValues)))
return failure();
// Splats are already handled in the fold() method of each op.
// We cannot handle them here because the use of DenseElementsAttr::getRawData
// is invalid for them.
if (inputValues.isSplat())
return failure();
// Make sure the input is a constant that has a single user.
if (!llvm::hasSingleElement(op.getInput1().getDefiningOp()->getUsers()))
return failure();
Expand Down
12 changes: 12 additions & 0 deletions mlir/test/Dialect/Tosa/constant-op-fold.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ func.func @transpose_fold_splat() -> tensor<3x2xf32> {
return %1 : tensor<3x2xf32>
}

// CHECK-LABEL: @transpose_fold_2d_bfloat16
func.func @transpose_fold_2d_bfloat16() -> tensor<3x2xbf16> {
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xbf16>} : () -> tensor<2x3xbf16>
%perms = "tosa.const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: %[[CST:.+]] = "tosa.const"()
// CHECK-SAME{LITERAL}: value = dense<[[0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xbf16>
%1 = "tosa.transpose"(%input, %perms) : (tensor<2x3xbf16>, tensor<2xi32>) -> tensor<3x2xbf16>
// CHECK: return %[[CST]]
return %1 : tensor<3x2xbf16>
}


// CHECK-LABEL: @transpose_fold_2d_float
func.func @transpose_fold_2d_float() -> tensor<3x2xf32> {
%input = "tosa.const"() {value = dense<[[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]]> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
Expand Down