Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bench - print more data types #712

Merged
merged 1 commit into from
Sep 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions include/TPP/BuilderUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ Value getConstIndex(OpBuilder &, int);
Value getConstInt(OpBuilder &, int, int);
Value getConstFloat(OpBuilder &, float, int);

// Return a typed attribute of specified type and value.
// For integer types, the value is rounded toward zero.
TypedAttr getTypedAttr(OpBuilder &builder, Type type, double value);

} // namespace mlir

#endif // TPP_BUILDER_UTILS_H
16 changes: 16 additions & 0 deletions lib/TPP/BuilderUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,20 @@ Value createDenseMemref(OpBuilder &builder, ModuleOp module,
auto nameAttr = builder.getStringAttr(globalName);
return builder.create<memref::GetGlobalOp>(unkLoc, type, nameAttr);
}

TypedAttr getTypedAttr(OpBuilder &builder, Type type, double value) {
if (isa<FloatType>(type))
return builder.getFloatAttr(type, value);
if (isa<IndexType>(type))
return builder.getIndexAttr(value);
if (auto intTp = dyn_cast<IntegerType>(type))
return builder.getIntegerAttr(type, APInt(intTp.getWidth(), value));
if (isa<RankedTensorType, VectorType>(type)) {
auto shapedType = cast<ShapedType>(type);
if (auto one = getTypedAttr(builder, shapedType.getElementType(), value))
return DenseElementsAttr::get(shapedType, one);
}
llvm_unreachable("Unsupported attribute type");
}

} // namespace mlir
9 changes: 9 additions & 0 deletions test/Integration/tpp-run-print-f16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: tpp-run %s -e entry -entry-point-result=void -print | FileCheck %s

func.func @entry(%arg0: memref<8x8xf16>) {
%cst = arith.constant 9.0 : f16
linalg.fill ins(%cst : f16) outs(%arg0 : memref<8x8xf16>)
return
}

// CHECK-COUNT-8: ( 9, 9, 9, 9, 9, 9, 9, 9 )
9 changes: 9 additions & 0 deletions test/Integration/tpp-run-print-f32.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: tpp-run %s -e entry -entry-point-result=void -print | FileCheck %s

func.func @entry(%arg0: memref<8x8xf32>) {
%cst = arith.constant 9.0 : f32
linalg.fill ins(%cst : f32) outs(%arg0 : memref<8x8xf32>)
return
}

// CHECK-COUNT-8: ( 9, 9, 9, 9, 9, 9, 9, 9 )
9 changes: 9 additions & 0 deletions test/Integration/tpp-run-print-i32.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: tpp-run %s -e entry -entry-point-result=void -print | FileCheck %s

func.func @entry(%arg0: memref<8x8xi32>) {
%cst = arith.constant 9 : i32
linalg.fill ins(%cst : i32) outs(%arg0 : memref<8x8xi32>)
return
}

// CHECK-COUNT-8: ( 9, 9, 9, 9, 9, 9, 9, 9 )
22 changes: 6 additions & 16 deletions tools/tpp-run/MLIRBench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,8 @@ LogicalResult MLIRBench::printShapedType(mlir::Value val) {
auto outputType = cast<ShapedType>(val.getType());
assert(outputType && "expected a shaped type");

Type outElmType = outputType.getElementType();

// Read into a vector and print output
// We don't want to alloc the whole tensor as a vector,
// so we pick the inner dimension and iterate through the outer ones.
Expand All @@ -422,29 +424,17 @@ LogicalResult MLIRBench::printShapedType(mlir::Value val) {
ArrayRef<int64_t> outerDims(1);
if (outputType.getRank() > 1) {
ArrayRef<int64_t> innerDims(&outputType.getShape()[lastDim], 1);
vecType = VectorType::get(innerDims, outputType.getElementType());
vecType = VectorType::get(innerDims, outElmType);
outerDims =
ArrayRef<int64_t>(&outputType.getShape()[0], outputType.getRank() - 1);
} else {
vecType =
VectorType::get(outputType.getShape(), outputType.getElementType());
vecType = VectorType::get(outputType.getShape(), outElmType);
}
assert(outerDims.size() == 1 && "Only supports 2D tensors for now");

// Vector undefined value
APFloat vectorFloatValue = APFloat(-1.0F);
Value minusOne;
if (outputType.getElementType().isBF16()) {
bool ignored;
vectorFloatValue.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven,
&ignored);

minusOne = builder.create<arith::ConstantFloatOp>(
unkLoc, vectorFloatValue, FloatType::getBF16(builder.getContext()));
} else {
minusOne = builder.create<arith::ConstantFloatOp>(unkLoc, vectorFloatValue,
builder.getF32Type());
}
Value minusOne = builder.create<arith::ConstantOp>(
unkLoc, getTypedAttr(builder, outElmType, -1.0));

// Loop through the shaped type, transfer each dim to vector
auto count = getConstIndex(builder, outerDims[0]);
Expand Down