-
Notifications
You must be signed in to change notification settings - Fork 80
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
dialects: (builtin) make DenseIntOrFPElementsAttr generic on element type #3492
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -74,6 +74,7 @@ | |||||
SymbolTable, | ||||||
) | ||||||
from xdsl.utils.exceptions import DiagnosticException, VerifyException | ||||||
from xdsl.utils.hints import isa | ||||||
from xdsl.utils.isattr import isattr | ||||||
|
||||||
if TYPE_CHECKING: | ||||||
|
@@ -1715,17 +1716,20 @@ def get_element_type(self) -> _UnrankedMemrefTypeElems: | |||||
VectorType[AttributeCovT] | TensorType[AttributeCovT] | MemRefType[AttributeCovT] | ||||||
) | ||||||
|
||||||
AnyDenseElement: TypeAlias = IntegerType | IndexType | AnyFloat | ||||||
DenseElementT = TypeVar("DenseElementT", bound=AnyDenseElement, covariant=True) | ||||||
_DenseElementT = TypeVar("_DenseElementT", bound=AnyDenseElement) | ||||||
FloatTypeT = TypeVar("FloatTypeT", bound=AnyFloat) | ||||||
|
||||||
|
||||||
@irdl_attr_definition | ||||||
class DenseIntOrFPElementsAttr( | ||||||
ParametrizedAttribute, ContainerType[IntegerType | IndexType | AnyFloat] | ||||||
Generic[DenseElementT], | ||||||
TypedAttribute, | ||||||
ContainerType[DenseElementT], | ||||||
): | ||||||
name = "dense" | ||||||
type: ParameterDef[ | ||||||
RankedStructure[IntegerType] | ||||||
| RankedStructure[IndexType] | ||||||
| RankedStructure[AnyFloat] | ||||||
] | ||||||
type: ParameterDef[RankedStructure[DenseElementT]] | ||||||
data: ParameterDef[ArrayAttr[AnyIntegerAttr] | ArrayAttr[AnyFloatAttr]] | ||||||
|
||||||
# The type stores the shape data | ||||||
|
@@ -1734,7 +1738,7 @@ def get_shape(self) -> tuple[int, ...] | None: | |||||
return None | ||||||
return self.type.get_shape() | ||||||
|
||||||
def get_element_type(self) -> IntegerType | IndexType | AnyFloat: | ||||||
def get_element_type(self) -> DenseElementT: | ||||||
return self.type.get_element_type() | ||||||
|
||||||
@property | ||||||
|
@@ -1757,21 +1761,21 @@ def shape_is_complete(self) -> bool: | |||||
def create_dense_index( | ||||||
type: RankedStructure[IndexType], | ||||||
data: Sequence[int] | Sequence[IntegerAttr[IndexType]], | ||||||
) -> DenseIntOrFPElementsAttr: | ||||||
) -> DenseIntOrFPElementsAttr[IndexType]: | ||||||
if len(data) and isinstance(data[0], int): | ||||||
attr_list = [ | ||||||
IntegerAttr.from_index_int_value(d) for d in cast(Sequence[int], data) | ||||||
] | ||||||
else: | ||||||
attr_list = cast(Sequence[IntegerAttr[IndexType]], data) | ||||||
|
||||||
return DenseIntOrFPElementsAttr([type, ArrayAttr(attr_list)]) | ||||||
return DenseIntOrFPElementsAttr[IndexType]([type, ArrayAttr(attr_list)]) | ||||||
|
||||||
@staticmethod | ||||||
def create_dense_int( | ||||||
type: RankedStructure[IntegerType], | ||||||
data: Sequence[int] | Sequence[IntegerAttr[IntegerType]], | ||||||
) -> DenseIntOrFPElementsAttr: | ||||||
) -> DenseIntOrFPElementsAttr[IntegerType]: | ||||||
if len(data) and isinstance(data[0], int): | ||||||
attr_list = [ | ||||||
IntegerAttr[IntegerType](d, type.element_type) | ||||||
|
@@ -1784,9 +1788,9 @@ def create_dense_int( | |||||
|
||||||
@staticmethod | ||||||
def create_dense_float( | ||||||
type: RankedStructure[AnyFloat], | ||||||
type: RankedStructure[FloatTypeT], | ||||||
data: Sequence[int | float] | Sequence[AnyFloatAttr], | ||||||
) -> DenseIntOrFPElementsAttr: | ||||||
) -> DenseIntOrFPElementsAttr[FloatTypeT]: | ||||||
if len(data) and isinstance(data[0], int | float): | ||||||
attr_list = [ | ||||||
FloatAttr(float(d), type.element_type) | ||||||
|
@@ -1797,64 +1801,40 @@ def create_dense_float( | |||||
|
||||||
return DenseIntOrFPElementsAttr([type, ArrayAttr(attr_list)]) | ||||||
|
||||||
@overload | ||||||
@staticmethod | ||||||
def from_list( | ||||||
type: ( | ||||||
RankedStructure[AnyFloat | IntegerType | IndexType] | ||||||
| RankedStructure[AnyFloat] | ||||||
| RankedStructure[IntegerType] | ||||||
| RankedStructure[IndexType] | ||||||
), | ||||||
data: ( | ||||||
Sequence[int] | ||||||
| Sequence[IntegerAttr[IndexType]] | ||||||
| Sequence[IntegerAttr[IntegerType]] | ||||||
), | ||||||
) -> DenseIntOrFPElementsAttr: ... | ||||||
|
||||||
@overload | ||||||
@staticmethod | ||||||
def from_list( | ||||||
type: ( | ||||||
RankedStructure[AnyFloat | IntegerType | IndexType] | ||||||
| RankedStructure[AnyFloat] | ||||||
| RankedStructure[IntegerType] | ||||||
| RankedStructure[IndexType] | ||||||
), | ||||||
data: Sequence[int | float] | Sequence[AnyFloatAttr], | ||||||
) -> DenseIntOrFPElementsAttr: ... | ||||||
|
||||||
@staticmethod | ||||||
def from_list( | ||||||
type: ( | ||||||
RankedStructure[AnyFloat | IntegerType | IndexType] | ||||||
| RankedStructure[AnyFloat] | ||||||
| RankedStructure[IntegerType] | ||||||
| RankedStructure[IndexType] | ||||||
), | ||||||
type: RankedStructure[_DenseElementT], | ||||||
data: Sequence[int | float] | Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr], | ||||||
) -> DenseIntOrFPElementsAttr: | ||||||
) -> DenseIntOrFPElementsAttr[_DenseElementT]: | ||||||
if isinstance(type.element_type, AnyFloat): | ||||||
new_type = cast(RankedStructure[AnyFloat], type) | ||||||
new_data = cast(Sequence[int | float] | Sequence[FloatAttr[AnyFloat]], data) | ||||||
return DenseIntOrFPElementsAttr.create_dense_float(new_type, new_data) | ||||||
return cast( | ||||||
DenseIntOrFPElementsAttr[_DenseElementT], | ||||||
DenseIntOrFPElementsAttr.create_dense_float(new_type, new_data), | ||||||
) | ||||||
elif isinstance(type.element_type, IntegerType): | ||||||
new_type = cast(RankedStructure[IntegerType], type) | ||||||
new_data = cast(Sequence[int] | Sequence[IntegerAttr[IntegerType]], data) | ||||||
return DenseIntOrFPElementsAttr.create_dense_int(new_type, new_data) | ||||||
return cast( | ||||||
DenseIntOrFPElementsAttr[_DenseElementT], | ||||||
DenseIntOrFPElementsAttr.create_dense_int(new_type, new_data), | ||||||
) | ||||||
else: | ||||||
new_type = cast(RankedStructure[IndexType], type) | ||||||
new_data = cast(Sequence[int] | Sequence[IntegerAttr[IndexType]], data) | ||||||
return DenseIntOrFPElementsAttr.create_dense_index(new_type, new_data) | ||||||
return cast( | ||||||
DenseIntOrFPElementsAttr[_DenseElementT], | ||||||
DenseIntOrFPElementsAttr.create_dense_index(new_type, new_data), | ||||||
) | ||||||
|
||||||
@staticmethod | ||||||
def vector_from_list( | ||||||
data: Sequence[int] | Sequence[float], | ||||||
data_type: IntegerType | IndexType | AnyFloat, | ||||||
) -> DenseIntOrFPElementsAttr: | ||||||
data_type: _DenseElementT, | ||||||
) -> DenseIntOrFPElementsAttr[_DenseElementT]: | ||||||
t = VectorType(data_type, [len(data)]) | ||||||
return DenseIntOrFPElementsAttr.from_list(t, data) | ||||||
return DenseIntOrFPElementsAttr[_DenseElementT].from_list(t, data) | ||||||
|
||||||
@staticmethod | ||||||
def tensor_from_list( | ||||||
|
@@ -1865,11 +1845,72 @@ def tensor_from_list( | |||||
| Sequence[IntegerAttr[IntegerType]] | ||||||
| Sequence[AnyFloatAttr] | ||||||
), | ||||||
data_type: IntegerType | IndexType | AnyFloat, | ||||||
data_type: _DenseElementT, | ||||||
shape: Sequence[int], | ||||||
) -> DenseIntOrFPElementsAttr: | ||||||
) -> DenseIntOrFPElementsAttr[_DenseElementT]: | ||||||
t = TensorType(data_type, shape) | ||||||
return DenseIntOrFPElementsAttr.from_list(t, data) | ||||||
return DenseIntOrFPElementsAttr[_DenseElementT].from_list(t, data) | ||||||
|
||||||
@staticmethod | ||||||
def parse_with_type(parser: AttrParser, type: Attribute) -> TypedAttribute: | ||||||
assert ( | ||||||
isa(type, VectorType[AnyDenseElement]) | ||||||
or isa(type, TensorType[AnyDenseElement]) | ||||||
or isa(type, MemRefType[AnyDenseElement]) | ||||||
) | ||||||
|
||||||
return parser.parse_dense_int_or_fp_elements_attr(type) | ||||||
|
||||||
@staticmethod | ||||||
def _print_one_elem(val: Attribute, printer: Printer): | ||||||
if isinstance(val, IntegerAttr): | ||||||
printer.print_string(f"{val.value.data}") | ||||||
elif isinstance(val, FloatAttr): | ||||||
printer.print_float(cast(AnyFloatAttr, val)) | ||||||
else: | ||||||
raise Exception( | ||||||
"unexpected attribute type " | ||||||
"in DenseIntOrFPElementsAttr: " | ||||||
f"{type(val)}" | ||||||
) | ||||||
|
||||||
@staticmethod | ||||||
def _print_dense_list( | ||||||
array: Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr], | ||||||
shape: Sequence[int], | ||||||
printer: Printer, | ||||||
): | ||||||
printer.print_string("[") | ||||||
if len(shape) > 1: | ||||||
k = len(array) // shape[0] | ||||||
printer.print_list( | ||||||
(array[i : i + k] for i in range(0, len(array), k)), | ||||||
lambda subarray: DenseIntOrFPElementsAttr._print_dense_list( | ||||||
subarray, shape[1:], printer | ||||||
), | ||||||
) | ||||||
else: | ||||||
printer.print_list( | ||||||
array, | ||||||
lambda val: DenseIntOrFPElementsAttr._print_one_elem(val, printer), | ||||||
) | ||||||
printer.print_string("]") | ||||||
|
||||||
def print_without_type(self, printer: Printer): | ||||||
printer.print_string("dense<") | ||||||
data = self.data.data | ||||||
shape = self.get_shape() if self.shape_is_complete else (len(data),) | ||||||
assert shape is not None, "If shape is complete, then it cannot be None" | ||||||
if len(data) == 0: | ||||||
pass | ||||||
elif data.count(data[0]) == len(data): | ||||||
DenseIntOrFPElementsAttr._print_one_elem(data[0], printer) | ||||||
else: | ||||||
DenseIntOrFPElementsAttr._print_dense_list(data, shape, printer) | ||||||
printer.print_string(">") | ||||||
|
||||||
|
||||||
DenseIntElementsAttr: TypeAlias = DenseIntOrFPElementsAttr[IntegerType] | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This line feels like it could have its own PR, we could probably have a custom constraint to start with to minimise the diff There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I still feel like this could be its own PR :) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. You want a separate PR just for the type alias? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah that was the idea, it technically is its own API surface change and seems to be a large contribution to this PR's diff, independent of the other changes. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see how it makes any impact on the diff at all. Anything that is changed to a DenseIntElementAttr would have had to be changed to a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Sorry that's what I mean, this change in the diff of this PR could already be done in main directly, and is not dependent on the changes you make here. Or am I misunderstanding? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't see how to separate the changes. I can't change I could potentially do two steps of changing every There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah sorry that's not what I meant, I meant something like
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah I get it now, I could give that a go |
||||||
|
||||||
|
||||||
Builtin = Dialect( | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the equivalent of sized in MLIR? Would it make sense to add this independently of the other changes?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I doubt it is as
IndexType
is not sized right? (AlsoIndexType
should be forbidden fromDenseIntOrFPElementsAttr
but that's another matter)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it's quite relevant, and we should forbid it indeed, I thought we already made that change recently.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah it was DenseArrayBase #3258
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think in any case that should be a different PR (which could happen before or after this one)