Skip to content

Commit

Permalink
api: constr helpers for generic attributes now standalone functions
Browse files Browse the repository at this point in the history
  • Loading branch information
superlopuh committed Oct 11, 2024
1 parent b431115 commit 9a30300
Show file tree
Hide file tree
Showing 17 changed files with 158 additions and 133 deletions.
40 changes: 18 additions & 22 deletions tests/irdl/test_declarative_assembly_format.py
Original file line number Diff line number Diff line change
Expand Up @@ -1664,25 +1664,23 @@ class ParamOne(ParametrizedAttribute, TypeAttribute, Generic[_T]):
p: ParameterDef[_T]
q: ParameterDef[Attribute]

@classmethod
def constr(
cls,
*,
n: GenericAttrConstraint[Attribute] | None = None,
p: GenericAttrConstraint[_T] | None = None,
q: GenericAttrConstraint[Attribute] | None = None,
) -> BaseAttr[ParamOne[Attribute]] | ParamAttrConstraint[ParamOne[_T]]:
if n is None and p is None and q is None:
return BaseAttr(cls)
return ParamAttrConstraint(cls, (n, p, q))
def ParamOneConstr(
*,
n: GenericAttrConstraint[Attribute] | None = None,
p: GenericAttrConstraint[_T] | None = None,
q: GenericAttrConstraint[Attribute] | None = None,
) -> BaseAttr[ParamOne[Attribute]] | ParamAttrConstraint[ParamOne[_T]]:
if n is None and p is None and q is None:
return BaseAttr[ParamOne[Attribute]](ParamOne)
return ParamAttrConstraint[ParamOne[_T]](ParamOne, (n, p, q))

@irdl_op_definition
class TwoOperandsNestedVarOp(IRDLOperation):
T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr())

name = "test.two_operands_one_result_with_var"
res = result_def(T)
lhs = operand_def(ParamOne[Attribute].constr(p=T))
lhs = operand_def(ParamOneConstr(p=T))
rhs = operand_def(T)

assembly_format = "$lhs $rhs attr-dict `:` type($lhs)"
Expand Down Expand Up @@ -1710,23 +1708,21 @@ class ParamOne(ParametrizedAttribute, TypeAttribute, Generic[_T]):
name = "test.param_one"
p: ParameterDef[_T]

@classmethod
def constr(
cls,
*,
p: GenericAttrConstraint[_T] | None = None,
) -> BaseAttr[ParamOne[Attribute]] | ParamAttrConstraint[ParamOne[_T]]:
if p is None:
return BaseAttr(cls)
return ParamAttrConstraint(cls, (p,))
def ParamOneConstr(
*,
p: GenericAttrConstraint[_T] | None = None,
) -> BaseAttr[ParamOne[Attribute]] | ParamAttrConstraint[ParamOne[_T]]:
if p is None:
return BaseAttr[ParamOne[Attribute]](ParamOne)
return ParamAttrConstraint[ParamOne[_T]](ParamOne, (p,))

@irdl_op_definition
class OneOperandOneResultNestedOp(IRDLOperation):
T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr())

name = "test.one_operand_one_result_nested"
res = result_def(T)
lhs = operand_def(ParamOne[Attribute].constr(p=T))
lhs = operand_def(ParamOneConstr(p=T))

assembly_format = "$lhs attr-dict `:` type($lhs)"

Expand Down
4 changes: 1 addition & 3 deletions tests/tblgen_to_py/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@ class Test_AnyOp(IRDLOperation):
class Test_AttributesOp(IRDLOperation):
name = "test.attributes"

int_attr = prop_def(
IntegerAttr[IntegerType].constr(type=EqAttrConstraint(IntegerType(16)))
)
int_attr = prop_def(IntegerAttrConstr(type=EqAttrConstraint(IntegerType(16))))

in_ = prop_def(BaseAttr(Test_TestAttr), prop_name="in")

Expand Down
1 change: 0 additions & 1 deletion tests/tblgen_to_py/test_tblgen.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,5 +11,4 @@ def test_run_tblgen_to_py():
with open("tests/tblgen_to_py/test.py") as f:
expected = f.read()

assert len(out_str.strip()) == len(expected.strip())
assert out_str.strip() == expected.strip()
5 changes: 3 additions & 2 deletions xdsl/dialects/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
IndexType,
IntegerAttr,
IntegerType,
MemRefTypeConstr,
ShapedType,
StringAttr,
)
Expand Down Expand Up @@ -262,7 +263,7 @@ class Store(IRDLOperation):
T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr())

value = operand_def(T)
memref = operand_def(MemRefType[Attribute].constr(element_type=T))
memref = operand_def(MemRefTypeConstr(element_type=T))
indices = var_operand_def(IndexType)
map = opt_prop_def(AffineMapAttr)

Expand Down Expand Up @@ -294,7 +295,7 @@ class Load(IRDLOperation):

T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr())

memref = operand_def(MemRefType[Attribute].constr(element_type=T))
memref = operand_def(MemRefTypeConstr(element_type=T))
indices = var_operand_def(IndexType)

result = result_def(T)
Expand Down
70 changes: 34 additions & 36 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,22 +513,21 @@ def parse_with_type(
def print_without_type(self, printer: Printer):
return printer.print(self.value.data)

@classmethod
def constr(
cls,
*,
value: AttrConstraint | None = None,
type: GenericAttrConstraint[_IntegerAttrType] = IntegerAttrTypeConstr,
) -> GenericAttrConstraint[IntegerAttr[_IntegerAttrType]]:
if value is None and type == AnyAttr():
return BaseAttr[IntegerAttr[_IntegerAttrType]](IntegerAttr)
return ParamAttrConstraint[IntegerAttr[_IntegerAttrType]](
IntegerAttr,
(
value,
type,
),
)

def IntegerAttrConstr(
*,
value: AttrConstraint | None = None,
type: GenericAttrConstraint[_IntegerAttrType] = IntegerAttrTypeConstr,
) -> GenericAttrConstraint[IntegerAttr[_IntegerAttrType]]:
if value is None and type == AnyAttr():
return BaseAttr[IntegerAttr[_IntegerAttrType]](IntegerAttr)
return ParamAttrConstraint[IntegerAttr[_IntegerAttrType]](
IntegerAttr,
(
value,
type,
),
)


AnyIntegerAttr: TypeAlias = IntegerAttr[IntegerType | IndexType]
Expand Down Expand Up @@ -1598,29 +1597,28 @@ def get_strides(self) -> Sequence[int | None] | None:
case _:
return self.layout.get_strides()

@classmethod
def constr(
cls,
*,
shape: GenericAttrConstraint[Attribute] | None = None,
element_type: GenericAttrConstraint[_MemRefTypeElement] = AnyAttr(),
layout: GenericAttrConstraint[Attribute] | None = None,
memory_space: GenericAttrConstraint[Attribute] | None = None,
) -> GenericAttrConstraint[MemRefType[_MemRefTypeElement]]:
if (
shape is None
and element_type == AnyAttr()
and layout is None
and memory_space is None
):
return BaseAttr[MemRefType[_MemRefTypeElement]](MemRefType)
return ParamAttrConstraint[MemRefType[_MemRefTypeElement]](
MemRefType, (shape, element_type, layout, memory_space)
)

def MemRefTypeConstr(
*,
shape: GenericAttrConstraint[Attribute] | None = None,
element_type: GenericAttrConstraint[_MemRefTypeElement] = AnyAttr(),
layout: GenericAttrConstraint[Attribute] | None = None,
memory_space: GenericAttrConstraint[Attribute] | None = None,
) -> GenericAttrConstraint[MemRefType[_MemRefTypeElement]]:
if (
shape is None
and element_type == AnyAttr()
and layout is None
and memory_space is None
):
return BaseAttr[MemRefType[_MemRefTypeElement]](MemRefType)
return ParamAttrConstraint[MemRefType[_MemRefTypeElement]](
MemRefType, (shape, element_type, layout, memory_space)
)


AnyMemRefType: TypeAlias = MemRefType[Attribute]
AnyMemRefTypeConstr = BaseAttr[MemRefType[Attribute]](MemRefType)
AnyMemRefTypeConstr = MemRefTypeConstr()


@irdl_attr_definition
Expand Down
9 changes: 3 additions & 6 deletions xdsl/dialects/csl/csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
IntegerAttr,
IntegerType,
MemRefType,
MemRefTypeConstr,
ModuleOp,
Signedness,
StringAttr,
Expand Down Expand Up @@ -601,9 +602,7 @@ class ZerosOp(IRDLOperation):

size = opt_operand_def(T)

result = result_def(
MemRefType[IntegerType | Float32Type | Float16Type].constr(element_type=T)
)
result = result_def(MemRefTypeConstr(element_type=T))

is_const = opt_prop_def(builtin.UnitAttr)

Expand Down Expand Up @@ -640,9 +639,7 @@ class ConstantsOp(IRDLOperation):

value = operand_def(T)

result = result_def(
MemRefType[IntegerType | Float32Type | Float16Type].constr(element_type=T)
)
result = result_def(MemRefTypeConstr(element_type=T))

is_const = opt_prop_def(builtin.UnitAttr)

Expand Down
16 changes: 8 additions & 8 deletions xdsl/dialects/csl/csl_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ class PrefetchOp(IRDLOperation):
name = "csl_stencil.prefetch"

input_stencil = operand_def(
stencil.StencilTypeConstr | AnyMemRefTypeConstr | AnyTensorTypeConstr
stencil.AnyStencilTypeConstr | AnyMemRefTypeConstr | AnyTensorTypeConstr
)

swaps = prop_def(builtin.ArrayAttr[ExchangeDeclarationAttr])
Expand Down Expand Up @@ -202,12 +202,12 @@ class ApplyOp(IRDLOperation):

name = "csl_stencil.apply"

field = operand_def(stencil.StencilTypeConstr | AnyMemRefTypeConstr)
field = operand_def(stencil.AnyStencilTypeConstr | AnyMemRefTypeConstr)

accumulator = operand_def(AnyTensorTypeConstr | AnyMemRefTypeConstr)

args = var_operand_def(Attribute)
dest = var_operand_def(stencil.FieldTypeConstr | AnyMemRefTypeConstr)
dest = var_operand_def(stencil.AnyFieldTypeConstr | AnyMemRefTypeConstr)

receive_chunk = region_def()
done_exchange = region_def()
Expand All @@ -220,7 +220,7 @@ class ApplyOp(IRDLOperation):

bounds = opt_prop_def(stencil.StencilBoundsAttr)

res = var_result_def(stencil.StencilTypeConstr)
res = var_result_def(stencil.AnyStencilTypeConstr)

traits = frozenset(
[
Expand Down Expand Up @@ -376,7 +376,7 @@ def get_rank(self) -> int:
res_type = self.dest[0].type
else:
res_type = self.res[0].type
if isattr(res_type, stencil.StencilTypeConstr):
if isattr(res_type, stencil.AnyStencilTypeConstr):
return res_type.get_num_dims()
elif self.bounds:
return len(self.bounds.ub)
Expand Down Expand Up @@ -423,7 +423,7 @@ class AccessOp(IRDLOperation):

name = "csl_stencil.access"
op = operand_def(
AnyMemRefTypeConstr | stencil.StencilTypeConstr | AnyTensorTypeConstr
AnyMemRefTypeConstr | stencil.AnyStencilTypeConstr | AnyTensorTypeConstr
)
offset = prop_def(stencil.IndexAttr)
offset_mapping = opt_prop_def(stencil.IndexAttr)
Expand Down Expand Up @@ -502,7 +502,7 @@ def parse(cls, parser: Parser):
props["offset_mapping"] = stencil.IndexAttr.get(*offset_mapping)
parser.parse_punctuation(":")
res_type = parser.parse_attribute()
if isattr(res_type, stencil.StencilTypeConstr):
if isattr(res_type, stencil.AnyStencilTypeConstr):
return cls.build(
operands=[temp],
result_types=[res_type.get_element_type()],
Expand Down Expand Up @@ -535,7 +535,7 @@ def verify_(self) -> None:
raise VerifyException(
f"{type(self)} access to own data requires{self.op.type} but found {self.result.type}"
)
elif isattr(self.op.type, stencil.StencilTypeConstr):
elif isattr(self.op.type, stencil.AnyStencilTypeConstr):
if not self.result.type == self.op.type.get_element_type():
raise VerifyException(
f"{type(self)} access to own data requires{self.op.type.get_element_type()} but found {self.result.type}"
Expand Down
2 changes: 1 addition & 1 deletion xdsl/dialects/experimental/dmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,7 @@ class SwapOp(IRDLOperation):

name = "dmp.swap"

input_stencil = operand_def(stencil.StencilTypeConstr)
input_stencil = operand_def(stencil.AnyStencilTypeConstr)
swapped_values = opt_result_def(stencil.TempType[Attribute])

swaps = attr_def(builtin.ArrayAttr[ExchangeDeclarationAttr])
Expand Down
9 changes: 4 additions & 5 deletions xdsl/dialects/memref.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
IntegerType,
MemrefLayoutAttr,
MemRefType,
MemRefTypeConstr,
NoneAttr,
SignlessIntegerConstraint,
StridedLayoutAttr,
Expand Down Expand Up @@ -76,7 +77,7 @@ class Load(IRDLOperation):

nontemporal = opt_prop_def(BoolAttr)

memref = operand_def(MemRefType[Attribute].constr(element_type=T))
memref = operand_def(MemRefTypeConstr(element_type=T))
indices = var_operand_def(IndexType())
res = result_def(T)

Expand Down Expand Up @@ -116,7 +117,7 @@ class Store(IRDLOperation):
nontemporal = opt_prop_def(BoolAttr)

value = operand_def(T)
memref = operand_def(MemRefType[Attribute].constr(element_type=T))
memref = operand_def(MemRefTypeConstr(element_type=T))
indices = var_operand_def(IndexType())

irdl_options = [ParsePropInAttrDict()]
Expand Down Expand Up @@ -365,9 +366,7 @@ class AtomicRMWOp(IRDLOperation):
)

value = operand_def(T)
memref = operand_def(
MemRefType[AnyFloat | AnySignlessIntegerType].constr(element_type=T)
)
memref = operand_def(MemRefTypeConstr(element_type=T))
indices = var_operand_def(IndexType)

kind = prop_def(IntegerAttr[I64])
Expand Down
3 changes: 2 additions & 1 deletion xdsl/dialects/memref_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
IntAttr,
IntegerAttr,
IntegerType,
MemRefTypeConstr,
StringAttr,
)
from xdsl.dialects.utils import AbstractYieldOperation
Expand Down Expand Up @@ -851,7 +852,7 @@ class FillOp(IRDLOperation):

T: ClassVar[VarConstraint[Attribute]] = VarConstraint("T", AnyAttr())

memref = operand_def(memref.MemRefType[Attribute].constr(element_type=T))
memref = operand_def(MemRefTypeConstr(element_type=T))
value = operand_def(T)

assembly_format = "$memref `with` $value attr-dict `:` type($memref)"
Expand Down
Loading

0 comments on commit 9a30300

Please sign in to comment.