Skip to content

Commit

Permalink
Add regions to declarative_assembly_format and the parser, and add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
kimxworrall committed Aug 19, 2024
1 parent 4a5e0df commit 68c66e9
Show file tree
Hide file tree
Showing 6 changed files with 323 additions and 5 deletions.
Empty file modified .github/workflows/update_xdsl_pyodide_build.py
100755 → 100644
Empty file.
179 changes: 179 additions & 0 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,13 @@
opt_attr_def,
opt_operand_def,
opt_prop_def,
opt_region_def,
opt_result_def,
prop_def,
region_def,
result_def,
var_operand_def,
var_region_def,
var_result_def,
)
from xdsl.parser import Parser
Expand Down Expand Up @@ -1086,6 +1089,182 @@ class OptionalResultOp(IRDLOperation):
check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)

################################################################################
# Regions #
################################################################################

def test_missing_region():
"""Test that regions should be parsed."""
with pytest.raises(PyRDLOpDefinitionError, match="region 'region' not found"):

@irdl_op_definition
class NoRegionTypeOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.no_region_op"
region = region_def()

assembly_format = "attr-dict-with-keyword"

def test_region_missing_keyword():
"""Test that regions require the 'attr-dict-with-keyword' directive."""
with pytest.raises(PyRDLOpDefinitionError, match="'attr-dict' directive must be 'attr-dict-with-keyword'"):

@irdl_op_definition
class NoRegionTypeOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.region_op_missing_keyword"
region = region_def()

assembly_format = "attr-dict $region"

def test_region_before_attr():
"""Test that regions should appear after the 'attr-dict-with-keyword' directive."""
with pytest.raises(PyRDLOpDefinitionError, match="'attr-dict' directive must appear"):

@irdl_op_definition
class NoRegionTypeOp(IRDLOperation): # pyright: ignore[reportUnusedClass]
name = "test.region_before_attr"
region = region_def()

assembly_format = "$region attr-dict-with-keyword"



@irdl_op_definition
class TypedAttributeOp(IRDLOperation):
name = "test.typed_attr"
attr = attr_def(IntegerAttr[I32])

assembly_format = "$attr attr-dict"


@pytest.mark.parametrize(
"format, program, generic_program",
[
(
"attr-dict-with-keyword $fst $snd",
"test.two_regions {\n} {\n}",
'"test.two_regions"() ({}, {}) : () -> ()',
),
(
"attr-dict-with-keyword $fst $snd",
'test.two_regions {\n test.typed_attr 3\n} {\n test.typed_attr 3\n}',
'"test.two_regions"() ({ test.typed_attr 3}, { test.typed_attr 3}) : () -> ()',
),
(
"attr-dict-with-keyword $fst $snd",
'test.two_regions attributes {"a" = 2 : i32} {\n test.typed_attr 3\n} {\n test.typed_attr 3\n}',
'"test.two_regions"() ({ test.typed_attr 3}, { test.typed_attr 3}) {"a" = 2 : i32} : () -> ()',
),
],
)
def test_regions(format: str, program: str, generic_program: str):
"""Test the parsing of regions"""

@irdl_op_definition
class TwoRegionsOp(IRDLOperation):
name = "test.two_regions"
fst = region_def()
snd = region_def()

assembly_format = format

ctx = MLContext()
ctx.load_op(TwoRegionsOp)
ctx.load_op(TypedAttributeOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)

@pytest.mark.parametrize(
"format, program, generic_program",
[
(
"attr-dict-with-keyword $region",
"test.variadic_region ",
'"test.variadic_region"() : () -> ()',
),
(
"attr-dict-with-keyword $region",
'test.variadic_region {\n "test.op"() : () -> ()\n}',
'"test.variadic_region"() ({ "test.op"() : () -> ()}) : () -> ()',
),
(
"attr-dict-with-keyword $region",
'test.variadic_region {\n "test.op"() : () -> ()\n} {\n "test.op"() : () -> ()\n}',
'"test.variadic_region"() ({ "test.op"() : () -> ()}, { "test.op"() : () -> ()}) : () -> ()',
),
(
"attr-dict-with-keyword $region",
'test.variadic_region {\n "test.op"() : () -> ()\n} {\n "test.op"() : () -> ()\n}',
'"test.variadic_region"() ({ "test.op"() : () -> ()}, {"test.op"() : () -> ()}) : () -> ()',
),
],
)
def test_variadic_region(format: str, program: str, generic_program: str):
"""Test the parsing of variadic regions"""

@irdl_op_definition
class VariadicRegionOp(IRDLOperation):
name = "test.variadic_region"
region = var_region_def()

assembly_format = format

ctx = MLContext()
ctx.load_op(VariadicRegionOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)


@pytest.mark.parametrize(
"format, program, generic_program",
[
(
"attr-dict-with-keyword $region",
"test.optional_region ",
'"test.optional_region"() : () -> ()',
),
(
"attr-dict-with-keyword $region",
'test.optional_region {\n "test.op"() : () -> ()\n}',
'"test.optional_region"() ({ "test.op"() : () -> ()}) : () -> ()',
),
],
)
def test_optional_region(format: str, program: str, generic_program: str):
"""Test the parsing of optional regions"""

@irdl_op_definition
class OptionalRegionOp(IRDLOperation):
name = "test.optional_region"
region = opt_region_def()

assembly_format = format

ctx = MLContext()
ctx.load_op(OptionalRegionOp)
ctx.load_dialect(Test)

check_roundtrip(program, ctx)
check_equivalence(program, generic_program, ctx)

def test_multiple_optional_regions():
"""Test the parsing of multiple optional regions"""

"""Test that multiple optional regions requires the ABCMeta PyRDL option."""
with pytest.raises(PyRDLOpDefinitionError, match='Operation test.optional_regions defines more than two variadic regions'):
@irdl_op_definition
class OptionalOperandsOp(IRDLOperation):
name = "test.optional_regions"
region1 = opt_region_def()
region2 = opt_region_def()

assembly_format = (
"attr-dict-with-keyword $region1 $region2"
)


################################################################################
# Inference #
Expand Down
90 changes: 87 additions & 3 deletions xdsl/irdl/declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from abc import ABC, abstractmethod
from collections.abc import Sequence
from dataclasses import dataclass, field
from os import stat

Check failure on line 12 in xdsl/irdl/declarative_assembly_format.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Import "stat" is not accessed (reportUnusedImport)
from re import A

Check failure on line 13 in xdsl/irdl/declarative_assembly_format.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Import "A" is not accessed (reportUnusedImport)
from typing import Any, Literal, cast

from xdsl.dialects.builtin import UnitAttr
Expand All @@ -18,7 +20,9 @@
ParametrizedAttribute,
SSAValue,
TypedAttribute,
Region,
)
from xdsl.ir.core import Block, Operation
from xdsl.irdl import (
ConstraintContext,
IRDLOperation,
Expand Down Expand Up @@ -48,19 +52,21 @@ class ParsingState:
operands: list[UnresolvedOperand | None | list[UnresolvedOperand | None]]
operand_types: list[Attribute | None | list[Attribute | None]]
result_types: list[Attribute | None | list[Attribute | None]]
regions: list[Region | None | list[Region | None]]
attributes: dict[str, Attribute]
properties: dict[str, Attribute]
constraint_context: ConstraintContext

def __init__(self, op_def: OpDef):
if op_def.regions or op_def.successors:
if op_def.successors:
raise NotImplementedError(
"Operation definitions with regions "
"or successors are not yet supported"
"Operation definitions with "
"successors are not yet supported"
)
self.operands = [None] * len(op_def.operands)
self.operand_types = [None] * len(op_def.operands)
self.result_types = [None] * len(op_def.results)
self.regions = [None] * len(op_def.regions)
self.attributes = {}
self.properties = {}
self.constraint_context = ConstraintContext()
Expand Down Expand Up @@ -159,11 +165,16 @@ def parse(
properties = state.properties
else:
properties = op_def.split_properties(state.attributes)

# Get the regions and cast them to the type needed in op_type
regions = cast(Sequence[Region | Sequence[Operation] | Sequence[Block] | Sequence[Region | Sequence[Operation] | Sequence[Block]] | None], state.regions)

return op_type.build(
result_types=result_types,
operands=operands,
attributes=state.attributes,
properties=properties,
regions= regions,
)

def assign_constraint_variables(
Expand Down Expand Up @@ -722,6 +733,79 @@ def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> No
state.last_was_punctuation = False
state.should_emit_space = True

@dataclass(frozen=True)
class RegionVariable(VariableDirective):
"""
A region variable, with the following format:
region-directive ::= dollar-ident
The directive will request a space to be printed after.
"""

def parse(self, parser: Parser, state: ParsingState) -> None:
region = parser.parse_region()
state.regions[self.index] = region

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
printer.print_region(getattr(op, self.name))
state.last_was_punctuation = False
state.should_emit_space = True


@dataclass(frozen=True)
class VariadicRegionVariable(
VariadicVariable, VariableDirective, OptionallyParsableDirective
):
"""
A variadic region variable, with the following format:
region-directive ::= ( percent-ident ( `,` percent-id )* )?
The directive will request a space to be printed after.
"""

def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
regions : list[Region] = []
current_region = parser.parse_optional_region()
while current_region is not None:
regions.append(current_region)
current_region = parser.parse_optional_region()

state.regions[self.index] = cast(list[Region | None], regions)
return bool(regions)

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
breakpoint()
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
region = getattr(op, self.name)
if region:
printer.print_list(region, printer.print_region, delimiter=" ")
state.last_was_punctuation = False
state.should_emit_space = True


class OptionalRegionVariable(OptionalVariable, OptionallyParsableDirective):
"""
An optional region variable, with the following format:
region-directive ::= ( percent-ident )?
The directive will request a space to be printed after.
"""

def parse_optional(self, parser: Parser, state: ParsingState) -> bool:
region = parser.parse_optional_region()
if region is None:
region = list[Region | None]()
state.regions[self.index] = region
return bool(region)

def print(self, printer: Printer, state: PrintingState, op: IRDLOperation) -> None:
if state.should_emit_space or not state.last_was_punctuation:
printer.print(" ")
region = getattr(op, self.name)
if region:
printer.print_region(region)
state.last_was_punctuation = False
state.should_emit_space = True

@dataclass(frozen=True)
class AttributeVariable(FormatDirective):
Expand Down
Loading

0 comments on commit 68c66e9

Please sign in to comment.