diff --git a/tests/filecheck/dialects/builtin/module.mlir b/tests/filecheck/dialects/builtin/module.mlir index 58efcb550a..3ffb7254bc 100644 --- a/tests/filecheck/dialects/builtin/module.mlir +++ b/tests/filecheck/dialects/builtin/module.mlir @@ -9,6 +9,12 @@ builtin.module { builtin.module attributes {a = "foo", b = "bar", unit} {} // CHECK-NEXT: builtin.module attributes {"a" = "foo", "b" = "bar", "unit"} { // CHECK-NEXT: } + builtin.module @moduleName {} + // CHECK-NEXT: builtin.module @moduleName { + // CHECK-NEXT: } + builtin.module @otherModule attributes {dialect.attr} {} + // CHECK-NEXT: builtin.module @otherModule attributes {"dialect.attr"} { + // CHECK-NEXT: } module {} // CHECK-NEXT: builtin.module { // CHECK-NEXT: } diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index e4bf295f56..e03b9be0d4 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -61,7 +61,7 @@ irdl_attr_definition, irdl_op_definition, irdl_to_attr_constraint, - opt_attr_def, + opt_prop_def, region_def, traits_def, var_operand_def, @@ -1628,7 +1628,7 @@ def verify(self): class ModuleOp(IRDLOperation): name = "builtin.module" - sym_name = opt_attr_def(StringAttr) + sym_name = opt_prop_def(StringAttr) body = region_def("single_block") @@ -1643,6 +1643,7 @@ def __init__( self, ops: list[Operation] | Region, attributes: Mapping[str, Attribute] | None = None, + sym_name: StringAttr | None = None, ): if attributes is None: attributes = {} @@ -1650,7 +1651,8 @@ def __init__( region = ops else: region = Region(Block(ops)) - super().__init__(regions=[region], attributes=attributes) + properties: dict[str, Attribute | None] = {"sym_name": sym_name} + super().__init__(regions=[region], attributes=attributes, properties=properties) @property def ops(self) -> BlockOps: @@ -1658,6 +1660,8 @@ def ops(self) -> BlockOps: @classmethod def parse(cls, parser: Parser) -> ModuleOp: + module_name = parser.parse_optional_symbol_name() + attributes = parser.parse_optional_attr_dict_with_keyword() if attributes is not None: attributes = attributes.data @@ -1667,9 +1671,12 @@ def parse(cls, parser: Parser) -> ModuleOp: if not region.blocks: region.add_block(Block()) - return ModuleOp(region, attributes) + return ModuleOp(region, attributes, module_name) def print(self, printer: Printer) -> None: + if self.sym_name is not None: + printer.print(f" @{self.sym_name.data}") + if len(self.attributes) != 0: printer.print(" attributes ") printer.print_op_attributes(self.attributes)