Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: (irdl) Add regions to declarative assembly format #3065

Merged
merged 13 commits into from
Aug 21, 2024
Merged
210 changes: 210 additions & 0 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@
opt_attr_def,
opt_operand_def,
opt_prop_def,
opt_region_def,
opt_result_def,
prop_def,
region_def,
result_def,
var_operand_def,
var_region_def,
var_result_def,
)
from xdsl.parser import Parser
Expand Down Expand Up @@ -1119,6 +1122,213 @@ class OptionalResultOp(IRDLOperation):
check_equivalence(program, generic_program, ctx)


################################################################################
# Regions #
################################################################################


def test_missing_region():
"""Test that regions should be parsed."""
with pytest.raises(PyRDLOpDefinitionError, match="region 'region' not found"):

@irdl_op_definition
class NoRegionOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.no_region_op"
region = region_def()

assembly_format = "attr-dict-with-keyword"


def test_attr_dict_directly_before_region_variable():
"""Test that regions require an 'attr-dict' directive."""
with pytest.raises(
PyRDLOpDefinitionError,
match="An `attr-dict' directive without keyword cannot be directly followed by a region variable",
):

@irdl_op_definition
class RegionAttrDictWrongOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.region_op_missing_keyword"
region = region_def()

assembly_format = "attr-dict $region"


@pytest.mark.parametrize(
"format, program, generic_program",
[
(
"$region attr-dict",
'test.region_attr_dict {\n} {"a" = 2 : i32}',
'"test.region_attr_dict"() ({}) {"a" = 2 : i32} : () -> ()',
),
(
"attr-dict `,` $region",
'test.region_attr_dict {"a" = 2 : i32}, {\n "test.op"() : () -> ()\n}',
'"test.region_attr_dict"() ({ "test.op"() : () -> ()}) {"a" = 2 : i32} : () -> ()',
),
],
)
def test_regions_with_attr_dict(format: str, program: str, generic_program: str):
"""Test the parsing of regions"""

@irdl_op_definition
class RegionsOp(IRDLOperation):
name = "test.region_attr_dict"
region = region_def()

assembly_format = format

ctx = MLContext()
ctx.load_op(RegionsOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)


@irdl_op_definition
class MiscOp(IRDLOperation):
name = "test.typed_attr"
attr = attr_def(IntegerAttr[I32])

assembly_format = "$attr attr-dict"
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved


@pytest.mark.parametrize(
"format, program, generic_program",
[
(
"attr-dict-with-keyword $fst $snd",
"test.two_regions {\n} {\n}",
'"test.two_regions"() ({}, {}) : () -> ()',
),
(
"attr-dict-with-keyword $fst $snd",
"test.two_regions {\n test.typed_attr 3\n} {\n test.typed_attr 3\n}",
'"test.two_regions"() ({ test.typed_attr 3}, { test.typed_attr 3}) : () -> ()',
),
(
"attr-dict-with-keyword $fst $snd",
'test.two_regions attributes {"a" = 2 : i32} {\n test.typed_attr 3\n} {\n test.typed_attr 3\n}',
'"test.two_regions"() ({ test.typed_attr 3}, { test.typed_attr 3}) {"a" = 2 : i32} : () -> ()',
),
],
)
def test_regions(format: str, program: str, generic_program: str):
"""Test the parsing of regions"""

@irdl_op_definition
class TwoRegionsOp(IRDLOperation):
name = "test.two_regions"
fst = region_def()
snd = region_def()

assembly_format = format

ctx = MLContext()
ctx.load_op(TwoRegionsOp)
ctx.load_op(MiscOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)


@pytest.mark.parametrize(
"format, program, generic_program",
[
(
"attr-dict-with-keyword $region",
"test.variadic_region ",
'"test.variadic_region"() : () -> ()',
),
(
"attr-dict-with-keyword $region",
'test.variadic_region {\n "test.op"() : () -> ()\n}',
'"test.variadic_region"() ({ "test.op"() : () -> ()}) : () -> ()',
),
(
"attr-dict-with-keyword $region",
'test.variadic_region {\n "test.op"() : () -> ()\n} {\n "test.op"() : () -> ()\n}',
'"test.variadic_region"() ({ "test.op"() : () -> ()}, { "test.op"() : () -> ()}) : () -> ()',
),
(
"attr-dict-with-keyword $region",
'test.variadic_region {\n "test.op"() : () -> ()\n} {\n "test.op"() : () -> ()\n}',
'"test.variadic_region"() ({ "test.op"() : () -> ()}, {"test.op"() : () -> ()}) : () -> ()',
),
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
],
)
def test_variadic_region(format: str, program: str, generic_program: str):
"""Test the parsing of variadic regions"""

@irdl_op_definition
class VariadicRegionOp(IRDLOperation):
name = "test.variadic_region"
region = var_region_def()

assembly_format = format

ctx = MLContext()
ctx.load_op(VariadicRegionOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)


@pytest.mark.parametrize(
"format, program, generic_program",
[
(
"attr-dict-with-keyword $region",
"test.optional_region ",
'"test.optional_region"() : () -> ()',
),
(
"attr-dict-with-keyword $region",
'test.optional_region {\n "test.op"() : () -> ()\n}',
'"test.optional_region"() ({ "test.op"() : () -> ()}) : () -> ()',
),
],
)
def test_optional_region(format: str, program: str, generic_program: str):
"""Test the parsing of optional regions"""

@irdl_op_definition
class OptionalRegionOp(IRDLOperation):
name = "test.optional_region"
region = opt_region_def()

assembly_format = format

ctx = MLContext()
ctx.load_op(OptionalRegionOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)


def test_multiple_optional_regions():
"""Test the parsing of multiple optional regions"""

"""Test that multiple optional regions requires the ABCMeta PyRDL option."""
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(
PyRDLOpDefinitionError,
match="Operation test.optional_regions defines more than two variadic regions",
):

@irdl_op_definition
class OptionalOperandsOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.optional_regions"
region1 = opt_region_def()
region2 = opt_region_def()

assembly_format = "attr-dict-with-keyword $region1 $region2"


kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
################################################################################
# Inference #
################################################################################
Expand Down
88 changes: 85 additions & 3 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Attribute,
Data,
ParametrizedAttribute,
Region,
SSAValue,
TypedAttribute,
)
Expand Down Expand Up @@ -48,19 +49,20 @@ class ParsingState:
operands: list[UnresolvedOperand | None | list[UnresolvedOperand | None]]
operand_types: list[Attribute | None | list[Attribute | None]]
result_types: list[Attribute | None | list[Attribute | None]]
regions: list[Region | None | list[Region]]
attributes: dict[str, Attribute]
properties: dict[str, Attribute]
constraint_context: ConstraintContext

def __init__(self, op_def: OpDef):
if op_def.regions or op_def.successors:
if op_def.successors:
raise NotImplementedError(
"Operation definitions with regions "
"or successors are not yet supported"
"Operation definitions with successors are not yet supported"
)
self.operands = [None] * len(op_def.operands)
self.operand_types = [None] * len(op_def.operands)
self.result_types = [None] * len(op_def.results)
self.regions = [None] * len(op_def.regions)
self.attributes = {}
self.properties = {}
self.constraint_context = ConstraintContext()
Expand Down Expand Up @@ -164,6 +166,7 @@ def parse(
operands=operands,
attributes=state.attributes,
properties=properties,
regions=state.regions,
)

def assign_constraint_variables(
Expand Down Expand Up @@ -723,6 +726,85 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No
state.should_emit_space = True


@dataclass(frozen=True)
class RegionVariable(VariableDirective, OptionallyParsableDirective):
"""
A region variable, with the following format:
region-directive ::= dollar-ident
The directive will request a space to be printed after.
"""

def parse(self, parser: Parser, state: ParsingState) -> None:
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
region = parser.parse_region()
state.regions[self.index] = region

def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
region = parser.parse_optional_region()
state.regions[self.index] = region
return region is not None

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
printer.print_region(getattr(op, self.name))
state.last_was_punctuation = False
state.should_emit_space = True


@dataclass(frozen=True)
class VariadicRegionVariable(
VariadicVariable, VariableDirective, OptionallyParsableDirective
):
"""
A variadic region variable, with the following format:
region-directive ::= dollar-ident
The directive will request a space to be printed after.
"""

def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
regions: list[Region] = []
current_region = parser.parse_optional_region()
while current_region is not None:
regions.append(current_region)
current_region = parser.parse_optional_region()

state.regions[self.index] = regions
return bool(regions)

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
region = getattr(op, self.name)
if region:
printer.print_list(region, printer.print_region, delimiter=" ")
state.last_was_punctuation = False
state.should_emit_space = True


class OptionalRegionVariable(OptionalVariable, OptionallyParsableDirective):
"""
An optional region variable, with the following format:
region-directive ::= dollar-ident
The directive will request a space to be printed after.
"""

def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
region = parser.parse_optional_region()
if region is None:
region = list[Region]()
state.regions[self.index] = region
return bool(region)

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
region = getattr(op, self.name)
if region:
printer.print_region(region)
state.last_was_punctuation = False
state.should_emit_space = True


@dataclass(frozen=True)
class AttributeVariable(FormatDirective):
"""
Expand Down
Loading
Loading