Skip to content

Commit

Permalink
dialects: (builtin) add unpack and iter_unpack to IntegerAttr and Flo…
Browse files Browse the repository at this point in the history
…atAttr (#3706)

This is useful for `DenseArrayBase`, and will be useful for
`DenseIntOrFPElementsAttr` once that migrates to be backed by bytes.
#3623
  • Loading branch information
superlopuh authored Jan 6, 2025
1 parent 05458db commit 061f484
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 0 deletions.
50 changes: 50 additions & 0 deletions tests/dialects/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,48 +175,72 @@ def test_IntegerType_packing():
buffer_i1 = i1.pack(nums_i1)
unpacked_i1 = i1.unpack(buffer_i1, len(nums_i1))
assert nums_i1 == unpacked_i1
attrs_i1 = IntegerAttr.unpack(i1, buffer_i1, len(nums_i1))
assert attrs_i1 == tuple(IntegerAttr(n, i1) for n in nums_i1)
assert tuple(attr for attr in IntegerAttr.iter_unpack(i1, buffer_i1)) == attrs_i1

# i8
nums_i8 = (-128, -1, 0, 1, 127)
buffer_i8 = i8.pack(nums_i8)
unpacked_i8 = i8.unpack(buffer_i8, len(nums_i8))
assert nums_i8 == unpacked_i8
attrs_i8 = IntegerAttr.unpack(i8, buffer_i8, len(nums_i8))
assert attrs_i8 == tuple(IntegerAttr(n, i8) for n in nums_i8)
assert tuple(attr for attr in IntegerAttr.iter_unpack(i8, buffer_i8)) == attrs_i8

# i16
nums_i16 = (-32768, -1, 0, 1, 32767)
buffer_i16 = i16.pack(nums_i16)
unpacked_i16 = i16.unpack(buffer_i16, len(nums_i16))
assert nums_i16 == unpacked_i16
attrs_i16 = IntegerAttr.unpack(i16, buffer_i16, len(nums_i16))
assert attrs_i16 == tuple(IntegerAttr(n, i16) for n in nums_i16)
assert tuple(attr for attr in IntegerAttr.iter_unpack(i16, buffer_i16)) == attrs_i16

# i32
nums_i32 = (-2147483648, -1, 0, 1, 2147483647)
buffer_i32 = i32.pack(nums_i32)
unpacked_i32 = i32.unpack(buffer_i32, len(nums_i32))
assert nums_i32 == unpacked_i32
attrs_i32 = IntegerAttr.unpack(i32, buffer_i32, len(nums_i32))
assert attrs_i32 == tuple(IntegerAttr(n, i32) for n in nums_i32)
assert tuple(attr for attr in IntegerAttr.iter_unpack(i32, buffer_i32)) == attrs_i32

# i64
nums_i64 = (-9223372036854775808, -1, 0, 1, 9223372036854775807)
buffer_i64 = i64.pack(nums_i64)
unpacked_i64 = i64.unpack(buffer_i64, len(nums_i64))
assert nums_i64 == unpacked_i64
attrs_i64 = IntegerAttr.unpack(i64, buffer_i64, len(nums_i64))
assert attrs_i64 == tuple(IntegerAttr(n, i64) for n in nums_i64)
assert tuple(attr for attr in IntegerAttr.iter_unpack(i64, buffer_i64)) == attrs_i64

# f16
nums_f16 = (-3.140625, -1.0, 0.0, 1.0, 3.140625)
buffer_f16 = f16.pack(nums_f16)
unpacked_f16 = f16.unpack(buffer_f16, len(nums_f16))
assert nums_f16 == unpacked_f16
attrs_f16 = FloatAttr.unpack(f16, buffer_f16, len(nums_f16))
assert attrs_f16 == tuple(FloatAttr(n, f16) for n in nums_f16)
assert tuple(attr for attr in FloatAttr.iter_unpack(f16, buffer_f16)) == attrs_f16

# f32
nums_f32 = (-3.140000104904175, -1.0, 0.0, 1.0, 3.140000104904175)
buffer_f32 = f32.pack(nums_f32)
unpacked_f32 = f32.unpack(buffer_f32, len(nums_f32))
assert nums_f32 == unpacked_f32
attrs_f32 = FloatAttr.unpack(f32, buffer_f32, len(nums_f32))
assert attrs_f32 == tuple(FloatAttr(n, f32) for n in nums_f32)
assert tuple(attr for attr in FloatAttr.iter_unpack(f32, buffer_f32)) == attrs_f32

# f64
nums_f64 = (-3.14159265359, -1.0, 0.0, 1.0, 3.14159265359)
buffer_f64 = f64.pack(nums_f64)
unpacked_f64 = f64.unpack(buffer_f64, len(nums_f64))
assert nums_f64 == unpacked_f64
attrs_f64 = FloatAttr.unpack(f64, buffer_f64, len(nums_f64))
assert attrs_f64 == tuple(FloatAttr(n, f64) for n in nums_f64)
assert tuple(attr for attr in FloatAttr.iter_unpack(f64, buffer_f64)) == attrs_f64

# Test error cases
# Different Python versions have different error messages for these
Expand Down Expand Up @@ -486,9 +510,35 @@ def test_complex_init():
def test_dense_as_tuple():
floats = DenseArrayBase.from_list(f32, [3.14159, 2.71828])
assert floats.get_values() == (3.141590118408203, 2.718280076980591)
assert tuple(floats.iter_values()) == (3.141590118408203, 2.718280076980591)
assert tuple(floats.iter_attrs()) == (
FloatAttr(3.141590118408203, f32),
FloatAttr(2.718280076980591, f32),
)
assert floats.get_attrs() == (
FloatAttr(3.141590118408203, f32),
FloatAttr(2.718280076980591, f32),
)

ints = DenseArrayBase.from_list(i32, [1, 1, 2, 3, 5, 8])
assert ints.get_values() == (1, 1, 2, 3, 5, 8)
assert tuple(ints.iter_values()) == (1, 1, 2, 3, 5, 8)
assert tuple(ints.iter_attrs()) == (
IntegerAttr(1, i32),
IntegerAttr(1, i32),
IntegerAttr(2, i32),
IntegerAttr(3, i32),
IntegerAttr(5, i32),
IntegerAttr(8, i32),
)
assert ints.get_attrs() == (
IntegerAttr(1, i32),
IntegerAttr(1, i32),
IntegerAttr(2, i32),
IntegerAttr(3, i32),
IntegerAttr(5, i32),
IntegerAttr(8, i32),
)


def test_create_dense_int():
Expand Down
52 changes: 52 additions & 0 deletions xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -602,6 +602,7 @@ def format(self) -> str:
_IntegerAttrType = TypeVar(
"_IntegerAttrType", bound=IntegerType | IndexType, covariant=True
)
_IntegerAttrTypeInvT = TypeVar("_IntegerAttrTypeInvT", bound=IntegerType | IndexType)
_IntegerAttrTypeConstrT = TypeVar(
"_IntegerAttrTypeConstrT", bound=IntegerType | IndexType, covariant=True
)
Expand Down Expand Up @@ -706,6 +707,25 @@ def constr(
),
)

@staticmethod
def iter_unpack(
type: _IntegerAttrTypeInvT, buffer: ReadableBuffer, /
) -> Iterator[IntegerAttr[_IntegerAttrTypeInvT]]:
"""
Yields unpacked values one at a time, starting at the beginning of the buffer.
"""
for value in type.iter_unpack(buffer):
yield IntegerAttr(value, type)

@staticmethod
def unpack(
type: _IntegerAttrTypeInvT, buffer: ReadableBuffer, num: int, /
) -> tuple[IntegerAttr[_IntegerAttrTypeInvT], ...]:
"""
Unpack `num` values from the beginning of the buffer.
"""
return tuple(IntegerAttr(value, type) for value in type.unpack(buffer, num))


AnyIntegerAttr: TypeAlias = IntegerAttr[IntegerType | IndexType]
AnyIntegerAttrConstr: BaseAttr[AnyIntegerAttr] = BaseAttr(IntegerAttr)
Expand Down Expand Up @@ -833,6 +853,7 @@ def __hash__(self):


_FloatAttrType = TypeVar("_FloatAttrType", bound=AnyFloat, covariant=True)
_FloatAttrTypeInvT = TypeVar("_FloatAttrTypeInvT", bound=AnyFloat)


@irdl_attr_definition
Expand Down Expand Up @@ -886,6 +907,25 @@ def parse_with_type(
def print_without_type(self, printer: Printer):
return printer.print_float_attr(self)

@staticmethod
def iter_unpack(
type: _FloatAttrTypeInvT, buffer: ReadableBuffer, /
) -> Iterator[FloatAttr[_FloatAttrTypeInvT]]:
"""
Yields unpacked values one at a time, starting at the beginning of the buffer.
"""
for value in type.iter_unpack(buffer):
yield FloatAttr(value, type)

@staticmethod
def unpack(
type: _FloatAttrTypeInvT, buffer: ReadableBuffer, num: int, /
) -> tuple[FloatAttr[_FloatAttrTypeInvT], ...]:
"""
Unpack `num` values from the beginning of the buffer.
"""
return tuple(FloatAttr(value, type) for value in type.unpack(buffer, num))


AnyFloatAttr: TypeAlias = FloatAttr[AnyFloat]
AnyFloatAttrConstr: BaseAttr[AnyFloatAttr] = BaseAttr(FloatAttr)
Expand Down Expand Up @@ -1250,6 +1290,18 @@ def iter_values(self) -> Iterator[float] | Iterator[int]:
def get_values(self) -> tuple[int, ...] | tuple[float, ...]:
return self.elt_type.unpack(self.data.data, len(self))

def iter_attrs(self) -> Iterator[AnyIntegerAttr] | Iterator[AnyFloatAttr]:
if isinstance(self.elt_type, IntegerType):
return IntegerAttr.iter_unpack(self.elt_type, self.data.data)
else:
return FloatAttr.iter_unpack(self.elt_type, self.data.data)

def get_attrs(self) -> tuple[AnyIntegerAttr, ...] | tuple[AnyFloatAttr, ...]:
if isinstance(self.elt_type, IntegerType):
return IntegerAttr.unpack(self.elt_type, self.data.data, len(self))
else:
return FloatAttr.unpack(self.elt_type, self.data.data, len(self))

def __len__(self) -> int:
return len(self.data.data) // self.elt_type.size

Expand Down

0 comments on commit 061f484

Please sign in to comment.