diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h index 87a4078f280f65..0c595a6b109caa 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.h @@ -47,6 +47,9 @@ bool isSupportedFloatType(mlir::Type type); /// Determines whether \p type is a emitc.size_t/ssize_t type. bool isPointerWideType(mlir::Type type); +/// Give the name of the EmitC reference attribute. +StringRef getReferenceAttributeName(); + } // namespace emitc } // namespace mlir diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 452302c565139c..0c945ab2c40304 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -285,9 +285,11 @@ def EmitC_CastOp : EmitC_Op<"cast", ``` }]; - let arguments = (ins EmitCType:$source); + let arguments = (ins EmitCType:$source, + UnitAttr:$reference); let results = (outs EmitCType:$dest); - let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; + let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest) (`ref` $reference^)?"; + let hasVerifier = 1; } def EmitC_CmpOp : EmitC_BinaryOp<"cmp", [CExpression]> { @@ -1050,7 +1052,8 @@ def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> { OptionalAttr:$initial_value, UnitAttr:$extern_specifier, UnitAttr:$static_specifier, - UnitAttr:$const_specifier); + UnitAttr:$const_specifier, + UnitAttr:$reference); let assemblyFormat = [{ (`extern` $extern_specifier^)? @@ -1058,6 +1061,7 @@ def EmitC_GlobalOp : EmitC_Op<"global", [Symbol]> { (`const` $const_specifier^)? $sym_name `:` custom($type, $initial_value) + (`ref` $reference^)? attr-dict }]; diff --git a/mlir/include/mlir/Dialect/EmitC/IR/FunctionOpAssembly.h b/mlir/include/mlir/Dialect/EmitC/IR/FunctionOpAssembly.h new file mode 100644 index 00000000000000..22567c97a21ad7 --- /dev/null +++ b/mlir/include/mlir/Dialect/EmitC/IR/FunctionOpAssembly.h @@ -0,0 +1,43 @@ +//===---------- FunctionOpAssembly.h - Parser for `emitc.func` op ---------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#ifndef MLIR_INCLUDE_MLIR_DIALECT_EMITC_IR_FUNCTIONOPASSEMBLY_H +#define MLIR_INCLUDE_MLIR_DIALECT_EMITC_IR_FUNCTIONOPASSEMBLY_H + +#include "mlir/IR/OperationSupport.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Support/LogicalResult.h" + +#include "mlir/IR/Builders.h" + +namespace mlir::emitc { + +class FuncOp; + +ParseResult +parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, + SmallVectorImpl &arguments, + bool &isVariadic, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs); + +ParseResult +parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, + StringAttr typeAttrName, + function_interface_impl::FuncTypeBuilder funcTypeBuilder, + StringAttr argAttrsName, StringAttr resAttrsName); + +void printFunctionSignature(OpAsmPrinter &p, FuncOp op, ArrayRef argTypes, + bool isVariadic, ArrayRef resultTypes); + +void printFunctionOp(OpAsmPrinter &p, FuncOp op, bool isVariadic, + StringRef typeAttrName, StringAttr argAttrsName, + StringAttr resAttrsName); + +} // namespace mlir::emitc + +#endif // MLIR_INCLUDE_MLIR_DIALECT_EMITC_IR_FUNCTIONOPASSEMBLY_H diff --git a/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt b/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt index 4cc54201d2745d..e1bef7f6851cb2 100644 --- a/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt +++ b/mlir/lib/Dialect/EmitC/IR/CMakeLists.txt @@ -1,5 +1,6 @@ add_mlir_dialect_library(MLIREmitCDialect EmitC.cpp + FunctionOpAssembly.cpp ADDITIONAL_HEADER_DIRS ${MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/EmitC diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index 03f96704ab4f6d..aa2495bc42ba03 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -8,6 +8,7 @@ #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/EmitC/IR/EmitCTraits.h" +#include "mlir/Dialect/EmitC/IR/FunctionOpAssembly.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypes.h" @@ -122,6 +123,8 @@ bool mlir::emitc::isPointerWideType(Type type) { type); } +StringRef mlir::emitc::getReferenceAttributeName() { return "emitc.reference"; } + /// Check that the type of the initial value is compatible with the operations /// result type. static LogicalResult verifyInitializationAttribute(Operation *op, @@ -232,6 +235,13 @@ bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) { emitc::isSupportedFloatType(output) || isa(output))); } +LogicalResult CastOp::verify() { + if (getReference()) + return emitOpError("cast of value type must not bear a reference"); + + return success(); +} + //===----------------------------------------------------------------------===// // CallOpaqueOp //===----------------------------------------------------------------------===// @@ -518,16 +528,15 @@ ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) { function_interface_impl::VariadicFlag, std::string &) { return builder.getFunctionType(argTypes, results); }; - return function_interface_impl::parseFunctionOp( - parser, result, /*allowVariadic=*/false, - getFunctionTypeAttrName(result.name), buildFuncType, - getArgAttrsAttrName(result.name), getResAttrsAttrName(result.name)); + return parseFunctionOp(parser, result, /*allowVariadic=*/false, + getFunctionTypeAttrName(result.name), buildFuncType, + getArgAttrsAttrName(result.name), + getResAttrsAttrName(result.name)); } void FuncOp::print(OpAsmPrinter &p) { - function_interface_impl::printFunctionOp( - p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), - getArgAttrsAttrName(), getResAttrsAttrName()); + printFunctionOp(p, *this, /*isVariadic=*/false, getFunctionTypeAttrName(), + getArgAttrsAttrName(), getResAttrsAttrName()); } LogicalResult FuncOp::verify() { @@ -1029,6 +1038,12 @@ LogicalResult GlobalOp::verify() { } if (getInitialValue().has_value()) { Attribute initValue = getInitialValue().value(); + if (getReference() && !isa(initValue)) { + return emitOpError("global reference initial value must be an opaque " + "attribute, got ") + << initValue; + } + // Check that the type of the initial value is compatible with the type of // the global variable. if (auto elementsAttr = llvm::dyn_cast(initValue)) { @@ -1057,6 +1072,8 @@ LogicalResult GlobalOp::verify() { "or opaque attribute, but got ") << initValue; } + } else if (getReference()) { + return emitOpError("global reference must be initialized"); } if (getStaticSpecifier() && getExternSpecifier()) { return emitOpError("cannot have both static and extern specifiers"); diff --git a/mlir/lib/Dialect/EmitC/IR/FunctionOpAssembly.cpp b/mlir/lib/Dialect/EmitC/IR/FunctionOpAssembly.cpp new file mode 100644 index 00000000000000..0db97a5890868c --- /dev/null +++ b/mlir/lib/Dialect/EmitC/IR/FunctionOpAssembly.cpp @@ -0,0 +1,310 @@ +//===--------- FunctionOpAssembly.cpp - Parser for `emitc.func` op --------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// This function printer/parser are copies of those in +// Interfaces/FunctionImplementation.cpp, except that they print out arguments +// followed by "ref" if they bear the emitc.reference attribute. +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/EmitC/IR/EmitC.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/OperationSupport.h" +#include "mlir/IR/SymbolTable.h" +#include "mlir/Interfaces/FunctionImplementation.h" +#include "mlir/Interfaces/FunctionInterfaces.h" +#include "llvm/ADT/SmallVector.h" + +#include "mlir/Dialect/EmitC/IR/FunctionOpAssembly.h" + +using namespace mlir; + +namespace mlir::emitc { + +static ParseResult +parseFunctionArgumentList(OpAsmParser &parser, bool allowVariadic, + SmallVectorImpl &arguments, + bool &isVariadic) { + + // Parse the function arguments. The argument list either has to consistently + // have ssa-id's followed by types, or just be a type list. It isn't ok to + // sometimes have SSA ID's and sometimes not. + isVariadic = false; + + return parser.parseCommaSeparatedList( + OpAsmParser::Delimiter::Paren, [&]() -> ParseResult { + // Ellipsis must be at end of the list. + if (isVariadic) + return parser.emitError( + parser.getCurrentLocation(), + "variadic arguments must be in the end of the argument list"); + + // Handle ellipsis as a special case. + if (allowVariadic && succeeded(parser.parseOptionalEllipsis())) { + // This is a variadic designator. + isVariadic = true; + return success(); // Stop parsing arguments. + } + // Parse argument name if present. + OpAsmParser::Argument argument; + auto argPresent = parser.parseOptionalArgument( + argument, /*allowType=*/true, /*allowAttrs=*/true); + if (argPresent.has_value()) { + if (failed(argPresent.value())) + return failure(); // Present but malformed. + + // Reject this if the preceding argument was missing a name. + if (!arguments.empty() && arguments.back().ssaName.name.empty()) + return parser.emitError(argument.ssaName.location, + "expected type instead of SSA identifier"); + if (succeeded(parser.parseOptionalKeyword("ref"))) { + llvm::ArrayRef origAttrs; + if (!argument.attrs.empty()) + origAttrs = argument.attrs.getValue(); + + SmallVector attrs(origAttrs); + attrs.push_back(NamedAttribute( + StringAttr::get(parser.getContext(), + emitc::getReferenceAttributeName()), + UnitAttr::get(parser.getContext()))); + argument.attrs = DictionaryAttr::get(parser.getContext(), attrs); + } + } else { + argument.ssaName.location = parser.getCurrentLocation(); + // Otherwise we just have a type list without SSA names. Reject + // this if the preceding argument had a name. + if (!arguments.empty() && !arguments.back().ssaName.name.empty()) + return parser.emitError(argument.ssaName.location, + "expected SSA identifier"); + + NamedAttrList attrs; + if (parser.parseType(argument.type) || + parser.parseOptionalAttrDict(attrs) || + parser.parseOptionalLocationSpecifier(argument.sourceLoc)) + return failure(); + if (succeeded(parser.parseOptionalKeyword("ref"))) { + // Add attribute to argument + attrs.push_back(NamedAttribute( + StringAttr::get(parser.getContext(), + emitc::getReferenceAttributeName()), + UnitAttr::get(parser.getContext()))); + } + argument.attrs = attrs.getDictionary(parser.getContext()); + } + arguments.push_back(argument); + return success(); + }); +} + +/// Parse a function result. +/// +/// function-result ::= type | `(` type attribute-dict? `)` +/// +static ParseResult +parseFunctionResult(OpAsmParser &parser, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs) { + + bool hasLParen = succeeded(parser.parseOptionalLParen()); + + if (hasLParen) { + // Special case for an empty set of parens. + if (succeeded(parser.parseOptionalRParen())) + return success(); + } + + // Parse a single type. + Type ty; + if (parser.parseType(ty)) + return failure(); + resultTypes.push_back(ty); + resultAttrs.emplace_back(); + + // There can be no attribute without parentheses (they would be confused with + // the function body) + if (!hasLParen) + return success(); + + // Parse result attributes if any. + NamedAttrList attrs; + if (succeeded(parser.parseOptionalAttrDict(attrs))) + resultAttrs.back() = attrs.getDictionary(parser.getContext()); + + return parser.parseRParen(); +} + +ParseResult +parseFunctionSignature(OpAsmParser &parser, bool allowVariadic, + SmallVectorImpl &arguments, + bool &isVariadic, SmallVectorImpl &resultTypes, + SmallVectorImpl &resultAttrs) { + if (parseFunctionArgumentList(parser, allowVariadic, arguments, isVariadic)) + return failure(); + if (succeeded(parser.parseOptionalArrow())) + return parseFunctionResult(parser, resultTypes, resultAttrs); + return success(); +} + +ParseResult +parseFunctionOp(OpAsmParser &parser, OperationState &result, bool allowVariadic, + StringAttr typeAttrName, + function_interface_impl::FuncTypeBuilder funcTypeBuilder, + StringAttr argAttrsName, StringAttr resAttrsName) { + SmallVector entryArgs; + SmallVector resultAttrs; + SmallVector resultTypes; + auto &builder = parser.getBuilder(); + + // Parse visibility. + (void)impl::parseOptionalVisibilityKeyword(parser, result.attributes); + + // Parse the name as a symbol. + StringAttr nameAttr; + if (parser.parseSymbolName(nameAttr, SymbolTable::getSymbolAttrName(), + result.attributes)) + return failure(); + + // Parse the function signature. + SMLoc signatureLocation = parser.getCurrentLocation(); + bool isVariadic = false; + if (parseFunctionSignature(parser, allowVariadic, entryArgs, isVariadic, + resultTypes, resultAttrs)) + return failure(); + + std::string errorMessage; + SmallVector argTypes; + argTypes.reserve(entryArgs.size()); + for (auto &arg : entryArgs) + argTypes.push_back(arg.type); + Type type = funcTypeBuilder(builder, argTypes, resultTypes, + function_interface_impl::VariadicFlag(isVariadic), + errorMessage); + if (!type) { + return parser.emitError(signatureLocation) + << "failed to construct function type" + << (errorMessage.empty() ? "" : ": ") << errorMessage; + } + result.addAttribute(typeAttrName, TypeAttr::get(type)); + + // If function attributes are present, parse them. + NamedAttrList parsedAttributes; + SMLoc attributeDictLocation = parser.getCurrentLocation(); + if (parser.parseOptionalAttrDictWithKeyword(parsedAttributes)) + return failure(); + + // Disallow attributes that are inferred from elsewhere in the attribute + // dictionary. + for (StringRef disallowed : + {SymbolTable::getVisibilityAttrName(), SymbolTable::getSymbolAttrName(), + typeAttrName.getValue()}) { + if (parsedAttributes.get(disallowed)) + return parser.emitError(attributeDictLocation, "'") + << disallowed + << "' is an inferred attribute and should not be specified in the " + "explicit attribute dictionary"; + } + result.attributes.append(parsedAttributes); + + // Add the attributes to the function arguments. + assert(resultAttrs.size() == resultTypes.size()); + function_interface_impl::addArgAndResultAttrs( + builder, result, entryArgs, resultAttrs, argAttrsName, resAttrsName); + + // Parse the optional function body. The printer will not print the body if + // its empty, so disallow parsing of empty body in the parser. + auto *body = result.addRegion(); + SMLoc loc = parser.getCurrentLocation(); + OptionalParseResult parseResult = + parser.parseOptionalRegion(*body, entryArgs, + /*enableNameShadowing=*/false); + if (parseResult.has_value()) { + if (failed(*parseResult)) + return failure(); + // Function body was parsed, make sure its not empty. + if (body->empty()) + return parser.emitError(loc, "expected non-empty function body"); + } + return success(); +} + +void printFunctionSignature(OpAsmPrinter &p, FuncOp op, ArrayRef argTypes, + bool isVariadic, ArrayRef resultTypes) { + Region &body = op->getRegion(0); + bool isExternal = body.empty(); + + p << '('; + ArrayAttr argAttrs = op.getArgAttrsAttr(); + for (unsigned i = 0, e = argTypes.size(); i < e; ++i) { + if (i > 0) + p << ", "; + + // Exclude reference attribute if there is to replace it by ref + SmallVector attrs; + if (argAttrs) { + for (auto attr : llvm::cast(argAttrs[i]).getValue()) { + if (attr.getName() != emitc::getReferenceAttributeName()) + attrs.push_back(attr); + } + } + + if (!isExternal) { + p.printRegionArgument(body.getArgument(i), attrs); + } else { + p.printType(argTypes[i]); + if (argAttrs) + p.printOptionalAttrDict(attrs); + } + } + + if (isVariadic) { + if (!argTypes.empty()) + p << ", "; + p << "..."; + } + + p << ')'; + + if (!resultTypes.empty()) { + assert(resultTypes.size() == 1); + p.getStream() << " -> "; + auto resultAttrs = op.getResAttrsAttr(); + p.printType(resultTypes[0]); + if (resultAttrs) + p.printOptionalAttrDict( + llvm::cast(resultAttrs[0]).getValue()); + } +} + +void printFunctionOp(OpAsmPrinter &p, FuncOp op, bool isVariadic, + StringRef typeAttrName, StringAttr argAttrsName, + StringAttr resAttrsName) { + // Print the operation and the function name. + auto funcName = + op->getAttrOfType(SymbolTable::getSymbolAttrName()) + .getValue(); + p << ' '; + + StringRef visibilityAttrName = SymbolTable::getVisibilityAttrName(); + if (auto visibility = op->getAttrOfType(visibilityAttrName)) + p << visibility.getValue() << ' '; + p.printSymbolName(funcName); + + ArrayRef argTypes = op.getArgumentTypes(); + ArrayRef resultTypes = op.getResultTypes(); + printFunctionSignature(p, op, argTypes, isVariadic, resultTypes); + function_interface_impl::printFunctionAttributes( + p, op, {visibilityAttrName, typeAttrName, argAttrsName, resAttrsName}); + // Print the body if this is not an external function. + Region &body = op->getRegion(0); + if (!body.empty()) { + p << ' '; + p.printRegion(body, /*printEntryBlockArgs=*/false, + /*printBlockTerminators=*/true); + } +} + +} // namespace mlir::emitc diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index f61de4a420a649..021bad1f6a7cc1 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -9,6 +9,7 @@ #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" #include "mlir/Dialect/EmitC/IR/EmitC.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Dialect.h" @@ -122,6 +123,9 @@ struct CppEmitter { /// Emits operation 'op' with/without training semicolon or returns failure. LogicalResult emitOperation(Operation &op, bool trailingSemicolon); + /// Emits a reference to type 'type' or returns failure. + LogicalResult emitReferenceToType(Location loc, Type type); + /// Emits type 'type' or returns failure. LogicalResult emitType(Location loc, Type type); @@ -143,8 +147,8 @@ struct CppEmitter { bool trailingSemicolon); /// Emits a declaration of a variable with the given type and name. - LogicalResult emitVariableDeclaration(Location loc, Type type, - StringRef name); + LogicalResult emitVariableDeclaration(Location loc, Type type, StringRef name, + bool isReference); /// Emits the variable declaration and assignment prefix for 'op'. /// - emits separate variable followed by std::tie for multi-valued operation; @@ -726,8 +730,14 @@ static LogicalResult printOperation(CppEmitter &emitter, emitc::CastOp castOp) { if (failed(emitter.emitAssignPrefix(op))) return failure(); os << "("; - if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType()))) - return failure(); + if (castOp.getReference()) { + if (failed(emitter.emitReferenceToType(op.getLoc(), + op.getResult(0).getType()))) + return failure(); + } else { + if (failed(emitter.emitType(op.getLoc(), op.getResult(0).getType()))) + return failure(); + } os << ") "; return emitter.emitOperand(castOp.getOperand()); } @@ -914,26 +924,73 @@ static LogicalResult printOperation(CppEmitter &emitter, ModuleOp moduleOp) { return success(); } +template static LogicalResult printFunctionArgs(CppEmitter &emitter, - Operation *functionOp, + FuncOpClass functionOp, ArrayRef arguments) { raw_indented_ostream &os = emitter.ostream(); + return (interleaveCommaWithError( + llvm::enumerate(arguments), os, [&](auto arg) -> LogicalResult { + bool hasReference = + functionOp.template getArgAttrOfType( + arg.index(), emitc::getReferenceAttributeName()) != nullptr; + if (hasReference) + return emitter.emitReferenceToType(functionOp->getLoc(), arg.value()); + return emitter.emitType(functionOp->getLoc(), arg.value()); + })); +} + +static LogicalResult printFunctionArgs(CppEmitter &emitter, + Operation *functionOp, + ArrayRef arguments) { + if (auto emitCDialectFunc = dyn_cast(functionOp)) { + return printFunctionArgs(emitter, emitCDialectFunc, arguments); + } + if (auto funcDialectFunc = dyn_cast(functionOp)) { + return printFunctionArgs(emitter, funcDialectFunc, arguments); + } + + raw_indented_ostream &os = emitter.ostream(); return ( interleaveCommaWithError(arguments, os, [&](Type arg) -> LogicalResult { return emitter.emitType(functionOp->getLoc(), arg); })); } +template static LogicalResult printFunctionArgs(CppEmitter &emitter, - Operation *functionOp, + FuncOpClass functionOp, Region::BlockArgListType arguments) { raw_indented_ostream &os = emitter.ostream(); return (interleaveCommaWithError( arguments, os, [&](BlockArgument arg) -> LogicalResult { + bool hasReference = functionOp.template getArgAttrOfType( + arg.getArgNumber(), + emitc::getReferenceAttributeName()) != nullptr; return emitter.emitVariableDeclaration( - functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg)); + functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg), + hasReference); + })); +} + +static LogicalResult printFunctionArgs(CppEmitter &emitter, + Operation *functionOp, + Region::BlockArgListType arguments) { + if (auto emitCDialectFunc = dyn_cast(functionOp)) { + return printFunctionArgs(emitter, emitCDialectFunc, arguments); + } + if (auto funcDialectFunc = dyn_cast(functionOp)) { + return printFunctionArgs(emitter, funcDialectFunc, arguments); + } + + raw_indented_ostream &os = emitter.ostream(); + return (interleaveCommaWithError( + arguments, os, [&](BlockArgument arg) -> LogicalResult { + return emitter.emitVariableDeclaration( + functionOp->getLoc(), arg.getType(), emitter.getOrCreateName(arg), + /*isReference=*/false); })); } @@ -1401,9 +1458,18 @@ LogicalResult CppEmitter::emitVariableDeclaration(OpResult result, return result.getDefiningOp()->emitError( "result variable for the operation already declared"); } + Operation *definingOp = result.getDefiningOp(); + bool isReference = false; + // List all ops that can produce references here + if (auto castOp = llvm::dyn_cast(definingOp)) { + isReference = castOp.getReference(); + } + if (auto globalOp = llvm::dyn_cast(definingOp)) { + isReference = globalOp.getReference(); + } if (failed(emitVariableDeclaration(result.getOwner()->getLoc(), - result.getType(), - getOrCreateName(result)))) + result.getType(), getOrCreateName(result), + isReference))) return failure(); if (trailingSemicolon) os << ";\n"; @@ -1419,7 +1485,7 @@ LogicalResult CppEmitter::emitGlobalVariable(GlobalOp op) { os << "const "; if (failed(emitVariableDeclaration(op->getLoc(), op.getType(), - op.getSymName()))) { + op.getSymName(), op.getReference()))) { return failure(); } @@ -1525,11 +1591,17 @@ LogicalResult CppEmitter::emitOperation(Operation &op, bool trailingSemicolon) { } LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type, - StringRef name) { + StringRef name, + bool isReference) { if (auto arrType = dyn_cast(type)) { if (failed(emitType(loc, arrType.getElementType()))) return failure(); - os << " " << name; + os << " "; + if (isReference) + os << "(&"; + os << name; + if (isReference) + os << ")"; for (auto dim : arrType.getShape()) { os << "[" << dim << "]"; } @@ -1537,7 +1609,25 @@ LogicalResult CppEmitter::emitVariableDeclaration(Location loc, Type type, } if (failed(emitType(loc, type))) return failure(); - os << " " << name; + os << " "; + if (isReference) + os << "&"; + os << name; + return success(); +} + +LogicalResult CppEmitter::emitReferenceToType(Location loc, Type type) { + if (auto aType = dyn_cast(type)) { + if (failed(emitType(loc, aType.getElementType()))) + return failure(); + os << " (&)"; + for (auto dim : aType.getShape()) + os << "[" << dim << "]"; + return success(); + } + if (failed(emitType(loc, type))) + return failure(); + os << " &"; return success(); } diff --git a/mlir/test/Dialect/EmitC/func.mlir b/mlir/test/Dialect/EmitC/func.mlir new file mode 100644 index 00000000000000..c047958f44b457 --- /dev/null +++ b/mlir/test/Dialect/EmitC/func.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s -split-input-file + +// CHECK: emitc.func @f +// CHECK-SAME: %{{[^:]*}}: i32 ref +emitc.func @f(%x: i32 {emitc.reference}) { + emitc.return +} + +// ----- + +// CHECK: emitc.func @f +// CHECK-SAME: %{{[^:]*}}: i32 ref +emitc.func @f(%x: i32 ref) { + emitc.return +} + +// ----- + +// CHECK: emitc.func @f +// CHECK-SAME: i32 ref +emitc.func @f(i32 ref) + +// ----- + +// CHECK: emitc.func @f +// CHECK-SAME: i32 ref +emitc.func @f(i32 {emitc.reference}) diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index 8cd8bdca4df336..aa2c969b05cc23 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -317,7 +317,7 @@ func.func @test_expression_multiple_results(%arg0: i32) -> i32 { // ----- -// expected-error @+1 {{'emitc.func' op requires zero or exactly one result, but has 2}} +// expected-error @+1 {{expected ')'}} emitc.func @multiple_results(%0: i32) -> (i32, i32) { emitc.return %0 : i32 } @@ -450,3 +450,13 @@ func.func @use_global() { %0 = emitc.get_global @myglobal : f32 return } + +// ----- + +// expected-error @+1 {{'emitc.global' op global reference initial value must be an opaque attribute, got dense<128>}} +emitc.global const @myref : !emitc.array<2xi16> = dense<128> ref + +// ----- + +// expected-error @+1 {{'emitc.global' op global reference must be initialized}} +emitc.global const @myref : !emitc.array<2xi16> ref diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index 6cfacca6446cbb..7b11c230e9a9dd 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -242,6 +242,7 @@ emitc.global extern @external_linkage : i32 emitc.global static @internal_linkage : i32 emitc.global @myglobal : !emitc.array<2xf32> = dense<4.000000e+00> emitc.global const @myconstant : !emitc.array<2xi16> = dense<2> +emitc.global const @myref : !emitc.array<2xi16> = #emitc.opaque<"myconstant"> ref func.func @use_global(%i: index) -> f32 { %0 = emitc.get_global @myglobal : !emitc.array<2xf32> diff --git a/mlir/test/Target/Cpp/common-cpp.mlir b/mlir/test/Target/Cpp/common-cpp.mlir index 0e24bdd19993f0..a638263cf1350c 100644 --- a/mlir/test/Target/Cpp/common-cpp.mlir +++ b/mlir/test/Target/Cpp/common-cpp.mlir @@ -94,3 +94,13 @@ func.func @apply(%arg0: i32) -> !emitc.ptr { func.func @array_type(%arg0: !emitc.array<3xi32>, %arg1: !emitc.array<10x20xf32>) { return } + +// CHECK: void arg_references(int32_t (&v1)[3], float (&v2)[10][20], int32_t &v3) +func.func @arg_references(%arg0: !emitc.array<3xi32> {emitc.reference}, %arg1: !emitc.array<10x20xf32> {emitc.reference}, %arg2: i32 {emitc.reference}) { + return +} + +// CHECK: void emitc_arg_references(int32_t (&v1)[3], float (&v2)[10][20], int32_t &v3) +emitc.func @emitc_arg_references(%arg0: !emitc.array<3xi32> ref, %arg1: !emitc.array<10x20xf32> ref, %arg2: i32 ref) { + emitc.return +} diff --git a/mlir/test/Target/Cpp/declare_func.mlir b/mlir/test/Target/Cpp/declare_func.mlir index 00680d71824ae0..40f73d659b368c 100644 --- a/mlir/test/Target/Cpp/declare_func.mlir +++ b/mlir/test/Target/Cpp/declare_func.mlir @@ -22,3 +22,17 @@ emitc.declare_func @array_arg emitc.func @array_arg(%arg0: !emitc.array<3xi32>) { emitc.return } + +// CHECK: void reference_scalar_arg(int32_t &[[V2:[^ ]*]]); +emitc.declare_func @reference_scalar_arg +// CHECK: void reference_scalar_arg(int32_t &[[V2:[^ ]*]]) { +emitc.func @reference_scalar_arg(%arg0: i32 ref) { + emitc.return +} + +// CHECK: void reference_array_arg(int32_t (&[[V2:[^ ]*]])[3]); +emitc.declare_func @reference_array_arg +// CHECK: void reference_array_arg(int32_t (&[[V2:[^ ]*]])[3]) { +emitc.func @reference_array_arg(%arg0: !emitc.array<3xi32> ref) { + emitc.return +} diff --git a/mlir/test/Target/Cpp/func.mlir b/mlir/test/Target/Cpp/func.mlir index 9c9ea55bfc4e1a..8adb6a8adbf2fb 100644 --- a/mlir/test/Target/Cpp/func.mlir +++ b/mlir/test/Target/Cpp/func.mlir @@ -43,3 +43,12 @@ emitc.func private @extern_func(i32) attributes {specifiers = ["extern"]} emitc.func private @array_arg(!emitc.array<3xi32>) attributes {specifiers = ["extern"]} // CPP-DEFAULT: extern void array_arg(int32_t[3]); + +emitc.func private @reference_scalar_arg(i32 ref) attributes {specifiers = ["extern"]} +// CPP-DEFAULT: extern void reference_scalar_arg(int32_t &); + +emitc.func private @reference_array_arg(!emitc.array<3xi32> ref) attributes {specifiers = ["extern"]} +// CPP-DEFAULT: extern void reference_array_arg(int32_t (&)[3]); + +emitc.func private @reference_multi_arg(!emitc.array<3xi32> ref, !emitc.array<3xi32>, i32 ref) attributes {specifiers = ["extern"]} +// CPP-DEFAULT: extern void reference_multi_arg(int32_t (&)[3], int32_t[3], int32_t &); diff --git a/mlir/test/Target/Cpp/global.mlir b/mlir/test/Target/Cpp/global.mlir index f0d92e862ae322..059c991fd7839f 100644 --- a/mlir/test/Target/Cpp/global.mlir +++ b/mlir/test/Target/Cpp/global.mlir @@ -36,3 +36,6 @@ func.func @use_global(%i: index) -> f32 { // CHECK-SAME: (size_t [[V1:.*]]) // CHECK: return myglobal[[[V1]]]; } + +emitc.global @ref : i32 = #emitc.opaque<"myglobal_int"> ref +// CHECK: int32_t &ref = myglobal_int; diff --git a/mlir/test/Target/Cpp/invalid.mlir b/mlir/test/Target/Cpp/invalid.mlir index 513371a09cde1d..b7373a03c638d9 100644 --- a/mlir/test/Target/Cpp/invalid.mlir +++ b/mlir/test/Target/Cpp/invalid.mlir @@ -85,3 +85,11 @@ func.func @ptr_to_array() { %v = "emitc.variable"(){value = #emitc.opaque<"NULL">} : () -> !emitc.ptr> return } + +// ----- + +func.func @cast_ref(%arg0 : i32) { + // expected-error@+1 {{'emitc.cast' op cast of value type must not bear a reference}} + %1 = emitc.cast %arg0 : i32 to i32 ref + return +}