From 68c66e9f72647ed62dd626a59a9cd27d9065007b Mon Sep 17 00:00:00 2001 From: Kim Worrall Date: Mon, 19 Aug 2024 13:23:44 +0100 Subject: [PATCH] Add regions to declarative_assembly_format and the parser, and add tests --- .../workflows/update_xdsl_pyodide_build.py | 0 .../irdl/test_declarative_assembly_format.py | 179 ++++++++++++++++++ xdsl/irdl/declarative_assembly_format.py | 90 ++++++++- .../declarative_assembly_format_parser.py | 57 +++++- xdsl/irdl/operations.py | 2 +- xdsl/tools/xdsl-opt | 0 6 files changed, 323 insertions(+), 5 deletions(-) mode change 100755 => 100644 .github/workflows/update_xdsl_pyodide_build.py mode change 100755 => 100644 xdsl/tools/xdsl-opt diff --git a/.github/workflows/update_xdsl_pyodide_build.py b/.github/workflows/update_xdsl_pyodide_build.py old mode 100755 new mode 100644 diff --git a/tests/irdl/test_declarative_assembly_format.py b/tests/irdl/test_declarative_assembly_format.py index 5d23421d14..a77b2f19fd 100644 --- a/tests/irdl/test_declarative_assembly_format.py +++ b/tests/irdl/test_declarative_assembly_format.py @@ -35,10 +35,13 @@ opt_attr_def, opt_operand_def, opt_prop_def, + opt_region_def, opt_result_def, prop_def, + region_def, result_def, var_operand_def, + var_region_def, var_result_def, ) from xdsl.parser import Parser @@ -1086,6 +1089,182 @@ class OptionalResultOp(IRDLOperation): check_roundtrip(program, ctx) check_equivalence(program, generic_program, ctx) +################################################################################ +# Regions # +################################################################################ + +def test_missing_region(): + """Test that regions should be parsed.""" + with pytest.raises(PyRDLOpDefinitionError, match="region 'region' not found"): + + @irdl_op_definition + class NoRegionTypeOp(IRDLOperation): # pyright: ignore[reportUnusedClass] + name = "test.no_region_op" + region = region_def() + + assembly_format = "attr-dict-with-keyword" + +def test_region_missing_keyword(): + """Test that regions require the 'attr-dict-with-keyword' directive.""" + with pytest.raises(PyRDLOpDefinitionError, match="'attr-dict' directive must be 'attr-dict-with-keyword'"): + + @irdl_op_definition + class NoRegionTypeOp(IRDLOperation): # pyright: ignore[reportUnusedClass] + name = "test.region_op_missing_keyword" + region = region_def() + + assembly_format = "attr-dict $region" + +def test_region_before_attr(): + """Test that regions should appear after the 'attr-dict-with-keyword' directive.""" + with pytest.raises(PyRDLOpDefinitionError, match="'attr-dict' directive must appear"): + + @irdl_op_definition + class NoRegionTypeOp(IRDLOperation): # pyright: ignore[reportUnusedClass] + name = "test.region_before_attr" + region = region_def() + + assembly_format = "$region attr-dict-with-keyword" + + + +@irdl_op_definition +class TypedAttributeOp(IRDLOperation): + name = "test.typed_attr" + attr = attr_def(IntegerAttr[I32]) + + assembly_format = "$attr attr-dict" + + +@pytest.mark.parametrize( + "format, program, generic_program", + [ + ( + "attr-dict-with-keyword $fst $snd", + "test.two_regions {\n} {\n}", + '"test.two_regions"() ({}, {}) : () -> ()', + ), + ( + "attr-dict-with-keyword $fst $snd", + 'test.two_regions {\n test.typed_attr 3\n} {\n test.typed_attr 3\n}', + '"test.two_regions"() ({ test.typed_attr 3}, { test.typed_attr 3}) : () -> ()', + ), + ( + "attr-dict-with-keyword $fst $snd", + 'test.two_regions attributes {"a" = 2 : i32} {\n test.typed_attr 3\n} {\n test.typed_attr 3\n}', + '"test.two_regions"() ({ test.typed_attr 3}, { test.typed_attr 3}) {"a" = 2 : i32} : () -> ()', + ), + ], +) +def test_regions(format: str, program: str, generic_program: str): + """Test the parsing of regions""" + + @irdl_op_definition + class TwoRegionsOp(IRDLOperation): + name = "test.two_regions" + fst = region_def() + snd = region_def() + + assembly_format = format + + ctx = MLContext() + ctx.load_op(TwoRegionsOp) + ctx.load_op(TypedAttributeOp) + ctx.load_dialect(Test) + + check_roundtrip(program, ctx) + check_equivalence(program, generic_program, ctx) + +@pytest.mark.parametrize( + "format, program, generic_program", + [ + ( + "attr-dict-with-keyword $region", + "test.variadic_region ", + '"test.variadic_region"() : () -> ()', + ), + ( + "attr-dict-with-keyword $region", + 'test.variadic_region {\n "test.op"() : () -> ()\n}', + '"test.variadic_region"() ({ "test.op"() : () -> ()}) : () -> ()', + ), + ( + "attr-dict-with-keyword $region", + 'test.variadic_region {\n "test.op"() : () -> ()\n} {\n "test.op"() : () -> ()\n}', + '"test.variadic_region"() ({ "test.op"() : () -> ()}, { "test.op"() : () -> ()}) : () -> ()', + ), + ( + "attr-dict-with-keyword $region", + 'test.variadic_region {\n "test.op"() : () -> ()\n} {\n "test.op"() : () -> ()\n}', + '"test.variadic_region"() ({ "test.op"() : () -> ()}, {"test.op"() : () -> ()}) : () -> ()', + ), + ], +) +def test_variadic_region(format: str, program: str, generic_program: str): + """Test the parsing of variadic regions""" + + @irdl_op_definition + class VariadicRegionOp(IRDLOperation): + name = "test.variadic_region" + region = var_region_def() + + assembly_format = format + + ctx = MLContext() + ctx.load_op(VariadicRegionOp) + ctx.load_dialect(Test) + + check_roundtrip(program, ctx) + check_equivalence(program, generic_program, ctx) + + +@pytest.mark.parametrize( + "format, program, generic_program", + [ + ( + "attr-dict-with-keyword $region", + "test.optional_region ", + '"test.optional_region"() : () -> ()', + ), + ( + "attr-dict-with-keyword $region", + 'test.optional_region {\n "test.op"() : () -> ()\n}', + '"test.optional_region"() ({ "test.op"() : () -> ()}) : () -> ()', + ), + ], +) +def test_optional_region(format: str, program: str, generic_program: str): + """Test the parsing of optional regions""" + + @irdl_op_definition + class OptionalRegionOp(IRDLOperation): + name = "test.optional_region" + region = opt_region_def() + + assembly_format = format + + ctx = MLContext() + ctx.load_op(OptionalRegionOp) + ctx.load_dialect(Test) + + check_roundtrip(program, ctx) + check_equivalence(program, generic_program, ctx) + +def test_multiple_optional_regions(): + """Test the parsing of multiple optional regions""" + + """Test that multiple optional regions requires the ABCMeta PyRDL option.""" + with pytest.raises(PyRDLOpDefinitionError, match='Operation test.optional_regions defines more than two variadic regions'): + @irdl_op_definition + class OptionalOperandsOp(IRDLOperation): + name = "test.optional_regions" + region1 = opt_region_def() + region2 = opt_region_def() + + assembly_format = ( + "attr-dict-with-keyword $region1 $region2" + ) + ################################################################################ # Inference # diff --git a/xdsl/irdl/declarative_assembly_format.py b/xdsl/irdl/declarative_assembly_format.py index c7f70d6db7..af440ffed7 100644 --- a/xdsl/irdl/declarative_assembly_format.py +++ b/xdsl/irdl/declarative_assembly_format.py @@ -9,6 +9,8 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from dataclasses import dataclass, field +from os import stat +from re import A from typing import Any, Literal, cast from xdsl.dialects.builtin import UnitAttr @@ -18,7 +20,9 @@ ParametrizedAttribute, SSAValue, TypedAttribute, + Region, ) +from xdsl.ir.core import Block, Operation from xdsl.irdl import ( ConstraintContext, IRDLOperation, @@ -48,19 +52,21 @@ class ParsingState: operands: list[UnresolvedOperand | None | list[UnresolvedOperand | None]] operand_types: list[Attribute | None | list[Attribute | None]] result_types: list[Attribute | None | list[Attribute | None]] + regions: list[Region | None | list[Region | None]] attributes: dict[str, Attribute] properties: dict[str, Attribute] constraint_context: ConstraintContext def __init__(self, op_def: OpDef): - if op_def.regions or op_def.successors: + if op_def.successors: raise NotImplementedError( - "Operation definitions with regions " - "or successors are not yet supported" + "Operation definitions with " + "successors are not yet supported" ) self.operands = [None] * len(op_def.operands) self.operand_types = [None] * len(op_def.operands) self.result_types = [None] * len(op_def.results) + self.regions = [None] * len(op_def.regions) self.attributes = {} self.properties = {} self.constraint_context = ConstraintContext() @@ -159,11 +165,16 @@ def parse( properties = state.properties else: properties = op_def.split_properties(state.attributes) + + # Get the regions and cast them to the type needed in op_type + regions = cast(Sequence[Region | Sequence[Operation] | Sequence[Block] | Sequence[Region | Sequence[Operation] | Sequence[Block]] | None], state.regions) + return op_type.build( result_types=result_types, operands=operands, attributes=state.attributes, properties=properties, + regions= regions, ) def assign_constraint_variables( @@ -722,6 +733,79 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No state.last_was_punctuation = False state.should_emit_space = True +@dataclass(frozen=True) +class RegionVariable(VariableDirective): + """ + A region variable, with the following format: + region-directive ::= dollar-ident + The directive will request a space to be printed after. + """ + + def parse(self, parser: Parser, state: ParsingState) -> None: + region = parser.parse_region() + state.regions[self.index] = region + + def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + if state.should_emit_space or not state.last_was_punctuation: + printer.print(" ") + printer.print_region(getattr(op, self.name)) + state.last_was_punctuation = False + state.should_emit_space = True + + +@dataclass(frozen=True) +class VariadicRegionVariable( + VariadicVariable, VariableDirective, OptionallyParsableDirective +): + """ + A variadic region variable, with the following format: + region-directive ::= ( percent-ident ( `,` percent-id )* )? + The directive will request a space to be printed after. + """ + + def parse_optional(self, parser: Parser, state: ParsingState) -> bool: + regions : list[Region] = [] + current_region = parser.parse_optional_region() + while current_region is not None: + regions.append(current_region) + current_region = parser.parse_optional_region() + + state.regions[self.index] = cast(list[Region | None], regions) + return bool(regions) + + def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + breakpoint() + if state.should_emit_space or not state.last_was_punctuation: + printer.print(" ") + region = getattr(op, self.name) + if region: + printer.print_list(region, printer.print_region, delimiter=" ") + state.last_was_punctuation = False + state.should_emit_space = True + + +class OptionalRegionVariable(OptionalVariable, OptionallyParsableDirective): + """ + An optional region variable, with the following format: + region-directive ::= ( percent-ident )? + The directive will request a space to be printed after. + """ + + def parse_optional(self, parser: Parser, state: ParsingState) -> bool: + region = parser.parse_optional_region() + if region is None: + region = list[Region | None]() + state.regions[self.index] = region + return bool(region) + + def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None: + if state.should_emit_space or not state.last_was_punctuation: + printer.print(" ") + region = getattr(op, self.name) + if region: + printer.print_region(region) + state.last_was_punctuation = False + state.should_emit_space = True @dataclass(frozen=True) class AttributeVariable(FormatDirective): diff --git a/xdsl/irdl/declarative_assembly_format_parser.py b/xdsl/irdl/declarative_assembly_format_parser.py index 79f7b13997..26ea55a500 100644 --- a/xdsl/irdl/declarative_assembly_format_parser.py +++ b/xdsl/irdl/declarative_assembly_format_parser.py @@ -21,11 +21,13 @@ OpDef, OptionalDef, OptOperandDef, + OptRegionDef, OptResultDef, ParamAttrConstraint, ParsePropInAttrDict, VariadicDef, VarOperandDef, + VarRegionDef, VarResultDef, ) from xdsl.irdl.declarative_assembly_format import ( @@ -43,10 +45,12 @@ OptionallyParsableDirective, OptionalOperandTypeDirective, OptionalOperandVariable, + OptionalRegionVariable, OptionalResultTypeDirective, OptionalResultVariable, OptionalUnitAttrVariable, PunctuationDirective, + RegionVariable, ResultTypeDirective, ResultVariable, VariableDirective, @@ -55,6 +59,7 @@ VariadicLikeVariable, VariadicOperandTypeDirective, VariadicOperandVariable, + VariadicRegionVariable, VariadicResultTypeDirective, VariadicResultVariable, WhitespaceDirective, @@ -123,8 +128,12 @@ class FormatParser(BaseParser): """The attributes that are already parsed.""" seen_properties: set[str] """The properties that are already parsed.""" + seen_regions: list[bool] + """The region variables that are already parsed.""" has_attr_dict: bool = field(default=False) - """True if the attribute dictionary has already been parsed.""" + """True if the attribute dictionary has already been parsed.""" + has_attr_dict_with_keyword: bool = field(default=False) + """True if the attr-dict directive was with keyword.""" context: ParsingContext = field(default=ParsingContext.TopLevel) """Indicates if the parser is nested in a particular directive.""" type_resolutions: dict[ @@ -141,6 +150,7 @@ def __init__(self, input: str, op_def: OpDef): self.seen_result_types = [False] * len(op_def.results) self.seen_attributes = set[str]() self.seen_properties = set[str]() + self.seen_regions = [False] * len(op_def.regions) self.type_resolutions = {} def parse_format(self) -> FormatProgram: @@ -161,6 +171,7 @@ def parse_format(self) -> FormatProgram: self.verify_properties() self.verify_operands(seen_variables) self.verify_results(seen_variables) + self.verify_regions() return FormatProgram(elements) def verify_directives(self, elements: list[FormatDirective]): @@ -291,6 +302,26 @@ def verify_properties(self): "'ParsePropInAttrDict' IRDL option." ) + def verify_regions(self): + """ + Check that all regions are present. + """ + for ( + seen_region, + (region_name, _), + ) in zip( + self.seen_regions, + self.op_def.regions, + strict=True, + ): + if not seen_region: + self.raise_error( + f"region '{region_name}' " + f"not found, consider adding a '${region_name}' " + "directive to the custom assembly format." + ) + + def parse_optional_variable( self, ) -> VariableDirective | AttributeVariable | None: @@ -343,6 +374,29 @@ def parse_optional_variable( return VariadicResultVariable(variable_name, idx) else: return ResultVariable(variable_name, idx) + + # Check if the variable is a region + for idx, (region_name, region_def) in enumerate(self.op_def.regions): + if variable_name != region_name: + continue + self.seen_regions[idx] = True + if not self.has_attr_dict: + self.raise_error( + "'attr-dict' directive must appear" + f"before regions, found region'{region_name}'" + ) + if not self.has_attr_dict_with_keyword: + self.raise_error( + "'attr-dict' directive must be 'attr-dict-with-keyword'" + "if regions present." + ) + match region_def: + case OptRegionDef(): + return OptionalRegionVariable(variable_name, idx) + case VarRegionDef(): + return VariadicRegionVariable(variable_name, idx) + case _: + return RegionVariable(variable_name, idx) attr_or_prop_by_name = { attr_name: attr_or_prop @@ -575,6 +629,7 @@ def create_attr_dict_directive(self, with_keyword: bool) -> AttrDictDirective: "in the assembly format description" ) self.has_attr_dict = True + self.has_attr_dict_with_keyword = with_keyword print_properties = any( isinstance(option, ParsePropInAttrDict) for option in self.op_def.options ) diff --git a/xdsl/irdl/operations.py b/xdsl/irdl/operations.py index a98cd59f99..311415ca78 100644 --- a/xdsl/irdl/operations.py +++ b/xdsl/irdl/operations.py @@ -1735,7 +1735,7 @@ def fun(self: Any, idx: int = arg_idx, previous_vars: int = previous_variadics): not any(isinstance(o, arg_size_option) for o in op_def.options) ): arg_size_option_name = type(arg_size_option).__name__ - raise Exception( + raise PyRDLOpDefinitionError( f"Operation {op_def.name} defines more than two variadic " f"{get_construct_name(construct)}s, but do not define the " f"{arg_size_option_name} PyRDL option." diff --git a/xdsl/tools/xdsl-opt b/xdsl/tools/xdsl-opt old mode 100755 new mode 100644