Skip to content

Commit

Permalink
dialects: (builtin) DenseIntOrFPElementsAttr: add iterators for value…
Browse files Browse the repository at this point in the history
…s and attrs
  • Loading branch information
jorendumoulin committed Jan 7, 2025
1 parent a0d424d commit 3aa6a85
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 1 deletion.
54 changes: 54 additions & 0 deletions tests/dialects/test_builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,60 @@ def test_DenseIntOrFPElementsAttr_from_list():
assert len(attr) == 4


def test_DenseIntOrFPElementsAttr_values():
int_attr = DenseIntOrFPElementsAttr.tensor_from_list([1, 2, 3, 4], i32, [4])
assert tuple(int_attr.get_values()) == (1, 2, 3, 4)
assert tuple(int_attr.iter_values()) == (1, 2, 3, 4)
assert tuple(int_attr.get_attrs()) == (
IntegerAttr(1, i32),
IntegerAttr(2, i32),
IntegerAttr(3, i32),
IntegerAttr(4, i32),
)
assert tuple(int_attr.iter_attrs()) == (
IntegerAttr(1, i32),
IntegerAttr(2, i32),
IntegerAttr(3, i32),
IntegerAttr(4, i32),
)

index_attr = DenseIntOrFPElementsAttr.tensor_from_list(
[1, 2, 3, 4], IndexType(), [4]
)
assert tuple(index_attr.get_values()) == (1, 2, 3, 4)
assert tuple(index_attr.iter_values()) == (1, 2, 3, 4)
assert tuple(index_attr.get_attrs()) == (
IntegerAttr(1, IndexType()),
IntegerAttr(2, IndexType()),
IntegerAttr(3, IndexType()),
IntegerAttr(4, IndexType()),
)
assert tuple(index_attr.iter_attrs()) == (
IntegerAttr(1, IndexType()),
IntegerAttr(2, IndexType()),
IntegerAttr(3, IndexType()),
IntegerAttr(4, IndexType()),
)

float_attr = DenseIntOrFPElementsAttr.tensor_from_list(
[1.0, 2.0, 3.0, 4.0], f32, [4]
)
assert tuple(float_attr.get_values()) == (1.0, 2.0, 3.0, 4.0)
assert tuple(float_attr.iter_values()) == (1.0, 2.0, 3.0, 4.0)
assert tuple(float_attr.get_attrs()) == (
FloatAttr(1.0, f32),
FloatAttr(2.0, f32),
FloatAttr(3.0, f32),
FloatAttr(4.0, f32),
)
assert tuple(float_attr.iter_attrs()) == (
FloatAttr(1.0, f32),
FloatAttr(2.0, f32),
FloatAttr(3.0, f32),
FloatAttr(4.0, f32),
)


@pytest.mark.parametrize(
"ref,expected",
[
Expand Down
15 changes: 14 additions & 1 deletion xdsl/dialects/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2167,11 +2167,24 @@ def tensor_from_list(
t = TensorType(data_type, shape)
return DenseIntOrFPElementsAttr.from_list(t, data)

def iter_values(self) -> Iterator[int] | Iterator[float]:
"""
Return an iterator over all the values of the elements in this DenseIntOrFPElementsAttr
"""
return (el.value.data for el in self.data.data)

def get_values(self) -> Sequence[int] | Sequence[float]:
"""
Return all the values of the elements in this DenseIntOrFPElementsAttr
"""
return tuple(el.value.data for el in self.data.data)
return tuple(self.iter_values())

def iter_attrs(self) -> Iterator[AnyIntegerAttr] | Iterator[AnyFloatAttr]:
"""
Return an iterator over all elements of the dense attribute in their relevant
attribute representation (IntegerAttr / FloatAttr)
"""
return iter(self.data)

def get_attrs(self) -> Sequence[AnyIntegerAttr] | Sequence[AnyFloatAttr]:
"""
Expand Down

0 comments on commit 3aa6a85

Please sign in to comment.