Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FXML-4791] Add printout of references with emitc.reference attr #316

Merged
merged 20 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions mlir/include/mlir/Dialect/EmitC/IR/EmitC.td
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,7 @@ def EmitC_CastOp : EmitC_Op<"cast",
let arguments = (ins EmitCType:$source);
let results = (outs EmitCType:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
cferry-AMD marked this conversation as resolved.
Show resolved Hide resolved
let hasVerifier = 1;
cferry-AMD marked this conversation as resolved.
Show resolved Hide resolved
}

def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> {
Expand Down
9 changes: 9 additions & 0 deletions mlir/lib/Dialect/EmitC/IR/EmitC.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -232,6 +232,15 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
emitc::isSupportedFloatType(output) || isa<emitc::PointerType>(output)));
}

LogicalResult CastOp::verify() {
bool isReference = (*this)->hasAttrOfType<UnitAttr>("emitc.reference");

if (isReference)
return emitOpError("cast of value type must not bear a reference");

return success();
}

//===----------------------------------------------------------------------===//
// CallOp
//===----------------------------------------------------------------------===//
Expand Down
111 changes: 98 additions & 13 deletions mlir/lib/Target/Cpp/TranslateToCpp.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "mlir/IR/Dialect.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/IR/Value.h"
#include "mlir/Support/IndentedOstream.h"
#include "mlir/Support/LLVM.h"
#include "mlir/Target/Cpp/CppEmitter.h"
Expand Down Expand Up @@ -122,6 +123,9 @@ struct CppEmitter {
/// Emits operation 'op' with/without training semicolon or returns failure.
LogicalResult emitOperation(Operation &op, bool trailingSemicolon);

/// Emits a reference to type 'type' or returns failure.
LogicalResult emitReferenceToType(Location loc, Type type);

/// Emits type 'type' or returns failure.
LogicalResult emitType(Location loc, Type type);

Expand All @@ -143,8 +147,8 @@ struct CppEmitter {
bool trailingSemicolon);

/// Emits a declaration of a variable with the given type and name.
LogicalResult emitVariableDeclaration(Location loc, Type type,
StringRef name);
LogicalResult emitVariableDeclaration(Location loc, Type type, StringRef name,
bool isReference);

/// Emits the variable declaration and assignment prefix for 'op'.
/// - emits separate variable followed by std::tie for multi-valued operation;
Expand Down Expand Up @@ -719,8 +723,15 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) {
if (failed(emitter.emitAssignPrefix(op)))
return failure();
os << "(";
if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
return failure();
// Cast of lvalues return lvalues and therefore references.
cferry-AMD marked this conversation as resolved.
Show resolved Hide resolved
cferry-AMD marked this conversation as resolved.
Show resolved Hide resolved
if (castOp->hasAttrOfType<UnitAttr>("emitc.reference")) {
if (failed(emitter.emitReferenceToType(op.getLoc(),
op.getResult(0).getType())))
return failure();
} else {
if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType())))
return failure();
}
os << ") ";
return emitter.emitOperand(castOp.getOperand());
}
Expand Down Expand Up @@ -907,29 +918,78 @@ static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) {
return success();
}

template <class FuncOpClass>
static LogicalResult printFunctionArgs(CppEmitter &emitter,
Operation *functionOp,
FuncOpClass functionOp,
cferry-AMD marked this conversation as resolved.
Show resolved Hide resolved
ArrayRef<Type> arguments) {
raw_indented_ostream &os = emitter.ostream();

uint32_t index = 0;

return (
interleaveCommaWithError(arguments, os, [&](Type arg) -> LogicalResult {
cferry-AMD marked this conversation as resolved.
Show resolved Hide resolved
bool hasReference = functionOp.template getArgAttrOfType<UnitAttr>(
index, "emitc.reference") != nullptr;
index += 1;
if (hasReference)
return emitter.emitReferenceToType(functionOp->getLoc(), arg);
return emitter.emitType(functionOp->getLoc(), arg);
}));
}

static LogicalResult printFunctionArgs(CppEmitter &emitter,
Operation *functionOp,
ArrayRef<Type> arguments) {
if (auto emitCDialectFunc = dyn_cast<emitc::FuncOp>(functionOp)) {
return printFunctionArgs(emitter, emitCDialectFunc, arguments);
}
if (auto funcDialectFunc = dyn_cast<func::FuncOp>(functionOp)) {
return printFunctionArgs(emitter, funcDialectFunc, arguments);
}

raw_indented_ostream &os = emitter.ostream();
return (
interleaveCommaWithError(arguments, os, [&](Type arg) -> LogicalResult {
return emitter.emitType(functionOp->getLoc(), arg);
}));
}

template <class FuncOpClass>
static LogicalResult printFunctionArgs(CppEmitter &emitter,
FuncOpClass functionOp,
Region::BlockArgListType arguments) {
raw_indented_ostream &os = emitter.ostream();

return (interleaveCommaWithError(
arguments, os, [&](BlockArgument arg) -> LogicalResult {
bool hasReference =
functionOp.template getArgAttrOfType<UnitAttr>(
arg.getArgNumber(), "emitc.reference") != nullptr;
return emitter.emitVariableDeclaration(
functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg));
functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg),
hasReference);
}));
}

static LogicalResult printFunctionArgs(CppEmitter &emitter,
Operation *functionOp,
Region::BlockArgListType arguments) {
if (auto emitCDialectFunc = dyn_cast<emitc::FuncOp>(functionOp)) {
return printFunctionArgs(emitter, emitCDialectFunc, arguments);
}
if (auto funcDialectFunc = dyn_cast<func::FuncOp>(functionOp)) {
return printFunctionArgs(emitter, funcDialectFunc, arguments);
}

raw_indented_ostream &os = emitter.ostream();
return (interleaveCommaWithError(arguments, os,
[&](BlockArgument arg) -> LogicalResult {
return emitter.emitVariableDeclaration(
functionOp->getLoc(), arg.getType(),
emitter.getOrCreateName(arg), false);
cferry-AMD marked this conversation as resolved.
Show resolved Hide resolved
}));
}

static LogicalResult printFunctionBody(CppEmitter &emitter,
Operation *functionOp,
Region::BlockListType &blocks) {
Expand Down Expand Up @@ -1394,9 +1454,10 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result,
return result.getDefiningOp()->emitError(
"result variable for the operation already declared");
}
if (failed(emitVariableDeclaration(result.getOwner()->getLoc(),
result.getType(),
getOrCreateName(result))))
if (failed(emitVariableDeclaration(
result.getOwner()->getLoc(), result.getType(),
getOrCreateName(result),
result.getDefiningOp()->hasAttrOfType<UnitAttr>("emitc.reference"))))
return failure();
if (trailingSemicolon)
os << ";\n";
Expand All @@ -1412,7 +1473,7 @@ LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) {
os << "const ";

if (failed(emitVariableDeclaration(op->getLoc(), op.getType(),
op.getSymName()))) {
op.getSymName(), false))) {
cferry-AMD marked this conversation as resolved.
Show resolved Hide resolved
return failure();
}

Expand Down Expand Up @@ -1518,19 +1579,43 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) {
}

LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type,
StringRef name) {
StringRef name,
bool isReference) {
if (auto arrType = dyn_cast<emitc::ArrayType>(type)) {
if (failed(emitType(loc, arrType.getElementType())))
return failure();
os << " " << name;
os << " ";
if (isReference)
os << "(&";
os << name;
if (isReference)
os << ")";
for (auto dim : arrType.getShape()) {
os << "[" << dim << "]";
}
return success();
}
if (failed(emitType(loc, type)))
return failure();
os << " " << name;
os << " ";
if (isReference)
os << "&";
os << name;
return success();
}

LogicalResult CppEmitter::emitReferenceToType(Location loc, Type type) {
if (auto aType = dyn_cast<ArrayType>(type)) {
if (failed(emitType(loc, aType.getElementType())))
return failure();
os << " (&)";
for (auto dim : aType.getShape())
os << "[" << dim << "]";
return success();
}
if (failed(emitType(loc, type)))
return failure();
os << " &";
return success();
}

Expand Down
5 changes: 5 additions & 0 deletions mlir/test/Target/Cpp/common-cpp.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -94,3 +94,8 @@ func.func @apply(%arg0: i32) -> !emitc.ptr<i32> {
func.func @array_type(%arg0: !emitc.array<3xi32>, %arg1: !emitc.array<10x20xf32>) {
return
}

// CHECK: void arg_references(int32_t (&v1)[3], float (&v2)[10][20], int32_t &v3)
func.func @arg_references(%arg0: !emitc.array<3xi32> {emitc.reference}, %arg1: !emitc.array<10x20xf32> {emitc.reference}, %arg2: i32 {emitc.reference}) {
return
}
14 changes: 14 additions & 0 deletions mlir/test/Target/Cpp/declare_func.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,17 @@ emitc.declare_func @array_arg
emitc.func @array_arg(%arg0: !emitc.array<3xi32>) {
emitc.return
}

// CHECK: void reference_scalar_arg(int32_t &[[V2:[^ ]*]]);
emitc.declare_func @reference_scalar_arg
// CHECK: void reference_scalar_arg(int32_t &[[V2:[^ ]*]]) {
emitc.func @reference_scalar_arg(%arg0: i32 {emitc.reference}) {
cferry-AMD marked this conversation as resolved.
Show resolved Hide resolved
cferry-AMD marked this conversation as resolved.
Show resolved Hide resolved
emitc.return
}

// CHECK: void reference_array_arg(int32_t (&[[V2:[^ ]*]])[3]);
emitc.declare_func @reference_array_arg
// CHECK: void reference_array_arg(int32_t (&[[V2:[^ ]*]])[3]) {
emitc.func @reference_array_arg(%arg0: !emitc.array<3xi32> {emitc.reference}) {
emitc.return
}
9 changes: 9 additions & 0 deletions mlir/test/Target/Cpp/func.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,12 @@ emitc.func private @extern_func(i32) attributes {specifiers = ["extern"]}

emitc.func private @array_arg(!emitc.array<3xi32>) attributes {specifiers = ["extern"]}
// CPP-DEFAULT: extern void array_arg(int32_t[3]);

emitc.func private @reference_scalar_arg(i32 {emitc.reference}) attributes {specifiers = ["extern"]}
// CPP-DEFAULT: extern void reference_scalar_arg(int32_t &);

emitc.func private @reference_array_arg(!emitc.array<3xi32> {emitc.reference}) attributes {specifiers = ["extern"]}
// CPP-DEFAULT: extern void reference_array_arg(int32_t (&)[3]);

emitc.func private @reference_multi_arg(!emitc.array<3xi32> {emitc.reference}, !emitc.array<3xi32>, i32 {emitc.reference}) attributes {specifiers = ["extern"]}
// CPP-DEFAULT: extern void reference_multi_arg(int32_t (&)[3], int32_t[3], int32_t &);
8 changes: 8 additions & 0 deletions mlir/test/Target/Cpp/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,11 @@ func.func @ptr_to_array() {
%v = "emitc.variable"(){value = #emitc.opaque<"NULL">} : () -> !emitc.ptr<!emitc.array<9xi16>>
return
}

// -----

func.func @cast_ref(%arg0 : i32) {
// expected-error@+1 {{'emitc.cast' op cast of value type must not bear a reference}}
cferry-AMD marked this conversation as resolved.
Show resolved Hide resolved
%1 = emitc.cast %arg0 {emitc.reference} : i32 to i32
return
}
Loading