Skip to content

Commit 955600d

Browse files
authored
core: Add attributes to func outputs (#3661)
Mlir [supports](https://mlir.llvm.org/docs/Dialects/Func/#funcfunc-funcfuncop) attributes for output types. We need for cross compatibility with other mlir projects.
1 parent e5e9069 commit 955600d

File tree

8 files changed

+115
-45
lines changed

8 files changed

+115
-45
lines changed

tests/filecheck/dialects/func/func_ops.mlir

+10
Original file line numberDiff line numberDiff line change
@@ -71,4 +71,14 @@ builtin.module {
7171
// CHECK: func.func public @arg_attrs(%{{.*}} : tensor<8x8xf64> {"llvm.noalias"}, %{{.*}} : tensor<8x8xf64> {"llvm.noalias"}, %{{.*}} : tensor<8x8xf64> {"llvm.noalias"}) -> tensor<8x8xf64> {
7272
// CHECK-NEXT: return %{{.*}} : tensor<8x8xf64>
7373
// CHECK-NEXT: }
74+
75+
func.func @output_attributes() -> (f32 {dialect.a = 0 : i32}, f32 {dialect.b = 0 : i32, dialect.c = 1 : i64}) {
76+
%r1, %r2 = "test.op"() : () -> (f32, f32)
77+
return %r1, %r2 : f32, f32
78+
}
79+
80+
// CHECK: func.func @output_attributes() -> (f32 {"dialect.a" = 0 : i32}, f32 {"dialect.b" = 0 : i32, "dialect.c" = 1 : i64}) {
81+
// CHECK-NEXT: %r1, %r2 = "test.op"() : () -> (f32, f32)
82+
// CHECK-NEXT: func.return %r1, %r2 : f32, f32
83+
// CHECK-NEXT: }
7484
}

tests/filecheck/dialects/func/func_ops_generic.mlir

+10
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,13 @@
1010
// CHECK-NEXT: ^0(%arg0 : tensor<8x8xf64>, %arg1 : tensor<8x8xf64>):
1111
// CHECK-NEXT: "func.return"(%arg0, %arg1) : (tensor<8x8xf64>, tensor<8x8xf64>) -> ()
1212
// CHECK-NEXT: }) : () -> ()
13+
14+
func.func @output_attributes() -> (f32 {dialect.a = 0 : i32}, f32 {dialect.b = 0 : i32, dialect.c = 1 : i64}) {
15+
%r1, %r2 = "test.op"() : () -> (f32, f32)
16+
return %r1, %r2 : f32, f32
17+
}
18+
19+
// CHECK: "func.func"() <{"sym_name" = "output_attributes", "function_type" = () -> (f32, f32), "res_attrs" = [{"dialect.a" = 0 : i32}, {"dialect.b" = 0 : i32, "dialect.c" = 1 : i64}]}> ({
20+
// CHECK-NEXT: %r1, %r2 = "test.op"() : () -> (f32, f32)
21+
// CHECK-NEXT: "func.return"(%r1, %r2) : (f32, f32) -> ()
22+
// CHECK-NEXT: }) : () -> ()

tests/filecheck/mlir-conversion/with-mlir/dialects/func/func_ops.mlir

+10
Original file line numberDiff line numberDiff line change
@@ -72,4 +72,14 @@ builtin.module {
7272
// CHECK: func.func public @arg_attrs(%{{.*}}: tensor<8x8xf64> {"llvm.noalias"}, %{{.*}}: tensor<8x8xf64> {"llvm.noalias"}, %{{.*}}: tensor<8x8xf64> {"llvm.noalias"}) -> tensor<8x8xf64> {
7373
// CHECK-NEXT: func.return %{{.*}} : tensor<8x8xf64>
7474
// CHECK-NEXT: }
75+
76+
func.func @output_attributes() -> (f32 {dialect.a = 0 : i32}, f32 {dialect.b = 0 : i32, dialect.c = 1 : i64}) {
77+
%r1, %r2 = "test.op"() : () -> (f32, f32)
78+
return %r1, %r2 : f32, f32
79+
}
80+
81+
// CHECK: func.func @output_attributes() -> (f32 {"dialect.a" = 0 : i32}, f32 {"dialect.b" = 0 : i32, "dialect.c" = 1 : i64}) {
82+
// CHECK-NEXT: %0, %1 = "test.op"() : () -> (f32, f32)
83+
// CHECK-NEXT: func.return %0, %1 : f32, f32
84+
// CHECK-NEXT: }
7585
}

xdsl/dialects/arm_func.py

+8-10
Original file line numberDiff line numberDiff line change
@@ -83,18 +83,16 @@ def __init__(
8383
@classmethod
8484
def parse(cls, parser: Parser) -> FuncOp:
8585
visibility = parser.parse_optional_visibility_keyword()
86-
(
87-
name,
88-
input_types,
89-
return_types,
90-
region,
91-
extra_attrs,
92-
arg_attrs,
93-
) = parse_func_op_like(
94-
parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility")
86+
(name, input_types, return_types, region, extra_attrs, arg_attrs, res_attrs) = (
87+
parse_func_op_like(
88+
parser,
89+
reserved_attr_names=("sym_name", "function_type", "sym_visibility"),
90+
)
9591
)
9692
if arg_attrs:
97-
raise NotImplementedError("arg_attrs not implemented in riscv_func")
93+
raise NotImplementedError("arg_attrs not implemented in arm_func")
94+
if res_attrs:
95+
raise NotImplementedError("res_attrs not implemented in arm_func")
9896
func = FuncOp(name, region, (input_types, return_types), visibility)
9997
if extra_attrs is not None:
10098
func.attributes |= extra_attrs.data

xdsl/dialects/csl/csl.py

+15-18
Original file line numberDiff line numberDiff line change
@@ -780,17 +780,16 @@ def verify_(self) -> None:
780780

781781
@classmethod
782782
def parse(cls, parser: Parser) -> FuncOp:
783-
(
784-
name,
785-
input_types,
786-
return_types,
787-
region,
788-
extra_attrs,
789-
arg_attrs,
790-
) = parse_func_op_like(
791-
parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility")
783+
(name, input_types, return_types, region, extra_attrs, arg_attrs, res_attrs) = (
784+
parse_func_op_like(
785+
parser,
786+
reserved_attr_names=("sym_name", "function_type", "sym_visibility"),
787+
)
792788
)
793789

790+
if res_attrs:
791+
raise NotImplementedError("res_attrs not implemented in csl FuncOp")
792+
794793
assert (
795794
len(return_types) <= 1
796795
), f"{cls.name} can't have more than one result type!"
@@ -890,16 +889,14 @@ def verify_(self) -> None:
890889
@classmethod
891890
def parse(cls, parser: Parser) -> TaskOp:
892891
pos = parser.pos
893-
(
894-
name,
895-
input_types,
896-
return_types,
897-
region,
898-
extra_attrs,
899-
arg_attrs,
900-
) = parse_func_op_like(
901-
parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility")
892+
(name, input_types, return_types, region, extra_attrs, arg_attrs, res_attrs) = (
893+
parse_func_op_like(
894+
parser,
895+
reserved_attr_names=("sym_name", "function_type", "sym_visibility"),
896+
)
902897
)
898+
if res_attrs:
899+
raise NotImplementedError("res_attrs not implemented in csl TaskOp")
903900
if (
904901
extra_attrs is None
905902
or "kind" not in extra_attrs.data

xdsl/dialects/func.py

+3
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,7 @@ def parse(cls, parser: Parser) -> FuncOp:
166166
region,
167167
extra_attrs,
168168
arg_attrs,
169+
res_attrs,
169170
) = parse_func_op_like(
170171
parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility")
171172
)
@@ -175,6 +176,7 @@ def parse(cls, parser: Parser) -> FuncOp:
175176
region=region,
176177
visibility=visibility,
177178
arg_attrs=arg_attrs,
179+
res_attrs=res_attrs,
178180
)
179181
if extra_attrs is not None:
180182
func.attributes |= extra_attrs.data
@@ -192,6 +194,7 @@ def print(self, printer: Printer):
192194
self.body,
193195
self.attributes,
194196
arg_attrs=self.arg_attrs,
197+
res_attrs=self.res_attrs,
195198
reserved_attr_names=(
196199
"sym_name",
197200
"function_type",

xdsl/dialects/riscv_func.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -174,18 +174,16 @@ def __init__(
174174
@classmethod
175175
def parse(cls, parser: Parser) -> FuncOp:
176176
visibility = parser.parse_optional_visibility_keyword()
177-
(
178-
name,
179-
input_types,
180-
return_types,
181-
region,
182-
extra_attrs,
183-
arg_attrs,
184-
) = parse_func_op_like(
185-
parser, reserved_attr_names=("sym_name", "function_type", "sym_visibility")
177+
(name, input_types, return_types, region, extra_attrs, arg_attrs, res_attrs) = (
178+
parse_func_op_like(
179+
parser,
180+
reserved_attr_names=("sym_name", "function_type", "sym_visibility"),
181+
)
186182
)
187183
if arg_attrs:
188184
raise NotImplementedError("arg_attrs not implemented in riscv_func")
185+
if res_attrs:
186+
raise NotImplementedError("res_attrs not implemented in riscv_func")
189187
func = FuncOp(name, region, (input_types, return_types), visibility)
190188
if extra_attrs is not None:
191189
func.attributes |= extra_attrs.data

xdsl/dialects/utils/format.py

+52-8
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def print_func_op_like(
4343
attributes: dict[str, Attribute],
4444
*,
4545
arg_attrs: ArrayAttr[DictionaryAttr] | None = None,
46+
res_attrs: ArrayAttr[DictionaryAttr] | None = None,
4647
reserved_attr_names: Sequence[str],
4748
):
4849
printer.print(f" @{sym_name.data}")
@@ -62,7 +63,15 @@ def print_func_op_like(
6263
printer.print("-> ")
6364
if len(function_type.outputs) > 1:
6465
printer.print("(")
65-
printer.print_list(function_type.outputs, printer.print_attribute)
66+
if res_attrs is not None:
67+
printer.print_list(
68+
zip(function_type.outputs, res_attrs),
69+
lambda arg_with_attrs: print_func_output(
70+
printer, arg_with_attrs[0], arg_with_attrs[1]
71+
),
72+
)
73+
else:
74+
printer.print_list(function_type.outputs, printer.print_attribute)
6675
if len(function_type.outputs) > 1:
6776
printer.print(")")
6877
printer.print(" ")
@@ -85,9 +94,10 @@ def parse_func_op_like(
8594
Region,
8695
DictionaryAttr | None,
8796
ArrayAttr[DictionaryAttr] | None,
97+
ArrayAttr[DictionaryAttr] | None,
8898
]:
8999
"""
90-
Returns the function name, argument types, return types, body, extra args, and arg_attrs.
100+
Returns the function name, argument types, return types, body, extra args, arg_attrs and res_attrs.
91101
"""
92102
# Parse function name
93103
name = parser.parse_symbol_name().data
@@ -103,6 +113,13 @@ def parse_fun_input() -> Attribute | tuple[Parser.Argument, dict[str, Attribute]
103113
ret = (arg, arg_attr_dict)
104114
return ret
105115

116+
def parse_fun_output() -> tuple[Attribute, dict[str, Attribute]]:
117+
arg_type = parser.parse_optional_type()
118+
if arg_type is None:
119+
parser.raise_error("Return type should be specified")
120+
arg_attr_dict = parser.parse_optional_dictionary_attr_dict()
121+
return (arg_type, arg_attr_dict)
122+
106123
# Parse function arguments
107124
args = parser.parse_comma_separated_list(
108125
parser.Delimiter.PAREN,
@@ -135,14 +152,25 @@ def parse_fun_input() -> Attribute | tuple[Parser.Argument, dict[str, Attribute]
135152
arg_attrs = None
136153

137154
# Parse return type
155+
return_types: list[Attribute] = []
156+
res_attrs_raw: list[dict[str, Attribute]] | None = []
138157
if parser.parse_optional_punctuation("->"):
139-
return_types = parser.parse_optional_comma_separated_list(
140-
parser.Delimiter.PAREN, parser.parse_type
158+
return_attributes = parser.parse_optional_comma_separated_list(
159+
parser.Delimiter.PAREN, parse_fun_output
141160
)
142-
if return_types is None:
143-
return_types = [parser.parse_type()]
161+
if return_attributes is None:
162+
# output attributes are supported only if return results are enclosed in brackets (...)
163+
return_types, res_attrs_raw = [parser.parse_type()], None
164+
else:
165+
return_types, res_attrs_raw = (
166+
[el[0] for el in return_attributes],
167+
[el[1] for el in return_attributes],
168+
)
169+
170+
if res_attrs_raw is not None and any(res_attrs_raw):
171+
res_attrs = ArrayAttr(DictionaryAttr(attrs) for attrs in res_attrs_raw)
144172
else:
145-
return_types = []
173+
res_attrs = None
146174

147175
extra_attributes = parser.parse_optional_attr_dict_with_keyword(reserved_attr_names)
148176

@@ -151,7 +179,15 @@ def parse_fun_input() -> Attribute | tuple[Parser.Argument, dict[str, Attribute]
151179
if region is None:
152180
region = Region()
153181

154-
return name, input_types, return_types, region, extra_attributes, arg_attrs
182+
return (
183+
name,
184+
input_types,
185+
return_types,
186+
region,
187+
extra_attributes,
188+
arg_attrs,
189+
res_attrs,
190+
)
155191

156192

157193
def print_func_argument(
@@ -162,6 +198,14 @@ def print_func_argument(
162198
printer.print_op_attributes(attrs.data)
163199

164200

201+
def print_func_output(
202+
printer: Printer, out_type: Attribute, attrs: DictionaryAttr | None
203+
):
204+
printer.print_attribute(out_type)
205+
if attrs is not None and attrs.data:
206+
printer.print_op_attributes(attrs.data)
207+
208+
165209
def print_assignment(printer: Printer, arg: BlockArgument, val: SSAValue):
166210
printer.print_block_argument(arg, print_type=False)
167211
printer.print_string(" = ")

0 commit comments

Comments
 (0)