Skip to content

Commit

Permalink
chore(typing): improve typing support for generic PyTree[T] and reg…
Browse files Browse the repository at this point in the history
…istry lookup / register functions (#160)
  • Loading branch information
XuehaiPan authored Oct 8, 2024
1 parent 52a1f4a commit 9a1a110
Show file tree
Hide file tree
Showing 5 changed files with 230 additions and 167 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

-
- Improve typing support for generic `PyTree[T]` and registry lookup / register functions by [@XuehaiPan](https://github.com/XuehaiPan) in [#160](https://github.com/metaopt/optree/pull/160).

### Changed

Expand Down
1 change: 1 addition & 0 deletions docs/source/spelling_wordlist.txt
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,4 @@ redef
hypot
init
ns
metaclass
45 changes: 23 additions & 22 deletions optree/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@

from optree import _C
from optree.accessor import PyTreeAccessor, PyTreeEntry
from optree.registry import PyTreeNodeRegistryEntry, register_pytree_node
from optree.registry import register_pytree_node
from optree.typing import (
CustomTreeNode,
MetaData,
Expand Down Expand Up @@ -2477,27 +2477,28 @@ def tree_flatten_one_level(
if (tree is None and none_is_leaf) or (is_leaf is not None and is_leaf(tree)): # type: ignore[unreachable,arg-type]
raise ValueError(f'Cannot flatten leaf-type: {node_type} (node: {tree!r}).')

handler: PyTreeNodeRegistryEntry | None = register_pytree_node.get(node_type, namespace=namespace) # type: ignore[attr-defined]
if handler:
flattened = tuple(handler.flatten_func(tree)) # type: ignore[arg-type]
if len(flattened) == 2:
flattened = (*flattened, None)
elif len(flattened) != 3:
raise RuntimeError(
f'PyTree custom flatten function for type {node_type} should return a 2- or 3-tuple, '
f'got {len(flattened)}.',
)
children, metadata, entries = flattened
children = list(children) # type: ignore[arg-type]
entries = tuple(range(len(children)) if entries is None else entries)
if len(children) != len(entries):
raise RuntimeError(
f'PyTree custom flatten function for type {node_type} returned inconsistent '
f'number of children ({len(children)}) and number of entries ({len(entries)}).',
)
return children, metadata, entries, handler.unflatten_func # type: ignore[return-value]
handler = register_pytree_node.get(node_type, namespace=namespace)
if handler is None:
raise ValueError(f'Cannot flatten leaf-type: {node_type} (node: {tree!r}).')

raise ValueError(f'Cannot flatten leaf-type: {node_type} (node: {tree!r}).')
flattened = tuple(handler.flatten_func(tree))
if len(flattened) == 2:
flattened = (*flattened, None)
elif len(flattened) != 3:
raise RuntimeError(
f'PyTree custom flatten function for type {node_type} should return a 2- or 3-tuple, '
f'got {len(flattened)}.',
)
flattened: tuple[Iterable[PyTree[T]], MetaData, Iterable[Any] | None]
children, metadata, entries = flattened
children = list(children)
entries = tuple(range(len(children)) if entries is None else entries)
if len(children) != len(entries):
raise RuntimeError(
f'PyTree custom flatten function for type {node_type} returned inconsistent '
f'number of children ({len(children)}) and number of entries ({len(entries)}).',
)
return children, metadata, entries, handler.unflatten_func # type: ignore[return-value]


def treespec_paths(treespec: PyTreeSpec) -> list[tuple[Any, ...]]:
Expand Down Expand Up @@ -3267,7 +3268,7 @@ def _prefix_error(
return # don't look for more errors in this subtree

# If the keys agree, we should ensure that the children are in the same order:
full_tree_children = [full_tree[k] for k in prefix_tree_keys] # type: ignore[index]
full_tree_children = [full_tree[k] for k in prefix_tree_keys] # type: ignore[misc]

if len(prefix_tree_children) != len(full_tree_children):
yield lambda name: ValueError(
Expand Down
Loading

0 comments on commit 9a1a110

Please sign in to comment.