@@ -43,6 +43,7 @@ def print_func_op_like(
43
43
attributes : dict [str , Attribute ],
44
44
* ,
45
45
arg_attrs : ArrayAttr [DictionaryAttr ] | None = None ,
46
+ res_attrs : ArrayAttr [DictionaryAttr ] | None = None ,
46
47
reserved_attr_names : Sequence [str ],
47
48
):
48
49
printer .print (f" @{ sym_name .data } " )
@@ -62,7 +63,15 @@ def print_func_op_like(
62
63
printer .print ("-> " )
63
64
if len (function_type .outputs ) > 1 :
64
65
printer .print ("(" )
65
- printer .print_list (function_type .outputs , printer .print_attribute )
66
+ if res_attrs is not None :
67
+ printer .print_list (
68
+ zip (function_type .outputs , res_attrs ),
69
+ lambda arg_with_attrs : print_func_output (
70
+ printer , arg_with_attrs [0 ], arg_with_attrs [1 ]
71
+ ),
72
+ )
73
+ else :
74
+ printer .print_list (function_type .outputs , printer .print_attribute )
66
75
if len (function_type .outputs ) > 1 :
67
76
printer .print (")" )
68
77
printer .print (" " )
@@ -85,9 +94,10 @@ def parse_func_op_like(
85
94
Region ,
86
95
DictionaryAttr | None ,
87
96
ArrayAttr [DictionaryAttr ] | None ,
97
+ ArrayAttr [DictionaryAttr ] | None ,
88
98
]:
89
99
"""
90
- Returns the function name, argument types, return types, body, extra args, and arg_attrs .
100
+ Returns the function name, argument types, return types, body, extra args, arg_attrs and res_attrs .
91
101
"""
92
102
# Parse function name
93
103
name = parser .parse_symbol_name ().data
@@ -103,6 +113,13 @@ def parse_fun_input() -> Attribute | tuple[Parser.Argument, dict[str, Attribute]
103
113
ret = (arg , arg_attr_dict )
104
114
return ret
105
115
116
+ def parse_fun_output () -> tuple [Attribute , dict [str , Attribute ]]:
117
+ arg_type = parser .parse_optional_type ()
118
+ if arg_type is None :
119
+ parser .raise_error ("Return type should be specified" )
120
+ arg_attr_dict = parser .parse_optional_dictionary_attr_dict ()
121
+ return (arg_type , arg_attr_dict )
122
+
106
123
# Parse function arguments
107
124
args = parser .parse_comma_separated_list (
108
125
parser .Delimiter .PAREN ,
@@ -135,14 +152,25 @@ def parse_fun_input() -> Attribute | tuple[Parser.Argument, dict[str, Attribute]
135
152
arg_attrs = None
136
153
137
154
# Parse return type
155
+ return_types : list [Attribute ] = []
156
+ res_attrs_raw : list [dict [str , Attribute ]] | None = []
138
157
if parser .parse_optional_punctuation ("->" ):
139
- return_types = parser .parse_optional_comma_separated_list (
140
- parser .Delimiter .PAREN , parser . parse_type
158
+ return_attributes = parser .parse_optional_comma_separated_list (
159
+ parser .Delimiter .PAREN , parse_fun_output
141
160
)
142
- if return_types is None :
143
- return_types = [parser .parse_type ()]
161
+ if return_attributes is None :
162
+ # output attributes are supported only if return results are enclosed in brackets (...)
163
+ return_types , res_attrs_raw = [parser .parse_type ()], None
164
+ else :
165
+ return_types , res_attrs_raw = (
166
+ [el [0 ] for el in return_attributes ],
167
+ [el [1 ] for el in return_attributes ],
168
+ )
169
+
170
+ if res_attrs_raw is not None and any (res_attrs_raw ):
171
+ res_attrs = ArrayAttr (DictionaryAttr (attrs ) for attrs in res_attrs_raw )
144
172
else :
145
- return_types = []
173
+ res_attrs = None
146
174
147
175
extra_attributes = parser .parse_optional_attr_dict_with_keyword (reserved_attr_names )
148
176
@@ -151,7 +179,15 @@ def parse_fun_input() -> Attribute | tuple[Parser.Argument, dict[str, Attribute]
151
179
if region is None :
152
180
region = Region ()
153
181
154
- return name , input_types , return_types , region , extra_attributes , arg_attrs
182
+ return (
183
+ name ,
184
+ input_types ,
185
+ return_types ,
186
+ region ,
187
+ extra_attributes ,
188
+ arg_attrs ,
189
+ res_attrs ,
190
+ )
155
191
156
192
157
193
def print_func_argument (
@@ -162,6 +198,14 @@ def print_func_argument(
162
198
printer .print_op_attributes (attrs .data )
163
199
164
200
201
+ def print_func_output (
202
+ printer : Printer , out_type : Attribute , attrs : DictionaryAttr | None
203
+ ):
204
+ printer .print_attribute (out_type )
205
+ if attrs is not None and attrs .data :
206
+ printer .print_op_attributes (attrs .data )
207
+
208
+
165
209
def print_assignment (printer : Printer , arg : BlockArgument , val : SSAValue ):
166
210
printer .print_block_argument (arg , print_type = False )
167
211
printer .print_string (" = " )
0 commit comments