Skip to content

Commit

Permalink
fixed attr-dict handling and tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kimxworrall committed Aug 19, 2024
1 parent e4d7ea5 commit 78393ec
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 34 deletions.
55 changes: 37 additions & 18 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,44 +1100,63 @@ def test_missing_region():
with pytest.raises(PyRDLOpDefinitionError, match="region 'region' not found"):

@irdl_op_definition
class NoRegionTypeOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
class NoRegionOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.no_region_op"
region = region_def()

assembly_format = "attr-dict-with-keyword"


def test_region_missing_keyword():
"""Test that regions require the 'attr-dict-with-keyword' directive."""
def test_attr_dict_directly_before_region_variable():
"""Test that regions require an 'attr-dict' directive."""
with pytest.raises(
PyRDLOpDefinitionError,
match="'attr-dict' directive must be 'attr-dict-with-keyword'",
match="An `attr-dict' directive without keyword cannot be directly followed by a region variable",
):

@irdl_op_definition
class NoRegionTypeOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
class RegionAttrDictWrongOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.region_op_missing_keyword"
region = region_def()

assembly_format = "attr-dict $region"


def test_region_before_attr():
"""Test that regions should appear after the 'attr-dict-with-keyword' directive."""
with pytest.raises(
PyRDLOpDefinitionError, match="'attr-dict' directive must appear"
):
@pytest.mark.parametrize(
"format, program, generic_program",
[
(
"$region attr-dict",
'test.region_attr_dict {\n} {"a" = 2 : i32}',
'"test.region_attr_dict"() ({}) {"a" = 2 : i32} : () -> ()',
),
(
"attr-dict `,` $region",
'test.region_attr_dict {"a" = 2 : i32}, {\n "test.op"() : () -> ()\n}',
'"test.region_attr_dict"() ({ "test.op"() : () -> ()}) {"a" = 2 : i32} : () -> ()',
),
],
)
def test_regions_with_attr_dict(format: str, program: str, generic_program: str):
"""Test the parsing of regions"""

@irdl_op_definition
class NoRegionTypeOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.region_before_attr"
region = region_def()
@irdl_op_definition
class RegionsOp(IRDLOperation):
name = "test.region_attr_dict"
region = region_def()

assembly_format = "$region attr-dict-with-keyword"
assembly_format = format

ctx = MLContext()
ctx.load_op(RegionsOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)


@irdl_op_definition
class TypedAttributeOp(IRDLOperation):
class MiscOp(IRDLOperation):
name = "test.typed_attr"
attr = attr_def(IntegerAttr[I32])

Expand Down Expand Up @@ -1177,7 +1196,7 @@ class TwoRegionsOp(IRDLOperation):

ctx = MLContext()
ctx.load_op(TwoRegionsOp)
ctx.load_op(TypedAttributeOp)
ctx.load_op(MiscOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)
Expand Down Expand Up @@ -1270,7 +1289,7 @@ def test_multiple_optional_regions():
):

@irdl_op_definition
class OptionalOperandsOp(IRDLOperation):
class OptionalOperandsOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.optional_regions"
region1 = opt_region_def()
region2 = opt_region_def()
Expand Down
11 changes: 8 additions & 3 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -742,7 +742,7 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No


@dataclass(frozen=True)
class RegionVariable(VariableDirective):
class RegionVariable(VariableDirective, OptionallyParsableDirective):
"""
A region variable, with the following format:
region-directive ::= dollar-ident
Expand All @@ -753,6 +753,11 @@ def parse(self, parser: Parser, state: ParsingState) -> None:
region = parser.parse_region()
state.regions[self.index] = region

def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
region = parser.parse_optional_region()
state.regions[self.index] = region
return region is not None

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
Expand All @@ -767,7 +772,7 @@ class VariadicRegionVariable(
):
"""
A variadic region variable, with the following format:
region-directive ::= ( percent-ident ( `,` percent-id )* )?
region-directive ::= ( dollar-ident ( `,` dollar-id )* )?
The directive will request a space to be printed after.
"""

Expand All @@ -794,7 +799,7 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No
class OptionalRegionVariable(OptionalVariable, OptionallyParsableDirective):
"""
An optional region variable, with the following format:
region-directive ::= ( percent-ident )?
region-directive ::= ( dollar-ident )?
The directive will request a space to be printed after.
"""

Expand Down
17 changes: 4 additions & 13 deletions xdsl/irdl/declarative_assembly_format_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,6 @@ class FormatParser(BaseParser):
"""The region variables that are already parsed."""
has_attr_dict: bool = field(default=False)
"""True if the attribute dictionary has already been parsed."""
has_attr_dict_with_keyword: bool = field(default=False)
"""True if the attr-dict directive was with keyword."""
context: ParsingContext = field(default=ParsingContext.TopLevel)
"""Indicates if the parser is nested in a particular directive."""
type_resolutions: dict[
Expand Down Expand Up @@ -196,6 +194,10 @@ def verify_directives(self, elements: list[FormatDirective]):
self.raise_error(
"A variadic operand variable cannot be followed by another variadic operand variable."
)
case AttrDictDirective(), RegionVariable() if not (a.with_keyword):
self.raise_error(
"An `attr-dict' directive without keyword cannot be directly followed by a region variable as it is ambiguous."
)
case _:
pass

Expand Down Expand Up @@ -379,16 +381,6 @@ def parse_optional_variable(
if variable_name != region_name:
continue
self.seen_regions[idx] = True
if not self.has_attr_dict:
self.raise_error(
"'attr-dict' directive must appear"
f"before regions, found region'{region_name}'"
)
if not self.has_attr_dict_with_keyword:
self.raise_error(
"'attr-dict' directive must be 'attr-dict-with-keyword'"
"if regions present."
)
match region_def:
case OptRegionDef():
return OptionalRegionVariable(variable_name, idx)
Expand Down Expand Up @@ -628,7 +620,6 @@ def create_attr_dict_directive(self, with_keyword: bool) -> AttrDictDirective:
"in the assembly format description"
)
self.has_attr_dict = True
self.has_attr_dict_with_keyword = with_keyword
print_properties = any(
isinstance(option, ParsePropInAttrDict) for option in self.op_def.options
)
Expand Down

0 comments on commit 78393ec

Please sign in to comment.