From 3aa6a85991a1bf337139f163af937eefe182fac9 Mon Sep 17 00:00:00 2001 From: Joren Dumoulin Date: Tue, 7 Jan 2025 16:31:38 +0100 Subject: [PATCH] dialects: (builtin) DenseIntOrFPElementsAttr: add iterators for values and attrs --- tests/dialects/test_builtin.py | 54 ++++++++++++++++++++++++++++++++++ xdsl/dialects/builtin.py | 15 +++++++++- 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/tests/dialects/test_builtin.py b/tests/dialects/test_builtin.py index 352ad96952..15e33b6270 100644 --- a/tests/dialects/test_builtin.py +++ b/tests/dialects/test_builtin.py @@ -324,6 +324,60 @@ def test_DenseIntOrFPElementsAttr_from_list(): assert len(attr) == 4 +def test_DenseIntOrFPElementsAttr_values(): + int_attr = DenseIntOrFPElementsAttr.tensor_from_list([1, 2, 3, 4], i32, [4]) + assert tuple(int_attr.get_values()) == (1, 2, 3, 4) + assert tuple(int_attr.iter_values()) == (1, 2, 3, 4) + assert tuple(int_attr.get_attrs()) == ( + IntegerAttr(1, i32), + IntegerAttr(2, i32), + IntegerAttr(3, i32), + IntegerAttr(4, i32), + ) + assert tuple(int_attr.iter_attrs()) == ( + IntegerAttr(1, i32), + IntegerAttr(2, i32), + IntegerAttr(3, i32), + IntegerAttr(4, i32), + ) + + index_attr = DenseIntOrFPElementsAttr.tensor_from_list( + [1, 2, 3, 4], IndexType(), [4] + ) + assert tuple(index_attr.get_values()) == (1, 2, 3, 4) + assert tuple(index_attr.iter_values()) == (1, 2, 3, 4) + assert tuple(index_attr.get_attrs()) == ( + IntegerAttr(1, IndexType()), + IntegerAttr(2, IndexType()), + IntegerAttr(3, IndexType()), + IntegerAttr(4, IndexType()), + ) + assert tuple(index_attr.iter_attrs()) == ( + IntegerAttr(1, IndexType()), + IntegerAttr(2, IndexType()), + IntegerAttr(3, IndexType()), + IntegerAttr(4, IndexType()), + ) + + float_attr = DenseIntOrFPElementsAttr.tensor_from_list( + [1.0, 2.0, 3.0, 4.0], f32, [4] + ) + assert tuple(float_attr.get_values()) == (1.0, 2.0, 3.0, 4.0) + assert tuple(float_attr.iter_values()) == (1.0, 2.0, 3.0, 4.0) + assert tuple(float_attr.get_attrs()) == ( + FloatAttr(1.0, f32), + FloatAttr(2.0, f32), + FloatAttr(3.0, f32), + FloatAttr(4.0, f32), + ) + assert tuple(float_attr.iter_attrs()) == ( + FloatAttr(1.0, f32), + FloatAttr(2.0, f32), + FloatAttr(3.0, f32), + FloatAttr(4.0, f32), + ) + + @pytest.mark.parametrize( "ref,expected", [ diff --git a/xdsl/dialects/builtin.py b/xdsl/dialects/builtin.py index 581761344a..c1683d8169 100644 --- a/xdsl/dialects/builtin.py +++ b/xdsl/dialects/builtin.py @@ -2167,11 +2167,24 @@ def tensor_from_list( t = TensorType(data_type, shape) return DenseIntOrFPElementsAttr.from_list(t, data) + def iter_values(self) -> Iterator[int] | Iterator[float]: + """ + Return an iterator over all the values of the elements in this DenseIntOrFPElementsAttr + """ + return (el.value.data for el in self.data.data) + def get_values(self) -> Sequence[int] | Sequence[float]: """ Return all the values of the elements in this DenseIntOrFPElementsAttr """ - return tuple(el.value.data for el in self.data.data) + return tuple(self.iter_values()) + + def iter_attrs(self) -> Iterator[AnyIntegerAttr] | Iterator[AnyFloatAttr]: + """ + Return an iterator over all elements of the dense attribute in their relevant + attribute representation (IntegerAttr / FloatAttr) + """ + return iter(self.data) def get_attrs(self) -> Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr]: """