Skip to content

Commit

Permalink
dialects: (builtin) change data representation of DenseIntOrFPElement…
Browse files Browse the repository at this point in the history
…s to use bytes (#3623)
  • Loading branch information
jorendumoulin authored Jan 8, 2025
1 parent 73834d3 commit 78bea80
Showing 1 changed file with 60 additions and 48 deletions.
108 changes: 60 additions & 48 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2004,7 +2004,7 @@ class DenseIntOrFPElementsAttr(TypedAttribute, ContainerType[AnyDenseElement]):
| RankedStructure[IndexType]
| RankedStructure[AnyFloat]
]
data: ParameterDef[ArrayAttr[AnyIntegerAttr] | ArrayAttr[AnyFloatAttr]]
data: ParameterDef[BytesAttr]

# The type stores the shape data
def get_shape(self) -> tuple[int, ...]:
Expand All @@ -2014,7 +2014,7 @@ def get_element_type(self) -> IntegerType | IndexType | AnyFloat:
return self.type.get_element_type()

def __len__(self) -> int:
return len(self.data)
return len(self.data.data) // self.type.element_type.compile_time_size

@property
def shape_is_complete(self) -> bool:
Expand All @@ -2028,51 +2028,64 @@ def shape_is_complete(self) -> bool:
n *= dim

# Product of dimensions needs to equal length
return n == len(self.data.data)
return n == len(self)

@staticmethod
def create_dense_index(
type: RankedStructure[IndexType],
data: Sequence[int] | Sequence[IntegerAttr[IndexType]],
) -> DenseIntOrFPElementsAttr:
if len(data) and isinstance(data[0], int):
attr_list = [
IntegerAttr.from_index_int_value(d) for d in cast(Sequence[int], data)
if len(data) and isinstance(data[0], IntegerAttr):
data = [
el.value.data for el in cast(Sequence[IntegerAttr[IndexType]], data)
]
else:
attr_list = cast(Sequence[IntegerAttr[IndexType]], data)
data = cast(Sequence[int], data)

return DenseIntOrFPElementsAttr([type, ArrayAttr(attr_list)])
return DenseIntOrFPElementsAttr([type, BytesAttr(type.element_type.pack(data))])

@staticmethod
def create_dense_int(
type: RankedStructure[IntegerType],
data: Sequence[int] | Sequence[IntegerAttr[IntegerType]],
) -> DenseIntOrFPElementsAttr:
if len(data) and isinstance(data[0], int):
attr_list = [
IntegerAttr[IntegerType](d, type.element_type)
for d in cast(Sequence[int], data)
if len(data) and isinstance(data[0], IntegerAttr):
data = [
el.value.data for el in cast(Sequence[IntegerAttr[IntegerType]], data)
]
else:
attr_list = cast(Sequence[IntegerAttr[IntegerType]], data)
data = cast(Sequence[int], data)

# ints are normalized
normalized_values = tuple(
type.element_type.normalized_value(value) for value in data
)

for value in normalized_values:
if value is None:
min_value, max_value = type.element_type.value_range()
raise ValueError(
f"Integer value {value} is out of range for type {type.element_type} which supports "
f"values in the range [{min_value}, {max_value})"
)

normalized_values = cast(Sequence[int], tuple(normalized_values))

return DenseIntOrFPElementsAttr([type, ArrayAttr(attr_list)])
return DenseIntOrFPElementsAttr(
[type, BytesAttr(type.element_type.pack(normalized_values))]
)

@staticmethod
def create_dense_float(
type: RankedStructure[AnyFloat],
data: Sequence[int | float] | Sequence[AnyFloatAttr],
) -> DenseIntOrFPElementsAttr:
if len(data) and isinstance(data[0], int | float):
attr_list = [
FloatAttr(float(d), type.element_type)
for d in cast(Sequence[int | float], data)
]
if len(data) and isa(data[0], AnyFloatAttr):
data = [el.value.data for el in cast(Sequence[AnyFloatAttr], data)]
else:
attr_list = cast(Sequence[AnyFloatAttr], data)
data = cast(Sequence[float], data)

return DenseIntOrFPElementsAttr([type, ArrayAttr(attr_list)])
return DenseIntOrFPElementsAttr([type, BytesAttr(type.element_type.pack(data))])

@overload
@staticmethod
Expand Down Expand Up @@ -2171,56 +2184,57 @@ def iter_values(self) -> Iterator[int] | Iterator[float]:
"""
Return an iterator over all the values of the elements in this DenseIntOrFPElementsAttr
"""
return (el.value.data for el in self.data.data)
return self.get_element_type().iter_unpack(self.data.data)

def get_values(self) -> Sequence[int] | Sequence[float]:
"""
Return all the values of the elements in this DenseIntOrFPElementsAttr
"""
return tuple(self.iter_values())
return self.get_element_type().unpack(self.data.data, len(self))

def iter_attrs(self) -> Iterator[AnyIntegerAttr] | Iterator[AnyFloatAttr]:
"""
Return an iterator over all elements of the dense attribute in their relevant
attribute representation (IntegerAttr / FloatAttr)
"""
return iter(self.data)
if isinstance(eltype := self.get_element_type(), IntegerType | IndexType):
return IntegerAttr.iter_unpack(eltype, self.data.data)
else:
return FloatAttr.iter_unpack(eltype, self.data.data)

def get_attrs(self) -> Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr]:
"""
Return all elements of the dense attribute in their relevant
attribute representation (IntegerAttr / FloatAttr)
"""
return self.data.data
if isinstance(eltype := self.get_element_type(), IntegerType | IndexType):
return IntegerAttr.unpack(eltype, self.data.data, len(self))
else:
return FloatAttr.unpack(eltype, self.data.data, len(self))

def is_splat(self) -> bool:
"""
Return whether or not this dense attribute is defined entirely
by a single value (splat).
"""
return self.data.data.count(self.data.data[0]) == len(self.data.data)
values = self.get_values()
return values.count(values[0]) == len(values)

@staticmethod
def parse_with_type(parser: AttrParser, type: Attribute) -> TypedAttribute:
assert isa(type, RankedStructure[AnyDenseElement])
return parser.parse_dense_int_or_fp_elements_attr(type)

@staticmethod
def _print_one_elem(val: Attribute, printer: Printer):
if isinstance(val, IntegerAttr):
val.print_without_type(printer)
elif isinstance(val, FloatAttr):
printer.print_float_attr(cast(AnyFloatAttr, val))
else:
raise Exception(
"unexpected attribute type "
"in DenseIntOrFPElementsAttr: "
f"{type(val)}"
)
def _print_one_elem(self, val: int | float, printer: Printer):
if isinstance(val, int):
element_type = cast(IntegerType | IndexType, self.get_element_type())
element_type.print_value_without_type(val, printer)
else: # float
printer.print_float(val, cast(AnyFloat, self.get_element_type()))

@staticmethod
def _print_dense_list(
array: Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr],
self,
array: Sequence[int] | Sequence[float],
shape: Sequence[int],
printer: Printer,
):
Expand All @@ -2229,28 +2243,26 @@ def _print_dense_list(
k = len(array) // shape[0]
printer.print_list(
(array[i : i + k] for i in range(0, len(array), k)),
lambda subarray: DenseIntOrFPElementsAttr._print_dense_list(
subarray, shape[1:], printer
),
lambda subarray: self._print_dense_list(subarray, shape[1:], printer),
)
else:
printer.print_list(
array,
lambda val: DenseIntOrFPElementsAttr._print_one_elem(val, printer),
lambda val: self._print_one_elem(val, printer),
)
printer.print_string("]")

def print_without_type(self, printer: Printer):
printer.print_string("dense<")
data = self.data.data
data = self.get_values()
shape = self.get_shape() if self.shape_is_complete else (len(data),)
assert shape is not None, "If shape is complete, then it cannot be None"
if len(data) == 0:
pass
elif data.count(data[0]) == len(data):
DenseIntOrFPElementsAttr._print_one_elem(data[0], printer)
elif self.is_splat():
self._print_one_elem(data[0], printer)
else:
DenseIntOrFPElementsAttr._print_dense_list(data, shape, printer)
self._print_dense_list(data, shape, printer)
printer.print_string(">")


Expand Down

0 comments on commit 78bea80

Please sign in to comment.