From 2d39e73851c533e78ba8d06e85b5833d6ca7d27f Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Fri, 24 Jan 2025 10:59:35 +0000 Subject: [PATCH] core: (assembly-format) omit default attributes/properties from attr-dict directive (#3783) The `attr-dict` directive currently prints attributes/properties which are set to the default value, which is not the correct behaviour. --- .../irdl/test_declarative_assembly_format.py | 62 +++++++++++++++++++ xdsl/irdl/declarative_assembly_format.py | 32 +++++++--- 2 files changed, 84 insertions(+), 10 deletions(-) diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index 4c3ad1a01e..b994c4af96 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -2965,6 +2965,68 @@ class DefaultConstantOp(IRDLOperation): check_equivalence(program, generic, ctx) +@pytest.mark.parametrize( + "program, generic", + [ + ( + "test.default_attr_dict", + '"test.default_attr_dict"() <{prop = false}> {attr = false} : () -> ()', + ), + ( + "test.default_attr_dict {attr = true, prop = true}", + '"test.default_attr_dict"() <{prop = true}> {attr = true} : () -> ()', + ), + ], +) +def test_default_property_in_attr_dict(program: str, generic: str): + @irdl_op_definition + class DefaultAttrDictOp(IRDLOperation): + name = "test.default_attr_dict" + + prop = prop_def(BoolAttr, default_value=BoolAttr.from_bool(False)) + + attr = attr_def(BoolAttr, default_value=BoolAttr.from_bool(False)) + + irdl_options = [ParsePropInAttrDict()] + + assembly_format = "attr-dict" + + ctx = MLContext() + ctx.load_op(DefaultAttrDictOp) + + check_roundtrip(program, ctx) + check_equivalence(program, generic, ctx) + + +@pytest.mark.parametrize( + "program, generic", + [ + ( + "test.default_attr_dict", + '"test.default_attr_dict"() {attr = false} : () -> ()', + ), + ( + "test.default_attr_dict {attr = true}", + '"test.default_attr_dict"() {attr = true} : () -> ()', + ), + ], +) +def test_default_attr_in_attr_dict(program: str, generic: str): + @irdl_op_definition + class DefaultAttrDictOp(IRDLOperation): + name = "test.default_attr_dict" + + attr = attr_def(BoolAttr, default_value=BoolAttr.from_bool(False)) + + assembly_format = "attr-dict" + + ctx = MLContext() + ctx.load_op(DefaultAttrDictOp) + + check_roundtrip(program, ctx) + check_equivalence(program, generic, ctx) + + ################################################################################ # Extractors # ################################################################################ diff --git a/xdsl/irdl/declarative_assembly_format.py b/xdsl/irdl/declarative_assembly_format.py index a17c9b6089..8206f71d85 100644 --- a/xdsl/irdl/declarative_assembly_format.py +++ b/xdsl/irdl/declarative_assembly_format.py @@ -438,27 +438,39 @@ def parse(self, parser: Parser, state: ParsingState) -> None: def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: if self.print_properties: - if ( - not (set(op.attributes.keys()) | set(op.properties.keys())) - - self.reserved_attr_names - ): - return if any(name in op.attributes for name in op.properties): raise ValueError( "Cannot print attributes and properties with the same name " - "in a signle dictionary" + "in a single dictionary" ) + op_def = op.get_irdl_definition() + dictionary = op.attributes | op.properties + reserved_or_default = self.reserved_attr_names.union( + name + for name, d in (op_def.properties | op_def.attributes).items() + if d.default_value is not None + and dictionary.get(name) == d.default_value + ) + if reserved_or_default.issuperset(dictionary.keys()): + return printer.print_op_attributes( - op.attributes | op.properties, - reserved_attr_names=self.reserved_attr_names, + dictionary, + reserved_attr_names=reserved_or_default, print_keyword=self.with_keyword, ) else: - if not set(op.attributes.keys()) - self.reserved_attr_names: + op_def = op.get_irdl_definition() + reserved_or_default = self.reserved_attr_names.union( + name + for name, d in op_def.attributes.items() + if d.default_value is not None + and op.attributes.get(name) == d.default_value + ) + if reserved_or_default.issuperset(op.attributes.keys()): return printer.print_op_attributes( op.attributes, - reserved_attr_names=self.reserved_attr_names, + reserved_attr_names=reserved_or_default, print_keyword=self.with_keyword, )