From 061f484cbd8294a707adb0238183bf76bcad0d95 Mon Sep 17 00:00:00 2001 From: Sasha Lopoukhine Date: Mon, 6 Jan 2025 21:57:39 +0000 Subject: [PATCH] dialects: (builtin) add unpack and iter_unpack to IntegerAttr and FloatAttr (#3706) This is useful for `DenseArrayBase`, and will be useful for `DenseIntOrFPElementsAttr` once that migrates to be backed by bytes. #3623 --- tests/dialects/test_builtin.py | 50 ++++++++++++++++++++++++++++++++ xdsl/dialects/builtin.py | 52 ++++++++++++++++++++++++++++++++++ 2 files changed, 102 insertions(+) diff --git a/tests/dialects/test_builtin.py b/tests/dialects/test_builtin.py index 6068aa37f3..352ad96952 100644 --- a/tests/dialects/test_builtin.py +++ b/tests/dialects/test_builtin.py @@ -175,48 +175,72 @@ def test_IntegerType_packing(): buffer_i1 = i1.pack(nums_i1) unpacked_i1 = i1.unpack(buffer_i1, len(nums_i1)) assert nums_i1 == unpacked_i1 + attrs_i1 = IntegerAttr.unpack(i1, buffer_i1, len(nums_i1)) + assert attrs_i1 == tuple(IntegerAttr(n, i1) for n in nums_i1) + assert tuple(attr for attr in IntegerAttr.iter_unpack(i1, buffer_i1)) == attrs_i1 # i8 nums_i8 = (-128, -1, 0, 1, 127) buffer_i8 = i8.pack(nums_i8) unpacked_i8 = i8.unpack(buffer_i8, len(nums_i8)) assert nums_i8 == unpacked_i8 + attrs_i8 = IntegerAttr.unpack(i8, buffer_i8, len(nums_i8)) + assert attrs_i8 == tuple(IntegerAttr(n, i8) for n in nums_i8) + assert tuple(attr for attr in IntegerAttr.iter_unpack(i8, buffer_i8)) == attrs_i8 # i16 nums_i16 = (-32768, -1, 0, 1, 32767) buffer_i16 = i16.pack(nums_i16) unpacked_i16 = i16.unpack(buffer_i16, len(nums_i16)) assert nums_i16 == unpacked_i16 + attrs_i16 = IntegerAttr.unpack(i16, buffer_i16, len(nums_i16)) + assert attrs_i16 == tuple(IntegerAttr(n, i16) for n in nums_i16) + assert tuple(attr for attr in IntegerAttr.iter_unpack(i16, buffer_i16)) == attrs_i16 # i32 nums_i32 = (-2147483648, -1, 0, 1, 2147483647) buffer_i32 = i32.pack(nums_i32) unpacked_i32 = i32.unpack(buffer_i32, len(nums_i32)) assert nums_i32 == unpacked_i32 + attrs_i32 = IntegerAttr.unpack(i32, buffer_i32, len(nums_i32)) + assert attrs_i32 == tuple(IntegerAttr(n, i32) for n in nums_i32) + assert tuple(attr for attr in IntegerAttr.iter_unpack(i32, buffer_i32)) == attrs_i32 # i64 nums_i64 = (-9223372036854775808, -1, 0, 1, 9223372036854775807) buffer_i64 = i64.pack(nums_i64) unpacked_i64 = i64.unpack(buffer_i64, len(nums_i64)) assert nums_i64 == unpacked_i64 + attrs_i64 = IntegerAttr.unpack(i64, buffer_i64, len(nums_i64)) + assert attrs_i64 == tuple(IntegerAttr(n, i64) for n in nums_i64) + assert tuple(attr for attr in IntegerAttr.iter_unpack(i64, buffer_i64)) == attrs_i64 # f16 nums_f16 = (-3.140625, -1.0, 0.0, 1.0, 3.140625) buffer_f16 = f16.pack(nums_f16) unpacked_f16 = f16.unpack(buffer_f16, len(nums_f16)) assert nums_f16 == unpacked_f16 + attrs_f16 = FloatAttr.unpack(f16, buffer_f16, len(nums_f16)) + assert attrs_f16 == tuple(FloatAttr(n, f16) for n in nums_f16) + assert tuple(attr for attr in FloatAttr.iter_unpack(f16, buffer_f16)) == attrs_f16 # f32 nums_f32 = (-3.140000104904175, -1.0, 0.0, 1.0, 3.140000104904175) buffer_f32 = f32.pack(nums_f32) unpacked_f32 = f32.unpack(buffer_f32, len(nums_f32)) assert nums_f32 == unpacked_f32 + attrs_f32 = FloatAttr.unpack(f32, buffer_f32, len(nums_f32)) + assert attrs_f32 == tuple(FloatAttr(n, f32) for n in nums_f32) + assert tuple(attr for attr in FloatAttr.iter_unpack(f32, buffer_f32)) == attrs_f32 # f64 nums_f64 = (-3.14159265359, -1.0, 0.0, 1.0, 3.14159265359) buffer_f64 = f64.pack(nums_f64) unpacked_f64 = f64.unpack(buffer_f64, len(nums_f64)) assert nums_f64 == unpacked_f64 + attrs_f64 = FloatAttr.unpack(f64, buffer_f64, len(nums_f64)) + assert attrs_f64 == tuple(FloatAttr(n, f64) for n in nums_f64) + assert tuple(attr for attr in FloatAttr.iter_unpack(f64, buffer_f64)) == attrs_f64 # Test error cases # Different Python versions have different error messages for these @@ -486,9 +510,35 @@ def test_complex_init(): def test_dense_as_tuple(): floats = DenseArrayBase.from_list(f32, [3.14159, 2.71828]) assert floats.get_values() == (3.141590118408203, 2.718280076980591) + assert tuple(floats.iter_values()) == (3.141590118408203, 2.718280076980591) + assert tuple(floats.iter_attrs()) == ( + FloatAttr(3.141590118408203, f32), + FloatAttr(2.718280076980591, f32), + ) + assert floats.get_attrs() == ( + FloatAttr(3.141590118408203, f32), + FloatAttr(2.718280076980591, f32), + ) ints = DenseArrayBase.from_list(i32, [1, 1, 2, 3, 5, 8]) assert ints.get_values() == (1, 1, 2, 3, 5, 8) + assert tuple(ints.iter_values()) == (1, 1, 2, 3, 5, 8) + assert tuple(ints.iter_attrs()) == ( + IntegerAttr(1, i32), + IntegerAttr(1, i32), + IntegerAttr(2, i32), + IntegerAttr(3, i32), + IntegerAttr(5, i32), + IntegerAttr(8, i32), + ) + assert ints.get_attrs() == ( + IntegerAttr(1, i32), + IntegerAttr(1, i32), + IntegerAttr(2, i32), + IntegerAttr(3, i32), + IntegerAttr(5, i32), + IntegerAttr(8, i32), + ) def test_create_dense_int(): diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 1f5f91b7d6..581761344a 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -602,6 +602,7 @@ def format(self) -> str: _IntegerAttrType = TypeVar( "_IntegerAttrType", bound=IntegerType | IndexType, covariant=True ) +_IntegerAttrTypeInvT = TypeVar("_IntegerAttrTypeInvT", bound=IntegerType | IndexType) _IntegerAttrTypeConstrT = TypeVar( "_IntegerAttrTypeConstrT", bound=IntegerType | IndexType, covariant=True ) @@ -706,6 +707,25 @@ def constr( ), ) + @staticmethod + def iter_unpack( + type: _IntegerAttrTypeInvT, buffer: ReadableBuffer, / + ) -> Iterator[IntegerAttr[_IntegerAttrTypeInvT]]: + """ + Yields unpacked values one at a time, starting at the beginning of the buffer. + """ + for value in type.iter_unpack(buffer): + yield IntegerAttr(value, type) + + @staticmethod + def unpack( + type: _IntegerAttrTypeInvT, buffer: ReadableBuffer, num: int, / + ) -> tuple[IntegerAttr[_IntegerAttrTypeInvT], ...]: + """ + Unpack `num` values from the beginning of the buffer. + """ + return tuple(IntegerAttr(value, type) for value in type.unpack(buffer, num)) + AnyIntegerAttr: TypeAlias = IntegerAttr[IntegerType | IndexType] AnyIntegerAttrConstr: BaseAttr[AnyIntegerAttr] = BaseAttr(IntegerAttr) @@ -833,6 +853,7 @@ def __hash__(self): _FloatAttrType = TypeVar("_FloatAttrType", bound=AnyFloat, covariant=True) +_FloatAttrTypeInvT = TypeVar("_FloatAttrTypeInvT", bound=AnyFloat) @irdl_attr_definition @@ -886,6 +907,25 @@ def parse_with_type( def print_without_type(self, printer: Printer): return printer.print_float_attr(self) + @staticmethod + def iter_unpack( + type: _FloatAttrTypeInvT, buffer: ReadableBuffer, / + ) -> Iterator[FloatAttr[_FloatAttrTypeInvT]]: + """ + Yields unpacked values one at a time, starting at the beginning of the buffer. + """ + for value in type.iter_unpack(buffer): + yield FloatAttr(value, type) + + @staticmethod + def unpack( + type: _FloatAttrTypeInvT, buffer: ReadableBuffer, num: int, / + ) -> tuple[FloatAttr[_FloatAttrTypeInvT], ...]: + """ + Unpack `num` values from the beginning of the buffer. + """ + return tuple(FloatAttr(value, type) for value in type.unpack(buffer, num)) + AnyFloatAttr: TypeAlias = FloatAttr[AnyFloat] AnyFloatAttrConstr: BaseAttr[AnyFloatAttr] = BaseAttr(FloatAttr) @@ -1250,6 +1290,18 @@ def iter_values(self) -> Iterator[float] | Iterator[int]: def get_values(self) -> tuple[int, ...] | tuple[float, ...]: return self.elt_type.unpack(self.data.data, len(self)) + def iter_attrs(self) -> Iterator[AnyIntegerAttr] | Iterator[AnyFloatAttr]: + if isinstance(self.elt_type, IntegerType): + return IntegerAttr.iter_unpack(self.elt_type, self.data.data) + else: + return FloatAttr.iter_unpack(self.elt_type, self.data.data) + + def get_attrs(self) -> tuple[AnyIntegerAttr, ...] | tuple[AnyFloatAttr, ...]: + if isinstance(self.elt_type, IntegerType): + return IntegerAttr.unpack(self.elt_type, self.data.data, len(self)) + else: + return FloatAttr.unpack(self.elt_type, self.data.data, len(self)) + def __len__(self) -> int: return len(self.data.data) // self.elt_type.size