From c1345849a83e33571e1a51398f94348ea221d96b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Thu, 8 Aug 2024 13:00:50 +0100 Subject: [PATCH] fix(py): `Hugr.__iter__` returning `NodeData | None` instead of `Node`s (#1401) `Hugr` is a `Mapping[Node, NodeData]`, so its `__iter__` should return a list of `Node`s, but it was returning a list of node data with `None` holes --- hugr-py/src/hugr/hugr.py | 10 +++++++--- hugr-py/tests/test_hugr_build.py | 4 ++++ 2 files changed, 11 insertions(+), 3 deletions(-) diff --git a/hugr-py/src/hugr/hugr.py b/hugr-py/src/hugr/hugr.py index afcd05271..f1eda8e8a 100644 --- a/hugr-py/src/hugr/hugr.py +++ b/hugr-py/src/hugr/hugr.py @@ -2,7 +2,7 @@ from __future__ import annotations -from collections.abc import Iterable, Mapping +from collections.abc import Iterable, Iterator, Mapping from dataclasses import dataclass, field, replace from typing import ( TYPE_CHECKING, @@ -120,8 +120,8 @@ def __getitem__(self, key: ToNode) -> NodeData: raise KeyError(key) return n - def __iter__(self): - return iter(self._nodes) + def __iter__(self) -> Iterator[Node]: + return (Node(idx) for idx, data in enumerate(self._nodes) if data is not None) def __len__(self) -> int: return self.num_nodes() @@ -131,6 +131,10 @@ def _get_typed_op(self, node: ToNode, cl: type[OpVar2]) -> OpVar2: assert isinstance(op, cl) return op + def nodes(self) -> Iterable[tuple[Node, NodeData]]: + """Iterator over nodes of the hugr and their data.""" + return self.items() + def children(self, node: ToNode | None = None) -> list[Node]: """The child nodes of a given `node`. diff --git a/hugr-py/tests/test_hugr_build.py b/hugr-py/tests/test_hugr_build.py index b02cde251..de7a1646c 100644 --- a/hugr-py/tests/test_hugr_build.py +++ b/hugr-py/tests/test_hugr_build.py @@ -19,6 +19,8 @@ def test_stable_indices(): nodes = [h.add_node(Not) for _ in range(3)] assert len(h) == 4 + assert list(iter(h)) == [Node(i) for i in range(4)] + assert all(data is not None for node, data in h.nodes()) h.add_link(nodes[0].out(0), nodes[1].inp(0)) assert h.children() == nodes @@ -47,6 +49,8 @@ def test_stable_indices(): assert len(h) == 4 assert h._free_nodes == [] + assert list(iter(h)) == [Node(i) for i in range(4)] + assert all(data is not None for node, data in h.nodes()) def simple_id() -> Dfg: