Skip to content

Commit

Permalink
chore(typing): improve typing support for generic PyTree[T] and reg…
Browse files Browse the repository at this point in the history
…istry lookup / register functions (#160)
  • Loading branch information
XuehaiPan committed Oct 9, 2024
1 parent 52a1f4a commit aa97eb5
Show file tree
Hide file tree
Showing 6 changed files with 246 additions and 185 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Improve typing support for generic `PyTree[T]` and registry lookup / register functions by [@XuehaiPan](https://github.com/XuehaiPan) in [#160](https://github.com/metaopt/optree/pull/160).

### Changed

Expand Down
1 change: 1 addition & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,4 @@ redef
hypot
init
ns
metaclass
13 changes: 6 additions & 7 deletions optree/_C.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,11 @@

import builtins
import enum
from collections.abc import Callable, Iterable, Iterator
from collections.abc import Callable, Collection, Iterable, Iterator
from typing import Any
from typing_extensions import Self

from optree.typing import (
CustomTreeNode,
FlattenFunc,
MetaData,
PyTree,
Expand Down Expand Up @@ -63,7 +62,7 @@ def make_none(
namespace: str = '', # unused
) -> PyTreeSpec: ...
def make_from_collection(
collection: CustomTreeNode[PyTreeSpec],
collection: Collection[PyTreeSpec],
node_is_leaf: bool = False,
namespace: str = '',
) -> PyTreeSpec: ...
Expand Down Expand Up @@ -149,14 +148,14 @@ class PyTreeIter(Iterator[T]):
def __next__(self) -> T: ...

def register_node(
cls: type[CustomTreeNode[T]],
flatten_func: FlattenFunc,
unflatten_func: UnflattenFunc,
cls: type[Collection[T]],
flatten_func: FlattenFunc[T],
unflatten_func: UnflattenFunc[T],
path_entry_type: type[PyTreeEntry],
namespace: str = '',
) -> None: ...
def unregister_node(
cls: type[CustomTreeNode[T]],
cls: type,
namespace: str = '',
) -> None: ...
def is_dict_insertion_ordered(
Expand Down
66 changes: 33 additions & 33 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,12 @@
import itertools
import textwrap
from collections import OrderedDict, defaultdict, deque
from typing import Any, Callable, ClassVar, Iterable, Mapping, overload
from typing import Any, Callable, ClassVar, Collection, Iterable, Mapping, overload

from optree import _C
from optree.accessor import PyTreeAccessor, PyTreeEntry
from optree.registry import PyTreeNodeRegistryEntry, register_pytree_node
from optree.registry import register_pytree_node
from optree.typing import (
CustomTreeNode,
MetaData,
NamedTuple,
PyTree,
Expand Down Expand Up @@ -2477,27 +2476,28 @@ def tree_flatten_one_level(
if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)): # type: ignore[unreachable,arg-type]
raise ValueError(f'Cannot flatten leaf-type: {node_type} (node: {tree!r}).')

handler: PyTreeNodeRegistryEntry | None = register_pytree_node.get(node_type, namespace=namespace) # type: ignore[attr-defined]
if handler:
flattened = tuple(handler.flatten_func(tree)) # type: ignore[arg-type]
if len(flattened) == 2:
flattened = (*flattened, None)
elif len(flattened) != 3:
raise RuntimeError(
f'PyTree custom flatten function for type {node_type} should return a 2- or 3-tuple, '
f'got {len(flattened)}.',
)
children, metadata, entries = flattened
children = list(children) # type: ignore[arg-type]
entries = tuple(range(len(children)) if entries is None else entries)
if len(children) != len(entries):
raise RuntimeError(
f'PyTree custom flatten function for type {node_type} returned inconsistent '
f'number of children ({len(children)}) and number of entries ({len(entries)}).',
)
return children, metadata, entries, handler.unflatten_func # type: ignore[return-value]
handler = register_pytree_node.get(node_type, namespace=namespace)
if handler is None:
raise ValueError(f'Cannot flatten leaf-type: {node_type} (node: {tree!r}).')

raise ValueError(f'Cannot flatten leaf-type: {node_type} (node: {tree!r}).')
flattened = tuple(handler.flatten_func(tree))
if len(flattened) == 2:
flattened = (*flattened, None)
elif len(flattened) != 3:
raise RuntimeError(
f'PyTree custom flatten function for type {node_type} should return a 2- or 3-tuple, '
f'got {len(flattened)}.',
)
flattened: tuple[Iterable[PyTree[T]], MetaData, Iterable[Any] | None]
children, metadata, entries = flattened
children = list(children)
entries = tuple(range(len(children)) if entries is None else entries)
if len(children) != len(entries):
raise RuntimeError(
f'PyTree custom flatten function for type {node_type} returned inconsistent '
f'number of children ({len(children)}) and number of entries ({len(entries)}).',
)
return children, metadata, entries, handler.unflatten_func # type: ignore[return-value]


def treespec_paths(treespec: PyTreeSpec) -> list[tuple[Any, ...]]:
Expand Down Expand Up @@ -2791,7 +2791,7 @@ def treespec_tuple(
A treespec representing a tuple node with the given children.
"""
return _C.make_from_collection(
tuple(iterable), # type: ignore[arg-type]
tuple(iterable),
none_is_leaf,
namespace,
)
Expand Down Expand Up @@ -2836,7 +2836,7 @@ def treespec_list(
A treespec representing a list node with the given children.
"""
return _C.make_from_collection(
list(iterable), # type: ignore[arg-type]
list(iterable),
none_is_leaf,
namespace,
)
Expand Down Expand Up @@ -2882,7 +2882,7 @@ def treespec_dict(
A treespec representing a dict node with the given children.
"""
return _C.make_from_collection(
dict(mapping, **kwargs), # type: ignore[arg-type]
dict(mapping, **kwargs),
none_is_leaf,
namespace,
)
Expand Down Expand Up @@ -2927,7 +2927,7 @@ def treespec_namedtuple(
if not is_namedtuple_instance(namedtuple):
raise ValueError(f'Expected a namedtuple of PyTreeSpec(s), got {namedtuple!r}.')
return _C.make_from_collection(
namedtuple, # type: ignore[arg-type]
namedtuple,
none_is_leaf,
namespace,
)
Expand Down Expand Up @@ -2973,7 +2973,7 @@ def treespec_ordereddict(
A treespec representing an OrderedDict node with the given children.
"""
return _C.make_from_collection(
OrderedDict(mapping, **kwargs), # type: ignore[arg-type]
OrderedDict(mapping, **kwargs),
none_is_leaf,
namespace,
)
Expand Down Expand Up @@ -3024,7 +3024,7 @@ def treespec_defaultdict(
A treespec representing a defaultdict node with the given children.
"""
return _C.make_from_collection(
defaultdict(default_factory, mapping, **kwargs), # type: ignore[arg-type]
defaultdict(default_factory, mapping, **kwargs),
none_is_leaf,
namespace,
)
Expand Down Expand Up @@ -3072,7 +3072,7 @@ def treespec_deque(
A treespec representing a deque node with the given children.
"""
return _C.make_from_collection(
deque(iterable, maxlen=maxlen), # type: ignore[arg-type]
deque(iterable, maxlen=maxlen),
none_is_leaf,
namespace,
)
Expand Down Expand Up @@ -3104,14 +3104,14 @@ def treespec_structseq(
if not is_structseq_instance(structseq):
raise ValueError(f'Expected a PyStructSequence of PyTreeSpec(s), got {structseq!r}.')
return _C.make_from_collection(
structseq, # type: ignore[arg-type]
structseq,
none_is_leaf,
namespace,
)


def treespec_from_collection(
collection: CustomTreeNode[PyTreeSpec],
collection: Collection[PyTreeSpec],
*,
none_is_leaf: bool = False,
namespace: str = '',
Expand Down Expand Up @@ -3267,7 +3267,7 @@ def _prefix_error(
return # don't look for more errors in this subtree

# If the keys agree, we should ensure that the children are in the same order:
full_tree_children = [full_tree[k] for k in prefix_tree_keys] # type: ignore[index]
full_tree_children = [full_tree[k] for k in prefix_tree_keys] # type: ignore[misc]

if len(prefix_tree_children) != len(full_tree_children):
yield lambda name: ValueError(
Expand Down
Loading

0 comments on commit aa97eb5

Please sign in to comment.