Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: (irdl) Add regions to declarative assembly format #3065

Merged
merged 13 commits into from
Aug 21, 2024
Merged
Empty file modified .github/workflows/update_xdsl_pyodide_build.py
100755 → 100644
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
Empty file.
191 changes: 191 additions & 0 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@
opt_attr_def,
opt_operand_def,
opt_prop_def,
opt_region_def,
opt_result_def,
prop_def,
region_def,
result_def,
var_operand_def,
var_region_def,
var_result_def,
)
from xdsl.parser import Parser
Expand Down Expand Up @@ -1087,6 +1090,194 @@
check_equivalence(program, generic_program, ctx)


################################################################################
# Regions #
################################################################################


def test_missing_region():
"""Test that regions should be parsed."""
with pytest.raises(PyRDLOpDefinitionError, match="region 'region' not found"):

@irdl_op_definition
class NoRegionTypeOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.no_region_op"
region = region_def()

assembly_format = "attr-dict-with-keyword"


def test_region_missing_keyword():
"""Test that regions require the 'attr-dict-with-keyword' directive."""
with pytest.raises(
PyRDLOpDefinitionError,
match="'attr-dict' directive must be 'attr-dict-with-keyword'",
):

@irdl_op_definition
class NoRegionTypeOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.region_op_missing_keyword"
region = region_def()

assembly_format = "attr-dict $region"


def test_region_before_attr():
"""Test that regions should appear after the 'attr-dict-with-keyword' directive."""
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(
PyRDLOpDefinitionError, match="'attr-dict' directive must appear"
):

@irdl_op_definition
class NoRegionTypeOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.region_before_attr"
region = region_def()

assembly_format = "$region attr-dict-with-keyword"


@irdl_op_definition
class TypedAttributeOp(IRDLOperation):
name = "test.typed_attr"
attr = attr_def(IntegerAttr[I32])

assembly_format = "$attr attr-dict"


@pytest.mark.parametrize(
"format, program, generic_program",
[
(
"attr-dict-with-keyword $fst $snd",
"test.two_regions {\n} {\n}",
'"test.two_regions"() ({}, {}) : () -> ()',
),
(
"attr-dict-with-keyword $fst $snd",
"test.two_regions {\n test.typed_attr 3\n} {\n test.typed_attr 3\n}",
'"test.two_regions"() ({ test.typed_attr 3}, { test.typed_attr 3}) : () -> ()',
),
(
"attr-dict-with-keyword $fst $snd",
'test.two_regions attributes {"a" = 2 : i32} {\n test.typed_attr 3\n} {\n test.typed_attr 3\n}',
'"test.two_regions"() ({ test.typed_attr 3}, { test.typed_attr 3}) {"a" = 2 : i32} : () -> ()',
),
],
)
def test_regions(format: str, program: str, generic_program: str):
"""Test the parsing of regions"""

@irdl_op_definition
class TwoRegionsOp(IRDLOperation):
name = "test.two_regions"
fst = region_def()
snd = region_def()

assembly_format = format

ctx = MLContext()
ctx.load_op(TwoRegionsOp)
ctx.load_op(TypedAttributeOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)


@pytest.mark.parametrize(
"format, program, generic_program",
[
(
"attr-dict-with-keyword $region",
"test.variadic_region ",
'"test.variadic_region"() : () -> ()',
),
(
"attr-dict-with-keyword $region",
'test.variadic_region {\n "test.op"() : () -> ()\n}',
'"test.variadic_region"() ({ "test.op"() : () -> ()}) : () -> ()',
),
(
"attr-dict-with-keyword $region",
'test.variadic_region {\n "test.op"() : () -> ()\n} {\n "test.op"() : () -> ()\n}',
'"test.variadic_region"() ({ "test.op"() : () -> ()}, { "test.op"() : () -> ()}) : () -> ()',
),
(
"attr-dict-with-keyword $region",
'test.variadic_region {\n "test.op"() : () -> ()\n} {\n "test.op"() : () -> ()\n}',
'"test.variadic_region"() ({ "test.op"() : () -> ()}, {"test.op"() : () -> ()}) : () -> ()',
),
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
],
)
def test_variadic_region(format: str, program: str, generic_program: str):
"""Test the parsing of variadic regions"""

@irdl_op_definition
class VariadicRegionOp(IRDLOperation):
name = "test.variadic_region"
region = var_region_def()

assembly_format = format

ctx = MLContext()
ctx.load_op(VariadicRegionOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)


@pytest.mark.parametrize(
"format, program, generic_program",
[
(
"attr-dict-with-keyword $region",
"test.optional_region ",
'"test.optional_region"() : () -> ()',
),
(
"attr-dict-with-keyword $region",
'test.optional_region {\n "test.op"() : () -> ()\n}',
'"test.optional_region"() ({ "test.op"() : () -> ()}) : () -> ()',
),
],
)
def test_optional_region(format: str, program: str, generic_program: str):
"""Test the parsing of optional regions"""

@irdl_op_definition
class OptionalRegionOp(IRDLOperation):
name = "test.optional_region"
region = opt_region_def()

assembly_format = format

ctx = MLContext()
ctx.load_op(OptionalRegionOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)


def test_multiple_optional_regions():
"""Test the parsing of multiple optional regions"""

"""Test that multiple optional regions requires the ABCMeta PyRDL option."""
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(
PyRDLOpDefinitionError,
match="Operation test.optional_regions defines more than two variadic regions",
):

@irdl_op_definition
class OptionalOperandsOp(IRDLOperation):
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
name = "test.optional_regions"
region1 = opt_region_def()
region2 = opt_region_def()

assembly_format = "attr-dict-with-keyword $region1 $region2"


kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
################################################################################
# Inference #
################################################################################
Expand All @@ -1111,7 +1302,7 @@

name = "test.two_operands_one_result_with_var"
res = result_def(T)
lhs = operand_def(T)

Check failure on line 1305 in tests/irdl/test_declarative_assembly_format.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Class "OptionalOperandsOp" is not accessed (reportUnusedClass)
rhs = operand_def(T)

assembly_format = format
Expand Down
98 changes: 95 additions & 3 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,17 @@
Attribute,
Data,
ParametrizedAttribute,
Region,
SSAValue,
TypedAttribute,
)
from xdsl.irdl import (
Block,
ConstraintContext,
IRDLOperation,
IRDLOperationInvT,
OpDef,
Operation,
OptionalDef,
VariadicDef,
VarIRConstruct,
Expand All @@ -48,19 +51,20 @@ class ParsingState:
operands: list[UnresolvedOperand | None | list[UnresolvedOperand | None]]
operand_types: list[Attribute | None | list[Attribute | None]]
result_types: list[Attribute | None | list[Attribute | None]]
regions: list[Region | None | list[Region | None]]
attributes: dict[str, Attribute]
properties: dict[str, Attribute]
constraint_context: ConstraintContext

def __init__(self, op_def: OpDef):
if op_def.regions or op_def.successors:
if op_def.successors:
raise NotImplementedError(
"Operation definitions with regions "
"or successors are not yet supported"
"Operation definitions with " "successors are not yet supported"
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
)
self.operands = [None] * len(op_def.operands)
self.operand_types = [None] * len(op_def.operands)
self.result_types = [None] * len(op_def.results)
self.regions = [None] * len(op_def.regions)
self.attributes = {}
self.properties = {}
self.constraint_context = ConstraintContext()
Expand Down Expand Up @@ -159,11 +163,25 @@ def parse(
properties = state.properties
else:
properties = op_def.split_properties(state.attributes)

kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
# Get the regions and cast them to the type needed in op_type
regions = cast(
Sequence[
Region
| Sequence[Operation]
| Sequence[Block]
| Sequence[Region | Sequence[Operation] | Sequence[Block]]
| None
],
state.regions,
)
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved

return op_type.build(
result_types=result_types,
operands=operands,
attributes=state.attributes,
properties=properties,
regions=regions,
)

def assign_constraint_variables(
Expand Down Expand Up @@ -723,6 +741,80 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No
state.should_emit_space = True


@dataclass(frozen=True)
class RegionVariable(VariableDirective):
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
"""
A region variable, with the following format:
region-directive ::= dollar-ident
The directive will request a space to be printed after.
"""

def parse(self, parser: Parser, state: ParsingState) -> None:
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
region = parser.parse_region()
state.regions[self.index] = region

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
printer.print_region(getattr(op, self.name))
state.last_was_punctuation = False
state.should_emit_space = True


@dataclass(frozen=True)
class VariadicRegionVariable(
VariadicVariable, VariableDirective, OptionallyParsableDirective
):
"""
A variadic region variable, with the following format:
region-directive ::= ( percent-ident ( `,` percent-id )* )?
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
The directive will request a space to be printed after.
"""

def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
regions: list[Region] = []
current_region = parser.parse_optional_region()
while current_region is not None:
regions.append(current_region)
current_region = parser.parse_optional_region()

state.regions[self.index] = cast(list[Region | None], regions)
return bool(regions)

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
region = getattr(op, self.name)
if region:
printer.print_list(region, printer.print_region, delimiter=" ")
state.last_was_punctuation = False
state.should_emit_space = True


class OptionalRegionVariable(OptionalVariable, OptionallyParsableDirective):
"""
An optional region variable, with the following format:
region-directive ::= ( percent-ident )?
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
The directive will request a space to be printed after.
"""

def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
region = parser.parse_optional_region()
if region is None:
region = list[Region | None]()
state.regions[self.index] = region
return bool(region)

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
region = getattr(op, self.name)
if region:
printer.print_region(region)
state.last_was_punctuation = False
state.should_emit_space = True


@dataclass(frozen=True)
class AttributeVariable(FormatDirective):
"""
Expand Down
Loading
Loading