diff --git a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td index 0015ff0cd8810..681c8709e574b 100644 --- a/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td +++ b/mlir/include/mlir/Dialect/EmitC/IR/EmitC.td @@ -292,6 +292,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> { Arg:$callee, Arg, "the order of operands and further attributes">:$args, Arg, "template arguments">:$template_args, + Arg, "template argument names">:$template_arg_names, Variadic:$operands ); let results = (outs Variadic); @@ -302,7 +303,7 @@ def EmitC_CallOpaqueOp : EmitC_Op<"call_opaque", [CExpression]> { "::mlir::ValueRange":$operands, CArg<"::mlir::ArrayAttr", "{}">:$args, CArg<"::mlir::ArrayAttr", "{}">:$template_args), [{ - build($_builder, $_state, resultTypes, callee, args, template_args, + build($_builder, $_state, resultTypes, callee, args, template_args, {}, operands); }] > diff --git a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp index c2fb835c2ebc3..7c8267a234368 100644 --- a/mlir/lib/Dialect/EmitC/IR/EmitC.cpp +++ b/mlir/lib/Dialect/EmitC/IR/EmitC.cpp @@ -355,6 +355,19 @@ LogicalResult emitc::CallOpaqueOp::verify() { } } + if (std::optional templateArgNames = getTemplateArgNames()) { + if (std::optional templateArgsAttr = getTemplateArgs()) { + if ((*templateArgNames).size() && + (*templateArgNames).size() != (*templateArgsAttr).size()) { + return emitOpError("number of template argument names must be equal to " + "number of template arguments"); + } + } else { + return emitOpError("should not have names for template arguments if it " + "does not have template arguments"); + } + } + if (llvm::any_of(getResultTypes(), llvm::IsaPred)) { return emitOpError() << "cannot return array type"; } diff --git a/mlir/lib/Target/Cpp/TranslateToCpp.cpp b/mlir/lib/Target/Cpp/TranslateToCpp.cpp index cc32828a0c544..ee94967ab56a8 100644 --- a/mlir/lib/Target/Cpp/TranslateToCpp.cpp +++ b/mlir/lib/Target/Cpp/TranslateToCpp.cpp @@ -659,11 +659,31 @@ static LogicalResult printOperation(CppEmitter &emitter, return success(); }; + auto emitNamedArgs = + [&](std::tuple tuple) + -> LogicalResult { + Attribute attr = std::get<0>(tuple); + StringAttr argName = cast(std::get<1>(tuple)); + + os << "/*" << argName.str() << "=*/"; + return emitArgs(attr); + }; + if (callOpaqueOp.getTemplateArgs()) { os << "<"; - if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os, - emitArgs))) - return failure(); + if (callOpaqueOp.getTemplateArgNames() && + !callOpaqueOp.getTemplateArgNames()->empty()) { + if (failed(interleaveCommaWithError( + llvm::zip_equal(*callOpaqueOp.getTemplateArgs(), + *callOpaqueOp.getTemplateArgNames()), + os, emitNamedArgs))) { + return failure(); + } + } else { + if (failed(interleaveCommaWithError(*callOpaqueOp.getTemplateArgs(), os, + emitArgs))) + return failure(); + } os << ">"; } diff --git a/mlir/test/Dialect/EmitC/invalid_ops.mlir b/mlir/test/Dialect/EmitC/invalid_ops.mlir index 3b4c6046a08c5..7f0e89b57b01a 100644 --- a/mlir/test/Dialect/EmitC/invalid_ops.mlir +++ b/mlir/test/Dialect/EmitC/invalid_ops.mlir @@ -524,3 +524,27 @@ func.func @test_verbatim(%arg0 : !emitc.ptr, %arg1 : i32) { emitc.verbatim "{a} " args %arg0, %arg1 : !emitc.ptr, i32 return } + +// ----- + +func.func @template_args_with_names(%arg0: i32) { + // expected-error @+1 {{'emitc.call_opaque' op number of template argument names must be equal to number of template arguments}} + emitc.call_opaque "kernel1"(%arg0) {template_arg_names = ["N", "P"], template_args = [42 : i32]} : (i32) -> () + return +} + +// ----- + +func.func @template_args_with_names(%arg0: i32) { + // expected-error @+1 {{'emitc.call_opaque' op number of template argument names must be equal to number of template arguments}} + emitc.call_opaque "kernel1"(%arg0) {template_arg_names = ["N"], template_args = [42 : i32, 56 : i32]} : (i32) -> () + return +} + +// ----- + +func.func @template_args_with_names(%arg0: i32) { + // expected-error @+1 {{'emitc.call_opaque' op should not have names for template arguments if it does not have template arguments}} + emitc.call_opaque "kernel1"(%arg0) {template_arg_names = ["N"]} : (i32) -> () + return +} diff --git a/mlir/test/Dialect/EmitC/ops.mlir b/mlir/test/Dialect/EmitC/ops.mlir index 4e86642c2a3a9..8fe0a828c84db 100644 --- a/mlir/test/Dialect/EmitC/ops.mlir +++ b/mlir/test/Dialect/EmitC/ops.mlir @@ -282,3 +282,10 @@ func.func @member_access(%arg0: !emitc.opaque<"mystruct">, %arg1: !emitc.opaque< %2 = "emitc.member_of_ptr" (%arg2) {member = "a"} : (!emitc.ptr>) -> i32 return } + +func.func @template_args_with_names(%arg0: i32, %arg1: f32) { + emitc.call_opaque "kernel1"(%arg0, %arg1) {template_arg_names = ["N", "P"], template_args = [42 : i32, 56]} : (i32, f32) -> () + emitc.call_opaque "kernel2"(%arg0, %arg1) {template_arg_names = ["N"], template_args = [42 : i32]} : (i32, f32) -> () + emitc.call_opaque "kernel3"(%arg0, %arg1) {template_arg_names = [], template_args = [#emitc.opaque<"42">]} : (i32, f32) -> () + return +} diff --git a/mlir/test/Target/Cpp/template_arg_names.mlir b/mlir/test/Target/Cpp/template_arg_names.mlir new file mode 100644 index 0000000000000..f4e504b059474 --- /dev/null +++ b/mlir/test/Target/Cpp/template_arg_names.mlir @@ -0,0 +1,14 @@ +// RUN: mlir-translate -mlir-to-cpp %s | FileCheck %s -check-prefix=CPP-DEFAULT + +// CPP-DEFAULT-LABEL: void basic +func.func @basic(%arg0: i32, %arg1: f32) { + emitc.call_opaque "kernel3"(%arg0, %arg1) : (i32, f32) -> () +// CPP-DEFAULT: kernel3( + emitc.call_opaque "kernel4"(%arg0, %arg1) {template_arg_names = ["N", "P"], template_args = [42 : i32, 56]} : (i32, f32) -> () +// CPP-DEFAULT: kernel4( + emitc.call_opaque "kernel4"(%arg0, %arg1) {template_arg_names = ["N"], template_args = [#emitc.opaque<"42">]} : (i32, f32) -> () +// CPP-DEFAULT: kernel4( + return +} + +