Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dialects: (builtin) make DenseIntOrFPElementsAttr generic on element type #3492

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/Toy/toy/dialects/toy.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ class ConstantOp(IRDLOperation):
"""

name = "toy.constant"
value = attr_def(DenseIntOrFPElementsAttr)
value = attr_def(DenseIntOrFPElementsAttr[Float64Type])
res = result_def(TensorTypeF64)

traits = traits_def(Pure())

def __init__(self, value: DenseIntOrFPElementsAttr):
def __init__(self, value: DenseIntOrFPElementsAttr[Float64Type]):
super().__init__(result_types=[value.type], attributes={"value": value})

@staticmethod
Expand Down
4 changes: 2 additions & 2 deletions xdsl/backend/csl/print_csl.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,8 +449,8 @@ def attribute_value_to_str(self, attr: Attribute) -> str:
return str(val.data)
case StringAttr() as s:
return f'"{s.data}"'
case DenseIntOrFPElementsAttr(data=ArrayAttr(data=data), type=typ):
return f"{self.mlir_type_to_csl_type(typ)} {{ {', '.join(self.attribute_value_to_str(d) for d in data)} }}"
case DenseIntOrFPElementsAttr(data=ArrayAttr(data=data)):
return f"{self.mlir_type_to_csl_type(attr.get_type())} {{ {', '.join(self.attribute_value_to_str(d) for d in data)} }}"
case _:
return f"<!unknown value {attr}>"

Expand Down
6 changes: 3 additions & 3 deletions xdsl/dialects/affine.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
AnyIntegerAttr,
ArrayAttr,
ContainerType,
DenseIntOrFPElementsAttr,
DenseIntElementsAttr,
IndexType,
IntegerAttr,
IntegerType,
Expand Down Expand Up @@ -224,9 +224,9 @@ class ParallelOp(IRDLOperation):

reductions = prop_def(ArrayAttr[StringAttr])
lowerBoundsMap = prop_def(AffineMapAttr)
lowerBoundsGroups = prop_def(DenseIntOrFPElementsAttr)
lowerBoundsGroups = prop_def(DenseIntElementsAttr)
upperBoundsMap = prop_def(AffineMapAttr)
upperBoundsGroups = prop_def(DenseIntOrFPElementsAttr)
upperBoundsGroups = prop_def(DenseIntElementsAttr)
steps = prop_def(ArrayAttr[IntegerAttr[IntegerType]])

res = var_result_def()
Expand Down
7 changes: 5 additions & 2 deletions xdsl/dialects/arith.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import ClassVar, Literal, TypeVar, cast, overload

from xdsl.dialects.builtin import (
AnyDenseElement,
AnyFloat,
AnyFloatConstr,
AnyIntegerAttr,
Expand Down Expand Up @@ -132,7 +133,9 @@ class Constant(IRDLOperation):
@overload
def __init__(
self,
value: AnyIntegerAttr | FloatAttr[AnyFloat] | DenseIntOrFPElementsAttr,
value: AnyIntegerAttr
| FloatAttr[AnyFloat]
| DenseIntOrFPElementsAttr[AnyDenseElement],
value_type: None = None,
) -> None: ...

Expand Down Expand Up @@ -179,7 +182,7 @@ def parse(cls: type[Constant], parser: Parser) -> Constant:
value,
base(AnyIntegerAttr)
| base(FloatAttr[AnyFloat])
| base(DenseIntOrFPElementsAttr),
| base(DenseIntOrFPElementsAttr[AnyDenseElement]),
):
parser.raise_error("Invalid constant value", p0, parser.pos)

Expand Down
153 changes: 97 additions & 56 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
SymbolTable,
)
from xdsl.utils.exceptions import DiagnosticException, VerifyException
from xdsl.utils.hints import isa
from xdsl.utils.isattr import isattr

if TYPE_CHECKING:
Expand Down Expand Up @@ -1715,17 +1716,20 @@ def get_element_type(self) -> _UnrankedMemrefTypeElems:
VectorType[AttributeCovT] | TensorType[AttributeCovT] | MemRefType[AttributeCovT]
)

AnyDenseElement: TypeAlias = IntegerType | IndexType | AnyFloat
DenseElementT = TypeVar("DenseElementT", bound=AnyDenseElement, covariant=True)
_DenseElementT = TypeVar("_DenseElementT", bound=AnyDenseElement)
Comment on lines +1719 to +1721
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the equivalent of sized in MLIR? Would it make sense to add this independently of the other changes?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I doubt it is as IndexType is not sized right? (Also IndexType should be forbidden from DenseIntOrFPElementsAttr but that's another matter)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's quite relevant, and we should forbid it indeed, I thought we already made that change recently.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah it was DenseArrayBase #3258

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think in any case that should be a different PR (which could happen before or after this one)

FloatTypeT = TypeVar("FloatTypeT", bound=AnyFloat)


@irdl_attr_definition
class DenseIntOrFPElementsAttr(
ParametrizedAttribute, ContainerType[IntegerType | IndexType | AnyFloat]
Generic[DenseElementT],
TypedAttribute,
ContainerType[DenseElementT],
):
name = "dense"
type: ParameterDef[
RankedStructure[IntegerType]
| RankedStructure[IndexType]
| RankedStructure[AnyFloat]
]
type: ParameterDef[RankedStructure[DenseElementT]]
data: ParameterDef[ArrayAttr[AnyIntegerAttr] | ArrayAttr[AnyFloatAttr]]

# The type stores the shape data
Expand All @@ -1734,7 +1738,7 @@ def get_shape(self) -> tuple[int, ...] | None:
return None
return self.type.get_shape()

def get_element_type(self) -> IntegerType | IndexType | AnyFloat:
def get_element_type(self) -> DenseElementT:
return self.type.get_element_type()

@property
Expand All @@ -1757,21 +1761,21 @@ def shape_is_complete(self) -> bool:
def create_dense_index(
type: RankedStructure[IndexType],
data: Sequence[int] | Sequence[IntegerAttr[IndexType]],
) -> DenseIntOrFPElementsAttr:
) -> DenseIntOrFPElementsAttr[IndexType]:
if len(data) and isinstance(data[0], int):
attr_list = [
IntegerAttr.from_index_int_value(d) for d in cast(Sequence[int], data)
]
else:
attr_list = cast(Sequence[IntegerAttr[IndexType]], data)

return DenseIntOrFPElementsAttr([type, ArrayAttr(attr_list)])
return DenseIntOrFPElementsAttr[IndexType]([type, ArrayAttr(attr_list)])

@staticmethod
def create_dense_int(
type: RankedStructure[IntegerType],
data: Sequence[int] | Sequence[IntegerAttr[IntegerType]],
) -> DenseIntOrFPElementsAttr:
) -> DenseIntOrFPElementsAttr[IntegerType]:
if len(data) and isinstance(data[0], int):
attr_list = [
IntegerAttr[IntegerType](d, type.element_type)
Expand All @@ -1784,9 +1788,9 @@ def create_dense_int(

@staticmethod
def create_dense_float(
type: RankedStructure[AnyFloat],
type: RankedStructure[FloatTypeT],
data: Sequence[int | float] | Sequence[AnyFloatAttr],
) -> DenseIntOrFPElementsAttr:
) -> DenseIntOrFPElementsAttr[FloatTypeT]:
if len(data) and isinstance(data[0], int | float):
attr_list = [
FloatAttr(float(d), type.element_type)
Expand All @@ -1797,64 +1801,40 @@ def create_dense_float(

return DenseIntOrFPElementsAttr([type, ArrayAttr(attr_list)])

@overload
@staticmethod
def from_list(
type: (
RankedStructure[AnyFloat | IntegerType | IndexType]
| RankedStructure[AnyFloat]
| RankedStructure[IntegerType]
| RankedStructure[IndexType]
),
data: (
Sequence[int]
| Sequence[IntegerAttr[IndexType]]
| Sequence[IntegerAttr[IntegerType]]
),
) -> DenseIntOrFPElementsAttr: ...

@overload
@staticmethod
def from_list(
type: (
RankedStructure[AnyFloat | IntegerType | IndexType]
| RankedStructure[AnyFloat]
| RankedStructure[IntegerType]
| RankedStructure[IndexType]
),
data: Sequence[int | float] | Sequence[AnyFloatAttr],
) -> DenseIntOrFPElementsAttr: ...

@staticmethod
def from_list(
type: (
RankedStructure[AnyFloat | IntegerType | IndexType]
| RankedStructure[AnyFloat]
| RankedStructure[IntegerType]
| RankedStructure[IndexType]
),
type: RankedStructure[_DenseElementT],
data: Sequence[int | float] | Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr],
) -> DenseIntOrFPElementsAttr:
) -> DenseIntOrFPElementsAttr[_DenseElementT]:
if isinstance(type.element_type, AnyFloat):
new_type = cast(RankedStructure[AnyFloat], type)
new_data = cast(Sequence[int | float] | Sequence[FloatAttr[AnyFloat]], data)
return DenseIntOrFPElementsAttr.create_dense_float(new_type, new_data)
return cast(
DenseIntOrFPElementsAttr[_DenseElementT],
DenseIntOrFPElementsAttr.create_dense_float(new_type, new_data),
)
elif isinstance(type.element_type, IntegerType):
new_type = cast(RankedStructure[IntegerType], type)
new_data = cast(Sequence[int] | Sequence[IntegerAttr[IntegerType]], data)
return DenseIntOrFPElementsAttr.create_dense_int(new_type, new_data)
return cast(
DenseIntOrFPElementsAttr[_DenseElementT],
DenseIntOrFPElementsAttr.create_dense_int(new_type, new_data),
)
else:
new_type = cast(RankedStructure[IndexType], type)
new_data = cast(Sequence[int] | Sequence[IntegerAttr[IndexType]], data)
return DenseIntOrFPElementsAttr.create_dense_index(new_type, new_data)
return cast(
DenseIntOrFPElementsAttr[_DenseElementT],
DenseIntOrFPElementsAttr.create_dense_index(new_type, new_data),
)

@staticmethod
def vector_from_list(
data: Sequence[int] | Sequence[float],
data_type: IntegerType | IndexType | AnyFloat,
) -> DenseIntOrFPElementsAttr:
data_type: _DenseElementT,
) -> DenseIntOrFPElementsAttr[_DenseElementT]:
t = VectorType(data_type, [len(data)])
return DenseIntOrFPElementsAttr.from_list(t, data)
return DenseIntOrFPElementsAttr[_DenseElementT].from_list(t, data)

@staticmethod
def tensor_from_list(
Expand All @@ -1865,11 +1845,72 @@ def tensor_from_list(
| Sequence[IntegerAttr[IntegerType]]
| Sequence[AnyFloatAttr]
),
data_type: IntegerType | IndexType | AnyFloat,
data_type: _DenseElementT,
shape: Sequence[int],
) -> DenseIntOrFPElementsAttr:
) -> DenseIntOrFPElementsAttr[_DenseElementT]:
t = TensorType(data_type, shape)
return DenseIntOrFPElementsAttr.from_list(t, data)
return DenseIntOrFPElementsAttr[_DenseElementT].from_list(t, data)

@staticmethod
def parse_with_type(parser: AttrParser, type: Attribute) -> TypedAttribute:
assert (
isa(type, VectorType[AnyDenseElement])
or isa(type, TensorType[AnyDenseElement])
or isa(type, MemRefType[AnyDenseElement])
)

return parser.parse_dense_int_or_fp_elements_attr(type)

@staticmethod
def _print_one_elem(val: Attribute, printer: Printer):
if isinstance(val, IntegerAttr):
printer.print_string(f"{val.value.data}")
elif isinstance(val, FloatAttr):
printer.print_float(cast(AnyFloatAttr, val))
else:
raise Exception(
"unexpected attribute type "
"in DenseIntOrFPElementsAttr: "
f"{type(val)}"
)

@staticmethod
def _print_dense_list(
array: Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr],
shape: Sequence[int],
printer: Printer,
):
printer.print_string("[")
if len(shape) > 1:
k = len(array) // shape[0]
printer.print_list(
(array[i : i + k] for i in range(0, len(array), k)),
lambda subarray: DenseIntOrFPElementsAttr._print_dense_list(
subarray, shape[1:], printer
),
)
else:
printer.print_list(
array,
lambda val: DenseIntOrFPElementsAttr._print_one_elem(val, printer),
)
printer.print_string("]")

def print_without_type(self, printer: Printer):
printer.print_string("dense<")
data = self.data.data
shape = self.get_shape() if self.shape_is_complete else (len(data),)
assert shape is not None, "If shape is complete, then it cannot be None"
if len(data) == 0:
pass
elif data.count(data[0]) == len(data):
DenseIntOrFPElementsAttr._print_one_elem(data[0], printer)
else:
DenseIntOrFPElementsAttr._print_dense_list(data, shape, printer)
printer.print_string(">")


DenseIntElementsAttr: TypeAlias = DenseIntOrFPElementsAttr[IntegerType]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This line feels like it could have its own PR, we could probably have a custom constraint to start with to minimise the diff

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still feel like this could be its own PR :)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You want a separate PR just for the type alias?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that was the idea, it technically is its own API surface change and seems to be a large contribution to this PR's diff, independent of the other changes.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how it makes any impact on the diff at all. Anything that is changed to a DenseIntElementAttr would have had to be changed to a DenseIntOrFPElementsAttr[IntegerType] or at least DenseIntOrFPElementsAttr[AnyDenseElement]

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry that's what I mean, this change in the diff of this PR could already be done in main directly, and is not dependent on the changes you make here. Or am I misunderstanding?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't see how to separate the changes. I can't change DenseIntOrFPElementsAttr to DenseIntOrFPElementsAttr[AnyDenseElement] before this PR.

I could potentially do two steps of changing every DenseIntOrFPElementsAttr to DenseIntOrFPElementsAttr[AnyDenseElement] and then do a second PR specialising some of these to DenseIntElementsAttr = DenseIntOrFPElementsAttr[IntegerType], but I don't think the first diff will be any smaller.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah sorry that's not what I meant, I meant something like

Suggested change
DenseIntElementsAttr: TypeAlias = DenseIntOrFPElementsAttr[IntegerType]
DenseIntElementsAttr: TypeAlias = Annotated[DenseIntOrFPElementsAttr, ParametrizedAttrConstraint(Attribute, (BaseAttr(IntegerType),AnyAttr()))]

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I get it now, I could give that a go



Builtin = Dialect(
Expand Down
13 changes: 6 additions & 7 deletions xdsl/dialects/cf.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,7 @@

from xdsl.dialects.builtin import (
DenseArrayBase,
DenseIntOrFPElementsAttr,
IndexType,
DenseIntElementsAttr,
IndexTypeConstr,
IntegerType,
SignlessIntegerConstraint,
Expand Down Expand Up @@ -178,7 +177,7 @@ class Switch(IRDLOperation):

name = "cf.switch"

case_values = opt_prop_def(DenseIntOrFPElementsAttr)
case_values = opt_prop_def(DenseIntElementsAttr)

flag = operand_def(IndexTypeConstr | SignlessIntegerConstraint)

Expand All @@ -202,7 +201,7 @@ def __init__(
flag: Operation | SSAValue,
default_block: Successor,
default_operands: Sequence[Operation | SSAValue],
case_values: DenseIntOrFPElementsAttr | None = None,
case_values: DenseIntElementsAttr | None = None,
case_blocks: Sequence[Successor] = [],
case_operands: Sequence[Sequence[Operation | SSAValue]] = [],
attr_dict: dict[str, Attribute] | None = None,
Expand Down Expand Up @@ -355,15 +354,15 @@ def parse(cls, parser: Parser) -> Self:
parser.parse_punctuation("[")
parser.parse_keyword("default")
(default_block, default_args) = cls._parse_case_body(parser)
case_values: DenseIntOrFPElementsAttr | None = None
case_values: DenseIntElementsAttr | None = None
case_blocks: tuple[Block, ...] = ()
case_operands: tuple[tuple[SSAValue, ...], ...] = ()
if parser.parse_optional_punctuation(","):
cases = parser.parse_comma_separated_list(
Parser.Delimiter.NONE, lambda: cls._parse_case(parser)
)
assert isinstance(flag_type, IntegerType | IndexType)
case_values = DenseIntOrFPElementsAttr.vector_from_list(
assert isinstance(flag_type, IntegerType)
case_values = DenseIntElementsAttr.vector_from_list(
[x for (x, _, _) in cases], flag_type
)
case_blocks = tuple(x for (_, x, _) in cases)
Expand Down
10 changes: 5 additions & 5 deletions xdsl/dialects/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
AnyTensorType,
ArrayAttr,
DenseArrayBase,
DenseIntOrFPElementsAttr,
DenseIntElementsAttr,
IntegerType,
MemRefType,
ShapedType,
Expand Down Expand Up @@ -961,8 +961,8 @@ class PoolingOpsBase(IRDLOperation, ABC):
"`outs` `(` $outputs `:` type($outputs) `)` `->` type($res)"
)

strides = attr_def(DenseIntOrFPElementsAttr)
dilations = attr_def(DenseIntOrFPElementsAttr)
strides = attr_def(DenseIntElementsAttr)
dilations = attr_def(DenseIntElementsAttr)

irdl_options = [AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()]

Expand Down Expand Up @@ -1012,8 +1012,8 @@ class ConvOpsBase(IRDLOperation, ABC):
"`outs` `(` $outputs `:` type($outputs) `)` `->` type($res)"
)

strides = attr_def(DenseIntOrFPElementsAttr)
dilations = attr_def(DenseIntOrFPElementsAttr)
strides = attr_def(DenseIntElementsAttr)
dilations = attr_def(DenseIntElementsAttr)

irdl_options = [AttrSizedOperandSegments(as_property=True), ParsePropInAttrDict()]

Expand Down
Loading
Loading