Skip to content

Commit

Permalink
fix(py): Hugr.__iter__ returning NodeData | None instead of `Node…
Browse files Browse the repository at this point in the history
…`s (#1401)

`Hugr` is a `Mapping[Node, NodeData]`, so its `__iter__` should return a
list of `Node`s, but it was returning a list of node data with `None`
holes
  • Loading branch information
aborgna-q authored Aug 8, 2024
1 parent b7f0765 commit c134584
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
10 changes: 7 additions & 3 deletions hugr-py/src/hugr/hugr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from collections.abc import Iterable, Mapping
from collections.abc import Iterable, Iterator, Mapping
from dataclasses import dataclass, field, replace
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -120,8 +120,8 @@ def __getitem__(self, key: ToNode) -> NodeData:
raise KeyError(key)
return n

def __iter__(self):
return iter(self._nodes)
def __iter__(self) -> Iterator[Node]:
return (Node(idx) for idx, data in enumerate(self._nodes) if data is not None)

def __len__(self) -> int:
return self.num_nodes()
Expand All @@ -131,6 +131,10 @@ def _get_typed_op(self, node: ToNode, cl: type[OpVar2]) -> OpVar2:
assert isinstance(op, cl)
return op

def nodes(self) -> Iterable[tuple[Node, NodeData]]:
"""Iterator over nodes of the hugr and their data."""
return self.items()

def children(self, node: ToNode | None = None) -> list[Node]:
"""The child nodes of a given `node`.
Expand Down
4 changes: 4 additions & 0 deletions hugr-py/tests/test_hugr_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@ def test_stable_indices():

nodes = [h.add_node(Not) for _ in range(3)]
assert len(h) == 4
assert list(iter(h)) == [Node(i) for i in range(4)]
assert all(data is not None for node, data in h.nodes())

h.add_link(nodes[0].out(0), nodes[1].inp(0))
assert h.children() == nodes
Expand Down Expand Up @@ -47,6 +49,8 @@ def test_stable_indices():

assert len(h) == 4
assert h._free_nodes == []
assert list(iter(h)) == [Node(i) for i in range(4)]
assert all(data is not None for node, data in h.nodes())


def simple_id() -> Dfg:
Expand Down

0 comments on commit c134584

Please sign in to comment.