diff --git a/include/TPP/BuilderUtils.h b/include/TPP/BuilderUtils.h index 6f599a761..5bf5ffb05 100644 --- a/include/TPP/BuilderUtils.h +++ b/include/TPP/BuilderUtils.h @@ -42,6 +42,10 @@ Value getConstIndex(OpBuilder &, int); Value getConstInt(OpBuilder &, int, int); Value getConstFloat(OpBuilder &, float, int); +// Return a typed attribute of specified type and value. +// For integer types, the value is rounded toward zero. +TypedAttr getTypedAttr(OpBuilder &builder, Type type, double value); + } // namespace mlir #endif // TPP_BUILDER_UTILS_H diff --git a/lib/TPP/BuilderUtils.cpp b/lib/TPP/BuilderUtils.cpp index 6d55923be..d59b316bf 100644 --- a/lib/TPP/BuilderUtils.cpp +++ b/lib/TPP/BuilderUtils.cpp @@ -108,4 +108,20 @@ Value createDenseMemref(OpBuilder &builder, ModuleOp module, auto nameAttr = builder.getStringAttr(globalName); return builder.create(unkLoc, type, nameAttr); } + +TypedAttr getTypedAttr(OpBuilder &builder, Type type, double value) { + if (isa(type)) + return builder.getFloatAttr(type, value); + if (isa(type)) + return builder.getIndexAttr(value); + if (auto intTp = dyn_cast(type)) + return builder.getIntegerAttr(type, APInt(intTp.getWidth(), value)); + if (isa(type)) { + auto shapedType = cast(type); + if (auto one = getTypedAttr(builder, shapedType.getElementType(), value)) + return DenseElementsAttr::get(shapedType, one); + } + llvm_unreachable("Unsupported attribute type"); +} + } // namespace mlir diff --git a/test/Integration/tpp-run-print-f16.mlir b/test/Integration/tpp-run-print-f16.mlir new file mode 100644 index 000000000..e9da47028 --- /dev/null +++ b/test/Integration/tpp-run-print-f16.mlir @@ -0,0 +1,9 @@ +// RUN: tpp-run %s -e entry -entry-point-result=void -print | FileCheck %s + +func.func @entry(%arg0: memref<8x8xf16>) { + %cst = arith.constant 9.0 : f16 + linalg.fill ins(%cst : f16) outs(%arg0 : memref<8x8xf16>) + return +} + +// CHECK-COUNT-8: ( 9, 9, 9, 9, 9, 9, 9, 9 ) diff --git a/test/Integration/tpp-run-print-f32.mlir b/test/Integration/tpp-run-print-f32.mlir new file mode 100644 index 000000000..983b030d1 --- /dev/null +++ b/test/Integration/tpp-run-print-f32.mlir @@ -0,0 +1,9 @@ +// RUN: tpp-run %s -e entry -entry-point-result=void -print | FileCheck %s + +func.func @entry(%arg0: memref<8x8xf32>) { + %cst = arith.constant 9.0 : f32 + linalg.fill ins(%cst : f32) outs(%arg0 : memref<8x8xf32>) + return +} + +// CHECK-COUNT-8: ( 9, 9, 9, 9, 9, 9, 9, 9 ) diff --git a/test/Integration/tpp-run-print-i32.mlir b/test/Integration/tpp-run-print-i32.mlir new file mode 100644 index 000000000..d4959aa7c --- /dev/null +++ b/test/Integration/tpp-run-print-i32.mlir @@ -0,0 +1,9 @@ +// RUN: tpp-run %s -e entry -entry-point-result=void -print | FileCheck %s + +func.func @entry(%arg0: memref<8x8xi32>) { + %cst = arith.constant 9 : i32 + linalg.fill ins(%cst : i32) outs(%arg0 : memref<8x8xi32>) + return +} + +// CHECK-COUNT-8: ( 9, 9, 9, 9, 9, 9, 9, 9 ) diff --git a/tools/tpp-run/MLIRBench.cpp b/tools/tpp-run/MLIRBench.cpp index 38d46b6ab..166618058 100644 --- a/tools/tpp-run/MLIRBench.cpp +++ b/tools/tpp-run/MLIRBench.cpp @@ -414,6 +414,8 @@ LogicalResult MLIRBench::printShapedType(mlir::Value val) { auto outputType = cast(val.getType()); assert(outputType && "expected a shaped type"); + Type outElmType = outputType.getElementType(); + // Read into a vector and print output // We don't want to alloc the whole tensor as a vector, // so we pick the inner dimension and iterate through the outer ones. @@ -422,29 +424,17 @@ LogicalResult MLIRBench::printShapedType(mlir::Value val) { ArrayRef outerDims(1); if (outputType.getRank() > 1) { ArrayRef innerDims(&outputType.getShape()[lastDim], 1); - vecType = VectorType::get(innerDims, outputType.getElementType()); + vecType = VectorType::get(innerDims, outElmType); outerDims = ArrayRef(&outputType.getShape()[0], outputType.getRank() - 1); } else { - vecType = - VectorType::get(outputType.getShape(), outputType.getElementType()); + vecType = VectorType::get(outputType.getShape(), outElmType); } assert(outerDims.size() == 1 && "Only supports 2D tensors for now"); // Vector undefined value - APFloat vectorFloatValue = APFloat(-1.0F); - Value minusOne; - if (outputType.getElementType().isBF16()) { - bool ignored; - vectorFloatValue.convert(APFloat::BFloat(), APFloat::rmNearestTiesToEven, - &ignored); - - minusOne = builder.create( - unkLoc, vectorFloatValue, FloatType::getBF16(builder.getContext())); - } else { - minusOne = builder.create(unkLoc, vectorFloatValue, - builder.getF32Type()); - } + Value minusOne = builder.create( + unkLoc, getTypedAttr(builder, outElmType, -1.0)); // Loop through the shaped type, transfer each dim to vector auto count = getConstIndex(builder, outerDims[0]);