Skip to content

Commit

Permalink
core: (irdl) Add regions to declarative assembly format (#3065)
Browse files Browse the repository at this point in the history
Added support for regions to the declarative assembly format.

The declarative assembly format now supports:
- Referring to region variables
- Referring to optional or variadic region variables
- Printing custom assembly for region variables

The format requires that:
- The attr-dict-with-keyword directive is present and not attr-dict.
This is to maintain consistency with mlir.
- The regions appear after the attributes.

Also added tests.

---------

Co-authored-by: Sasha Lopoukhine <[email protected]>
  • Loading branch information
kimxworrall and superlopuh authored Aug 21, 2024
1 parent 5299b90 commit 600dd91
Show file tree
Hide file tree
Showing 3 changed files with 340 additions and 3 deletions.
210 changes: 210 additions & 0 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@
opt_attr_def,
opt_operand_def,
opt_prop_def,
opt_region_def,
opt_result_def,
prop_def,
region_def,
result_def,
var_operand_def,
var_region_def,
var_result_def,
)
from xdsl.parser import Parser
Expand Down Expand Up @@ -1119,6 +1122,213 @@ class OptionalResultOp(IRDLOperation):
check_equivalence(program, generic_program, ctx)


################################################################################
# Regions #
################################################################################


def test_missing_region():
"""Test that regions should be parsed."""
with pytest.raises(PyRDLOpDefinitionError, match="region 'region' not found"):

@irdl_op_definition
class NoRegionOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.no_region_op"
region = region_def()

assembly_format = "attr-dict-with-keyword"


def test_attr_dict_directly_before_region_variable():
"""Test that regions require an 'attr-dict' directive."""
with pytest.raises(
PyRDLOpDefinitionError,
match="An `attr-dict' directive without keyword cannot be directly followed by a region variable",
):

@irdl_op_definition
class RegionAttrDictWrongOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.region_op_missing_keyword"
region = region_def()

assembly_format = "attr-dict $region"


@pytest.mark.parametrize(
"format, program, generic_program",
[
(
"$region attr-dict",
'test.region_attr_dict {\n} {"a" = 2 : i32}',
'"test.region_attr_dict"() ({}) {"a" = 2 : i32} : () -> ()',
),
(
"attr-dict `,` $region",
'test.region_attr_dict {"a" = 2 : i32}, {\n "test.op"() : () -> ()\n}',
'"test.region_attr_dict"() ({ "test.op"() : () -> ()}) {"a" = 2 : i32} : () -> ()',
),
],
)
def test_regions_with_attr_dict(format: str, program: str, generic_program: str):
"""Test the parsing of regions"""

@irdl_op_definition
class RegionsOp(IRDLOperation):
name = "test.region_attr_dict"
region = region_def()

assembly_format = format

ctx = MLContext()
ctx.load_op(RegionsOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)


@irdl_op_definition
class MiscOp(IRDLOperation):
name = "test.typed_attr"
attr = attr_def(IntegerAttr[I32])

assembly_format = "$attr attr-dict"


@pytest.mark.parametrize(
"format, program, generic_program",
[
(
"attr-dict-with-keyword $fst $snd",
"test.two_regions {\n} {\n}",
'"test.two_regions"() ({}, {}) : () -> ()',
),
(
"attr-dict-with-keyword $fst $snd",
"test.two_regions {\n test.typed_attr 3\n} {\n test.typed_attr 3\n}",
'"test.two_regions"() ({ test.typed_attr 3}, { test.typed_attr 3}) : () -> ()',
),
(
"attr-dict-with-keyword $fst $snd",
'test.two_regions attributes {"a" = 2 : i32} {\n test.typed_attr 3\n} {\n test.typed_attr 3\n}',
'"test.two_regions"() ({ test.typed_attr 3}, { test.typed_attr 3}) {"a" = 2 : i32} : () -> ()',
),
],
)
def test_regions(format: str, program: str, generic_program: str):
"""Test the parsing of regions"""

@irdl_op_definition
class TwoRegionsOp(IRDLOperation):
name = "test.two_regions"
fst = region_def()
snd = region_def()

assembly_format = format

ctx = MLContext()
ctx.load_op(TwoRegionsOp)
ctx.load_op(MiscOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)


@pytest.mark.parametrize(
"format, program, generic_program",
[
(
"attr-dict-with-keyword $region",
"test.variadic_region ",
'"test.variadic_region"() : () -> ()',
),
(
"attr-dict-with-keyword $region",
'test.variadic_region {\n "test.op"() : () -> ()\n}',
'"test.variadic_region"() ({ "test.op"() : () -> ()}) : () -> ()',
),
(
"attr-dict-with-keyword $region",
'test.variadic_region {\n "test.op"() : () -> ()\n} {\n "test.op"() : () -> ()\n}',
'"test.variadic_region"() ({ "test.op"() : () -> ()}, { "test.op"() : () -> ()}) : () -> ()',
),
(
"attr-dict-with-keyword $region",
'test.variadic_region {\n "test.op"() : () -> ()\n} {\n "test.op"() : () -> ()\n}',
'"test.variadic_region"() ({ "test.op"() : () -> ()}, {"test.op"() : () -> ()}) : () -> ()',
),
],
)
def test_variadic_region(format: str, program: str, generic_program: str):
"""Test the parsing of variadic regions"""

@irdl_op_definition
class VariadicRegionOp(IRDLOperation):
name = "test.variadic_region"
region = var_region_def()

assembly_format = format

ctx = MLContext()
ctx.load_op(VariadicRegionOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)


@pytest.mark.parametrize(
"format, program, generic_program",
[
(
"attr-dict-with-keyword $region",
"test.optional_region ",
'"test.optional_region"() : () -> ()',
),
(
"attr-dict-with-keyword $region",
'test.optional_region {\n "test.op"() : () -> ()\n}',
'"test.optional_region"() ({ "test.op"() : () -> ()}) : () -> ()',
),
],
)
def test_optional_region(format: str, program: str, generic_program: str):
"""Test the parsing of optional regions"""

@irdl_op_definition
class OptionalRegionOp(IRDLOperation):
name = "test.optional_region"
region = opt_region_def()

assembly_format = format

ctx = MLContext()
ctx.load_op(OptionalRegionOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)


def test_multiple_optional_regions():
"""Test the parsing of multiple optional regions"""

"""Test that multiple optional regions requires the ABCMeta PyRDL option."""
with pytest.raises(
PyRDLOpDefinitionError,
match="Operation test.optional_regions defines more than two variadic regions",
):

@irdl_op_definition
class OptionalOperandsOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.optional_regions"
region1 = opt_region_def()
region2 = opt_region_def()

assembly_format = "attr-dict-with-keyword $region1 $region2"


################################################################################
# Inference #
################################################################################
Expand Down
88 changes: 85 additions & 3 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Attribute,
Data,
ParametrizedAttribute,
Region,
SSAValue,
TypedAttribute,
)
Expand Down Expand Up @@ -48,19 +49,20 @@ class ParsingState:
operands: list[UnresolvedOperand | None | list[UnresolvedOperand | None]]
operand_types: list[Attribute | None | list[Attribute | None]]
result_types: list[Attribute | None | list[Attribute | None]]
regions: list[Region | None | list[Region]]
attributes: dict[str, Attribute]
properties: dict[str, Attribute]
constraint_context: ConstraintContext

def __init__(self, op_def: OpDef):
if op_def.regions or op_def.successors:
if op_def.successors:
raise NotImplementedError(
"Operation definitions with regions "
"or successors are not yet supported"
"Operation definitions with successors are not yet supported"
)
self.operands = [None] * len(op_def.operands)
self.operand_types = [None] * len(op_def.operands)
self.result_types = [None] * len(op_def.results)
self.regions = [None] * len(op_def.regions)
self.attributes = {}
self.properties = {}
self.constraint_context = ConstraintContext()
Expand Down Expand Up @@ -164,6 +166,7 @@ def parse(
operands=operands,
attributes=state.attributes,
properties=properties,
regions=state.regions,
)

def assign_constraint_variables(
Expand Down Expand Up @@ -723,6 +726,85 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No
state.should_emit_space = True


@dataclass(frozen=True)
class RegionVariable(VariableDirective, OptionallyParsableDirective):
"""
A region variable, with the following format:
region-directive ::= dollar-ident
The directive will request a space to be printed after.
"""

def parse(self, parser: Parser, state: ParsingState) -> None:
region = parser.parse_region()
state.regions[self.index] = region

def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
region = parser.parse_optional_region()
state.regions[self.index] = region
return region is not None

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
printer.print_region(getattr(op, self.name))
state.last_was_punctuation = False
state.should_emit_space = True


@dataclass(frozen=True)
class VariadicRegionVariable(
VariadicVariable, VariableDirective, OptionallyParsableDirective
):
"""
A variadic region variable, with the following format:
region-directive ::= dollar-ident
The directive will request a space to be printed after.
"""

def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
regions: list[Region] = []
current_region = parser.parse_optional_region()
while current_region is not None:
regions.append(current_region)
current_region = parser.parse_optional_region()

state.regions[self.index] = regions
return bool(regions)

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
region = getattr(op, self.name)
if region:
printer.print_list(region, printer.print_region, delimiter=" ")
state.last_was_punctuation = False
state.should_emit_space = True


class OptionalRegionVariable(OptionalVariable, OptionallyParsableDirective):
"""
An optional region variable, with the following format:
region-directive ::= dollar-ident
The directive will request a space to be printed after.
"""

def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
region = parser.parse_optional_region()
if region is None:
region = list[Region]()
state.regions[self.index] = region
return bool(region)

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
region = getattr(op, self.name)
if region:
printer.print_region(region)
state.last_was_punctuation = False
state.should_emit_space = True


@dataclass(frozen=True)
class AttributeVariable(FormatDirective):
"""
Expand Down
Loading

0 comments on commit 600dd91

Please sign in to comment.