Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

core: (irdl) Add regions to declarative assembly format #3065

Merged
merged 13 commits into from
Aug 21, 2024
Merged
Empty file modified .github/workflows/update_xdsl_pyodide_build.py
100755 → 100644
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
Empty file.
179 changes: 179 additions & 0 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@
opt_attr_def,
opt_operand_def,
opt_prop_def,
opt_region_def,
opt_result_def,
prop_def,
region_def,
result_def,
var_operand_def,
var_region_def,
var_result_def,
)
from xdsl.parser import Parser
Expand Down Expand Up @@ -1086,6 +1089,182 @@
check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)

################################################################################
# Regions #
################################################################################

def test_missing_region():
"""Test that regions should be parsed."""
with pytest.raises(PyRDLOpDefinitionError, match="region 'region' not found"):

@irdl_op_definition
class NoRegionTypeOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.no_region_op"
region = region_def()

assembly_format = "attr-dict-with-keyword"

def test_region_missing_keyword():
"""Test that regions require the 'attr-dict-with-keyword' directive."""
with pytest.raises(PyRDLOpDefinitionError, match="'attr-dict' directive must be 'attr-dict-with-keyword'"):

@irdl_op_definition
class NoRegionTypeOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.region_op_missing_keyword"
region = region_def()

assembly_format = "attr-dict $region"

def test_region_before_attr():
"""Test that regions should appear after the 'attr-dict-with-keyword' directive."""
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(PyRDLOpDefinitionError, match="'attr-dict' directive must appear"):

@irdl_op_definition
class NoRegionTypeOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.region_before_attr"
region = region_def()

assembly_format = "$region attr-dict-with-keyword"



@irdl_op_definition
class TypedAttributeOp(IRDLOperation):
name = "test.typed_attr"
attr = attr_def(IntegerAttr[I32])

assembly_format = "$attr attr-dict"


@pytest.mark.parametrize(
"format, program, generic_program",
[
(
"attr-dict-with-keyword $fst $snd",
"test.two_regions {\n} {\n}",
'"test.two_regions"() ({}, {}) : () -> ()',
),
(
"attr-dict-with-keyword $fst $snd",
'test.two_regions {\n test.typed_attr 3\n} {\n test.typed_attr 3\n}',
'"test.two_regions"() ({ test.typed_attr 3}, { test.typed_attr 3}) : () -> ()',
),
(
"attr-dict-with-keyword $fst $snd",
'test.two_regions attributes {"a" = 2 : i32} {\n test.typed_attr 3\n} {\n test.typed_attr 3\n}',
'"test.two_regions"() ({ test.typed_attr 3}, { test.typed_attr 3}) {"a" = 2 : i32} : () -> ()',
),
],
)
def test_regions(format: str, program: str, generic_program: str):
"""Test the parsing of regions"""

@irdl_op_definition
class TwoRegionsOp(IRDLOperation):
name = "test.two_regions"
fst = region_def()
snd = region_def()

assembly_format = format

ctx = MLContext()
ctx.load_op(TwoRegionsOp)
ctx.load_op(TypedAttributeOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)

@pytest.mark.parametrize(
"format, program, generic_program",
[
(
"attr-dict-with-keyword $region",
"test.variadic_region ",
'"test.variadic_region"() : () -> ()',
),
(
"attr-dict-with-keyword $region",
'test.variadic_region {\n "test.op"() : () -> ()\n}',
'"test.variadic_region"() ({ "test.op"() : () -> ()}) : () -> ()',
),
(
"attr-dict-with-keyword $region",
'test.variadic_region {\n "test.op"() : () -> ()\n} {\n "test.op"() : () -> ()\n}',
'"test.variadic_region"() ({ "test.op"() : () -> ()}, { "test.op"() : () -> ()}) : () -> ()',
),
(
"attr-dict-with-keyword $region",
'test.variadic_region {\n "test.op"() : () -> ()\n} {\n "test.op"() : () -> ()\n}',
'"test.variadic_region"() ({ "test.op"() : () -> ()}, {"test.op"() : () -> ()}) : () -> ()',
),
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
],
)
def test_variadic_region(format: str, program: str, generic_program: str):
"""Test the parsing of variadic regions"""

@irdl_op_definition
class VariadicRegionOp(IRDLOperation):
name = "test.variadic_region"
region = var_region_def()

assembly_format = format

ctx = MLContext()
ctx.load_op(VariadicRegionOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)


@pytest.mark.parametrize(
"format, program, generic_program",
[
(
"attr-dict-with-keyword $region",
"test.optional_region ",
'"test.optional_region"() : () -> ()',
),
(
"attr-dict-with-keyword $region",
'test.optional_region {\n "test.op"() : () -> ()\n}',
'"test.optional_region"() ({ "test.op"() : () -> ()}) : () -> ()',
),
],
)
def test_optional_region(format: str, program: str, generic_program: str):
"""Test the parsing of optional regions"""

@irdl_op_definition
class OptionalRegionOp(IRDLOperation):
name = "test.optional_region"
region = opt_region_def()

assembly_format = format

ctx = MLContext()
ctx.load_op(OptionalRegionOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)

def test_multiple_optional_regions():
"""Test the parsing of multiple optional regions"""

"""Test that multiple optional regions requires the ABCMeta PyRDL option."""
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(PyRDLOpDefinitionError, match='Operation test.optional_regions defines more than two variadic regions'):
@irdl_op_definition
class OptionalOperandsOp(IRDLOperation):
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
name = "test.optional_regions"
region1 = opt_region_def()
region2 = opt_region_def()

assembly_format = (
"attr-dict-with-keyword $region1 $region2"
)


kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
################################################################################
# Inference #
Expand All @@ -1109,7 +1288,7 @@
class TwoOperandsOneResultWithVarOp(IRDLOperation):
T = Annotated[Attribute, ConstraintVar("T")]

name = "test.two_operands_one_result_with_var"

Check failure on line 1291 in tests/irdl/test_declarative_assembly_format.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Class "OptionalOperandsOp" is not accessed (reportUnusedClass)
res = result_def(T)
lhs = operand_def(T)
rhs = operand_def(T)
Expand Down
90 changes: 87 additions & 3 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass, field
from os import stat

Check failure on line 12 in xdsl/irdl/declarative_assembly_format.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Import "stat" is not accessed (reportUnusedImport)
from re import A

Check failure on line 13 in xdsl/irdl/declarative_assembly_format.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Import "A" is not accessed (reportUnusedImport)
from typing import Any, Literal, cast

from xdsl.dialects.builtin import UnitAttr
Expand All @@ -18,7 +20,9 @@
ParametrizedAttribute,
SSAValue,
TypedAttribute,
Region,
)
from xdsl.ir.core import Block, Operation
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
from xdsl.irdl import (
ConstraintContext,
IRDLOperation,
Expand Down Expand Up @@ -48,19 +52,21 @@
operands: list[UnresolvedOperand | None | list[UnresolvedOperand | None]]
operand_types: list[Attribute | None | list[Attribute | None]]
result_types: list[Attribute | None | list[Attribute | None]]
regions: list[Region | None | list[Region | None]]
attributes: dict[str, Attribute]
properties: dict[str, Attribute]
constraint_context: ConstraintContext

def __init__(self, op_def: OpDef):
if op_def.regions or op_def.successors:
if op_def.successors:
raise NotImplementedError(
"Operation definitions with regions "
"or successors are not yet supported"
"Operation definitions with "
"successors are not yet supported"
)
self.operands = [None] * len(op_def.operands)
self.operand_types = [None] * len(op_def.operands)
self.result_types = [None] * len(op_def.results)
self.regions = [None] * len(op_def.regions)
self.attributes = {}
self.properties = {}
self.constraint_context = ConstraintContext()
Expand Down Expand Up @@ -159,11 +165,16 @@
properties = state.properties
else:
properties = op_def.split_properties(state.attributes)

kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
# Get the regions and cast them to the type needed in op_type
regions = cast(Sequence[Region | Sequence[Operation] | Sequence[Block] | Sequence[Region | Sequence[Operation] | Sequence[Block]] | None], state.regions)

return op_type.build(
result_types=result_types,
operands=operands,
attributes=state.attributes,
properties=properties,
regions= regions,
)

def assign_constraint_variables(
Expand Down Expand Up @@ -722,6 +733,79 @@
state.last_was_punctuation = False
state.should_emit_space = True

@dataclass(frozen=True)
class RegionVariable(VariableDirective):
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
"""
A region variable, with the following format:
region-directive ::= dollar-ident
The directive will request a space to be printed after.
"""

def parse(self, parser: Parser, state: ParsingState) -> None:
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
region = parser.parse_region()
state.regions[self.index] = region

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
printer.print_region(getattr(op, self.name))
state.last_was_punctuation = False
state.should_emit_space = True


@dataclass(frozen=True)
class VariadicRegionVariable(
VariadicVariable, VariableDirective, OptionallyParsableDirective
):
"""
A variadic region variable, with the following format:
region-directive ::= ( percent-ident ( `,` percent-id )* )?
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
The directive will request a space to be printed after.
"""

def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
regions : list[Region] = []
current_region = parser.parse_optional_region()
while current_region is not None:
regions.append(current_region)
current_region = parser.parse_optional_region()

state.regions[self.index] = cast(list[Region | None], regions)
return bool(regions)

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
breakpoint()
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
region = getattr(op, self.name)
if region:
printer.print_list(region, printer.print_region, delimiter=" ")
state.last_was_punctuation = False
state.should_emit_space = True


class OptionalRegionVariable(OptionalVariable, OptionallyParsableDirective):
"""
An optional region variable, with the following format:
region-directive ::= ( percent-ident )?
kimxworrall marked this conversation as resolved.
Show resolved Hide resolved
The directive will request a space to be printed after.
"""

def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
region = parser.parse_optional_region()
if region is None:
region = list[Region | None]()
state.regions[self.index] = region
return bool(region)

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
region = getattr(op, self.name)
if region:
printer.print_region(region)
state.last_was_punctuation = False
state.should_emit_space = True

@dataclass(frozen=True)
class AttributeVariable(FormatDirective):
Expand Down
Loading
Loading