From dffa3051c8b309ba3e4b516e2cfeb315e6c28e39 Mon Sep 17 00:00:00 2001 From: Alex Rice Date: Fri, 22 Nov 2024 09:16:43 +0000 Subject: [PATCH] core: remove generic check on TypedAttribute --- tests/irdl/test_attribute_definition.py | 16 +------------ xdsl/dialects/builtin.py | 12 +++++----- xdsl/ir/core.py | 6 ++--- xdsl/irdl/attributes.py | 31 +++++++------------------ 4 files changed, 18 insertions(+), 47 deletions(-) diff --git a/tests/irdl/test_attribute_definition.py b/tests/irdl/test_attribute_definition.py index f13f9d9b88..8f403ca4e6 100644 --- a/tests/irdl/test_attribute_definition.py +++ b/tests/irdl/test_attribute_definition.py @@ -19,7 +19,6 @@ IntegerType, NoneAttr, Signedness, - StringAttr, ) from xdsl.ir import ( Attribute, @@ -252,23 +251,10 @@ def test_typed_attribute(): @irdl_attr_definition class TypedAttr( # pyright: ignore[reportUnusedClass] - TypedAttribute[Attribute] + TypedAttribute ): name = "test.typed" - with pytest.raises( - Exception, - match="A TypedAttribute `type` parameter must be of the same type as the type variable in the TypedAttribute base class.", - ): - - @irdl_attr_definition - class TypedAttrBis( # pyright: ignore[reportUnusedClass] - TypedAttribute[IntegerAttr[IndexType]] - ): - name = "test.typed" - - type: ParameterDef[StringAttr] - ################################################################################ # IntegerAttr diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 3cbccf68bc..fe3c33384e 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -449,7 +449,7 @@ class IndexType(ParametrizedAttribute): @irdl_attr_definition class IntegerAttr( Generic[_IntegerAttrType], - TypedAttribute[_IntegerAttrType], + TypedAttribute, ): name = "integer" value: ParameterDef[IntAttr] @@ -504,8 +504,8 @@ def verify(self) -> None: @staticmethod def parse_with_type( parser: AttrParser, - type: AttributeInvT, - ) -> TypedAttribute[AttributeInvT]: + type: Attribute, + ) -> TypedAttribute: assert isinstance(type, IntegerType | IndexType) return IntegerAttr(parser.parse_integer(allow_boolean=(type == i1)), type) @@ -634,7 +634,7 @@ def __hash__(self): @irdl_attr_definition -class FloatAttr(Generic[_FloatAttrType], TypedAttribute[_FloatAttrType]): +class FloatAttr(Generic[_FloatAttrType], TypedAttribute): name = "float" value: ParameterDef[FloatData] @@ -671,8 +671,8 @@ def __init__( @staticmethod def parse_with_type( parser: AttrParser, - type: AttributeInvT, - ) -> TypedAttribute[AttributeInvT]: + type: Attribute, + ) -> TypedAttribute: assert isinstance(type, AnyFloat) return FloatAttr(parser.parse_float(), type) diff --git a/xdsl/ir/core.py b/xdsl/ir/core.py index b4bece71de..698f52e25d 100644 --- a/xdsl/ir/core.py +++ b/xdsl/ir/core.py @@ -606,7 +606,7 @@ def _verify(self): super()._verify() -class TypedAttribute(ParametrizedAttribute, Generic[AttributeCovT], ABC): +class TypedAttribute(ParametrizedAttribute, ABC): """ An attribute with a type. """ @@ -617,8 +617,8 @@ def get_type_index(cls) -> int: ... @staticmethod def parse_with_type( parser: AttrParser, - type: AttributeInvT, - ) -> TypedAttribute[AttributeInvT]: + type: Attribute, + ) -> TypedAttribute: """ Parse the attribute with the given type. """ diff --git a/xdsl/irdl/attributes.py b/xdsl/irdl/attributes.py index 08db888feb..24d8cd97a8 100644 --- a/xdsl/irdl/attributes.py +++ b/xdsl/irdl/attributes.py @@ -24,7 +24,6 @@ from xdsl.ir import ( Attribute, - AttributeCovT, AttributeInvT, Data, ParametrizedAttribute, @@ -34,7 +33,6 @@ from xdsl.utils.hints import ( PropertyType, get_type_var_from_generic_class, - get_type_var_mapping, ) from xdsl.utils.runtime_final import runtime_final @@ -160,25 +158,6 @@ def from_pyrdl( name = clsdict["name"] param_hints = irdl_param_attr_get_param_type_hints(pyrdl_def) - if issubclass(pyrdl_def, TypedAttribute): - pyrdl_def = cast(type[TypedAttribute[Attribute]], pyrdl_def) - try: - param_names = [name for name, _ in param_hints] - type_index = param_names.index("type") - except ValueError: - raise PyRDLAttrDefinitionError( - f"TypedAttribute {pyrdl_def.__name__} should have a 'type' parameter." - ) - typed_hint = param_hints[type_index][1] - if get_origin(typed_hint) is Annotated: - typed_hint = get_args(typed_hint)[0] - type_var = get_type_var_mapping(pyrdl_def)[1][AttributeCovT] - - if typed_hint != type_var: - raise ValueError( - "A TypedAttribute `type` parameter must be of the same type" - " as the type variable in the TypedAttribute base class." - ) parameters = list[tuple[str, AttrConstraint]]() for param_name, param_type in param_hints: @@ -233,8 +212,14 @@ def irdl_param_attr_definition(cls: _PAttrTT) -> _PAttrTT: new_fields = get_accessors_from_param_attr_def(attr_def) if issubclass(cls, TypedAttribute): - parameter_names: tuple[str] = tuple(zip(*attr_def.parameters))[0] - type_index = parameter_names.index("type") + type_indexes = tuple( + i for i, (p, _) in enumerate(attr_def.parameters) if p == "type" + ) + if not type_indexes: + raise PyRDLAttrDefinitionError( + f"TypedAttribute {cls.__name__} should have a 'type' parameter." + ) + type_index = type_indexes[0] @classmethod def get_type_index(cls: Any) -> int: