Skip to content

Commit

Permalink
Bench - print more data types (#712)
Browse files Browse the repository at this point in the history
Expands tpp-run '-print' utility to support more data types.
  • Loading branch information
adam-smnk authored Sep 14, 2023
1 parent ab7a026 commit 5122399
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 16 deletions.
4 changes: 4 additions & 0 deletions include/TPP/BuilderUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ Value getConstIndex(OpBuilder &, int);
Value getConstInt(OpBuilder &, int, int);
Value getConstFloat(OpBuilder &, float, int);

// Return a typed attribute of specified type and value.
// For integer types, the value is rounded toward zero.
TypedAttr getTypedAttr(OpBuilder &builder, Type type, double value);

} // namespace mlir

#endif // TPP_BUILDER_UTILS_H
16 changes: 16 additions & 0 deletions lib/TPP/BuilderUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,4 +108,20 @@ Value createDenseMemref(OpBuilder &builder, ModuleOp module,
auto nameAttr = builder.getStringAttr(globalName);
return builder.create<memref::GetGlobalOp>(unkLoc, type, nameAttr);
}

TypedAttr getTypedAttr(OpBuilder &builder, Type type, double value) {
if (isa<FloatType>(type))
return builder.getFloatAttr(type, value);
if (isa<IndexType>(type))
return builder.getIndexAttr(value);
if (auto intTp = dyn_cast<IntegerType>(type))
return builder.getIntegerAttr(type, APInt(intTp.getWidth(), value));
if (isa<RankedTensorType, VectorType>(type)) {
auto shapedType = cast<ShapedType>(type);
if (auto one = getTypedAttr(builder, shapedType.getElementType(), value))
return DenseElementsAttr::get(shapedType, one);
}
llvm_unreachable("Unsupported attribute type");
}

} // namespace mlir
9 changes: 9 additions & 0 deletions test/Integration/tpp-run-print-f16.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: tpp-run %s -e entry -entry-point-result=void -print | FileCheck %s

func.func @entry(%arg0: memref<8x8xf16>) {
%cst = arith.constant 9.0 : f16
linalg.fill ins(%cst : f16) outs(%arg0 : memref<8x8xf16>)
return
}

// CHECK-COUNT-8: ( 9, 9, 9, 9, 9, 9, 9, 9 )
9 changes: 9 additions & 0 deletions test/Integration/tpp-run-print-f32.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: tpp-run %s -e entry -entry-point-result=void -print | FileCheck %s

func.func @entry(%arg0: memref<8x8xf32>) {
%cst = arith.constant 9.0 : f32
linalg.fill ins(%cst : f32) outs(%arg0 : memref<8x8xf32>)
return
}

// CHECK-COUNT-8: ( 9, 9, 9, 9, 9, 9, 9, 9 )
9 changes: 9 additions & 0 deletions test/Integration/tpp-run-print-i32.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
// RUN: tpp-run %s -e entry -entry-point-result=void -print | FileCheck %s

func.func @entry(%arg0: memref<8x8xi32>) {
%cst = arith.constant 9 : i32
linalg.fill ins(%cst : i32) outs(%arg0 : memref<8x8xi32>)
return
}

// CHECK-COUNT-8: ( 9, 9, 9, 9, 9, 9, 9, 9 )
22 changes: 6 additions & 16 deletions tools/tpp-run/MLIRBench.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,8 @@ LogicalResult MLIRBench::printShapedType(mlir::Value val) {
auto outputType = cast<ShapedType>(val.getType());
assert(outputType && "expected a shaped type");

Type outElmType = outputType.getElementType();

// Read into a vector and print output
// We don't want to alloc the whole tensor as a vector,
// so we pick the inner dimension and iterate through the outer ones.
Expand All @@ -422,29 +424,17 @@ LogicalResult MLIRBench::printShapedType(mlir::Value val) {
ArrayRef<int64_t> outerDims(1);
if (outputType.getRank() > 1) {
ArrayRef<int64_t> innerDims(&outputType.getShape()[lastDim], 1);
vecType = VectorType::get(innerDims, outputType.getElementType());
vecType = VectorType::get(innerDims, outElmType);
outerDims =
ArrayRef<int64_t>(&outputType.getShape()[0], outputType.getRank() - 1);
} else {
vecType =
VectorType::get(outputType.getShape(), outputType.getElementType());
vecType = VectorType::get(outputType.getShape(), outElmType);
}
assert(outerDims.size() == 1 && "Only supports 2D tensors for now");

// Vector undefined value
APFloat vectorFloatValue = APFloat(-1.0F);
Value minusOne;
if (outputType.getElementType().isBF16()) {
bool ignored;
vectorFloatValue.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven,
&ignored);

minusOne = builder.create<arith::ConstantFloatOp>(
unkLoc, vectorFloatValue, FloatType::getBF16(builder.getContext()));
} else {
minusOne = builder.create<arith::ConstantFloatOp>(unkLoc, vectorFloatValue,
builder.getF32Type());
}
Value minusOne = builder.create<arith::ConstantOp>(
unkLoc, getTypedAttr(builder, outElmType, -1.0));

// Loop through the shaped type, transfer each dim to vector
auto count = getConstIndex(builder, outerDims[0]);
Expand Down

0 comments on commit 5122399

Please sign in to comment.