diff --git a/optree/_C.pyi b/optree/_C.pyi index cf8ce63a..55a7503e 100644 --- a/optree/_C.pyi +++ b/optree/_C.pyi @@ -43,12 +43,14 @@ GLIBCXX_USE_CXX11_ABI: bool def flatten( tree: PyTree[T], + /, leaf_predicate: Callable[[T], bool] | None = None, node_is_leaf: bool = False, namespace: str = '', ) -> tuple[list[T], PyTreeSpec]: ... def flatten_with_path( tree: PyTree[T], + /, leaf_predicate: Callable[[T], bool] | None = None, node_is_leaf: bool = False, namespace: str = '', @@ -63,29 +65,32 @@ def make_none( ) -> PyTreeSpec: ... def make_from_collection( collection: Collection[PyTreeSpec], + /, node_is_leaf: bool = False, namespace: str = '', ) -> PyTreeSpec: ... def is_leaf( obj: T, + /, leaf_predicate: Callable[[T], bool] | None = None, node_is_leaf: bool = False, namespace: str = '', ) -> bool: ... def all_leaves( iterable: Iterable[T], + /, leaf_predicate: Callable[[T], bool] | None = None, node_is_leaf: bool = False, namespace: str = '', ) -> bool: ... -def is_namedtuple(obj: object | type) -> bool: ... -def is_namedtuple_instance(obj: object) -> bool: ... -def is_namedtuple_class(cls: type) -> bool: ... -def namedtuple_fields(obj: tuple | type[tuple]) -> tuple[str, ...]: ... -def is_structseq(obj: object | type) -> bool: ... -def is_structseq_instance(obj: object) -> bool: ... -def is_structseq_class(cls: type) -> bool: ... -def structseq_fields(obj: tuple | type[tuple]) -> tuple[str, ...]: ... +def is_namedtuple(obj: object | type, /) -> bool: ... +def is_namedtuple_instance(obj: object, /) -> bool: ... +def is_namedtuple_class(cls: type, /) -> bool: ... +def namedtuple_fields(obj: tuple | type[tuple], /) -> tuple[str, ...]: ... +def is_structseq(obj: object | type, /) -> bool: ... +def is_structseq_instance(obj: object, /) -> bool: ... +def is_structseq_class(cls: type, /) -> bool: ... +def structseq_fields(obj: tuple | type[tuple], /) -> tuple[str, ...]: ... class PyTreeKind(enum.IntEnum): CUSTOM = 0 # a custom type @@ -108,12 +113,13 @@ class PyTreeSpec: namespace: str type: builtins.type | None kind: PyTreeKind - def unflatten(self, leaves: Iterable[T]) -> PyTree[T]: ... - def flatten_up_to(self, full_tree: PyTree[T]) -> list[PyTree[T]]: ... - def broadcast_to_common_suffix(self, other: PyTreeSpec) -> PyTreeSpec: ... - def compose(self, inner_treespec: PyTreeSpec) -> PyTreeSpec: ... + def unflatten(self, leaves: Iterable[T], /) -> PyTree[T]: ... + def flatten_up_to(self, full_tree: PyTree[T], /) -> list[PyTree[T]]: ... + def broadcast_to_common_suffix(self, other: PyTreeSpec, /) -> PyTreeSpec: ... + def compose(self, inner_treespec: PyTreeSpec, /) -> PyTreeSpec: ... def walk( self, + /, f_node: Callable[[tuple[U, ...], MetaData], U], f_leaf: Callable[[T], U] | None, leaves: Iterable[T], @@ -121,18 +127,18 @@ class PyTreeSpec: def paths(self) -> list[tuple[Any, ...]]: ... def accessors(self) -> list[PyTreeAccessor]: ... def entries(self) -> list[Any]: ... - def entry(self, index: int) -> Any: ... + def entry(self, index: int, /) -> Any: ... def children(self) -> list[PyTreeSpec]: ... - def child(self, index: int) -> PyTreeSpec: ... - def is_leaf(self, strict: bool = True) -> bool: ... - def is_prefix(self, other: PyTreeSpec, strict: bool = False) -> bool: ... - def is_suffix(self, other: PyTreeSpec, strict: bool = False) -> bool: ... - def __eq__(self, other: object) -> bool: ... - def __ne__(self, other: object) -> bool: ... - def __lt__(self, other: object) -> bool: ... - def __le__(self, other: object) -> bool: ... - def __gt__(self, other: object) -> bool: ... - def __ge__(self, other: object) -> bool: ... + def child(self, index: int, /) -> PyTreeSpec: ... + def is_leaf(self, /, strict: bool = True) -> bool: ... + def is_prefix(self, other: PyTreeSpec, /, strict: bool = False) -> bool: ... + def is_suffix(self, other: PyTreeSpec, /, strict: bool = False) -> bool: ... + def __eq__(self, other: object, /) -> bool: ... + def __ne__(self, other: object, /) -> bool: ... + def __lt__(self, other: object, /) -> bool: ... + def __le__(self, other: object, /) -> bool: ... + def __gt__(self, other: object, /) -> bool: ... + def __ge__(self, other: object, /) -> bool: ... def __hash__(self) -> int: ... def __len__(self) -> int: ... @@ -140,6 +146,7 @@ class PyTreeIter(Iterator[T]): def __init__( self, tree: PyTree[T], + /, leaf_predicate: Callable[[T], bool] | None = None, node_is_leaf: bool = False, namespace: str = '', @@ -149,6 +156,7 @@ class PyTreeIter(Iterator[T]): def register_node( cls: type[Collection[T]], + /, flatten_func: FlattenFunc[T], unflatten_func: UnflattenFunc[T], path_entry_type: type[PyTreeEntry], @@ -156,6 +164,7 @@ def register_node( ) -> None: ... def unregister_node( cls: type, + /, namespace: str = '', ) -> None: ... def is_dict_insertion_ordered( @@ -164,5 +173,6 @@ def is_dict_insertion_ordered( ) -> bool: ... def set_dict_insertion_ordered( mode: bool, + /, namespace: str = '', ) -> None: ... diff --git a/optree/accessor.py b/optree/accessor.py index c985f832..44ad237b 100644 --- a/optree/accessor.py +++ b/optree/accessor.py @@ -76,7 +76,7 @@ def __post_init__(self) -> None: if self.kind == PyTreeKind.NONE: raise ValueError('Cannot create a path entry for None.') - def __call__(self, obj: Any) -> Any: + def __call__(self, obj: Any, /) -> Any: """Get the child object.""" try: return obj[self.entry] # should be overridden @@ -85,7 +85,7 @@ def __call__(self, obj: Any) -> Any: f'{self.__class__!r} cannot access through {obj!r} via entry {self.entry!r}', ) from ex - def __add__(self, other: object) -> PyTreeAccessor: + def __add__(self, other: object, /) -> PyTreeAccessor: """Join the path entry with another path entry or accessor.""" if isinstance(other, PyTreeEntry): return PyTreeAccessor((self, other)) @@ -93,7 +93,7 @@ def __add__(self, other: object) -> PyTreeAccessor: return PyTreeAccessor((self, *other)) return NotImplemented - def __eq__(self, other: object) -> bool: + def __eq__(self, other: object, /) -> bool: """Check if the path entries are equal.""" return isinstance(other, PyTreeEntry) and ( ( @@ -196,7 +196,7 @@ class GetItemEntry(PyTreeEntry): __slots__: ClassVar[tuple[()]] = () - def __call__(self, obj: Any) -> Any: + def __call__(self, obj: Any, /) -> Any: """Get the child object.""" return obj[self.entry] @@ -217,7 +217,7 @@ def name(self) -> str: """Get the attribute name.""" return self.entry - def __call__(self, obj: Any) -> Any: + def __call__(self, obj: Any, /) -> Any: """Get the child object.""" return getattr(obj, self.name) @@ -245,7 +245,7 @@ def index(self) -> int: """Get the index.""" return self.entry - def __call__(self, obj: Sequence[_T_co]) -> _T_co: + def __call__(self, obj: Sequence[_T_co], /) -> _T_co: """Get the child object.""" return obj[self.index] @@ -267,7 +267,7 @@ def key(self) -> _KT_co: """Get the key.""" return self.entry - def __call__(self, obj: Mapping[_KT_co, _VT_co]) -> _VT_co: + def __call__(self, obj: Mapping[_KT_co, _VT_co], /) -> _VT_co: """Get the child object.""" return obj[self.key] @@ -383,27 +383,27 @@ def __new__(cls, path: Iterable[PyTreeEntry] = ()) -> Self: raise TypeError(f'Expected a path of PyTreeEntry, got {path!r}.') return super().__new__(cls, path) - def __call__(self, obj: Any) -> Any: + def __call__(self, obj: Any, /) -> Any: """Get the child object.""" for entry in self: obj = entry(obj) return obj @overload # type: ignore[override] - def __getitem__(self, index: int) -> PyTreeEntry: # noqa: D105,RUF100 + def __getitem__(self, index: int, /) -> PyTreeEntry: # noqa: D105,RUF100 ... @overload - def __getitem__(self, index: slice) -> PyTreeAccessor: # noqa: D105,RUF100 + def __getitem__(self, index: slice, /) -> PyTreeAccessor: # noqa: D105,RUF100 ... - def __getitem__(self, index: int | slice) -> PyTreeEntry | PyTreeAccessor: + def __getitem__(self, index: int | slice, /) -> PyTreeEntry | PyTreeAccessor: """Get the child path entry or an accessor for a subpath.""" if isinstance(index, slice): return PyTreeAccessor(super().__getitem__(index)) return super().__getitem__(index) - def __add__(self, other: object) -> PyTreeAccessor: + def __add__(self, other: object, /) -> PyTreeAccessor: """Join the accessor with another path entry or accessor.""" if isinstance(other, PyTreeEntry): return PyTreeAccessor((*self, other)) @@ -411,15 +411,15 @@ def __add__(self, other: object) -> PyTreeAccessor: return PyTreeAccessor((*self, *other)) return NotImplemented - def __mul__(self, value: int) -> PyTreeAccessor: # type: ignore[override] + def __mul__(self, value: int, /) -> PyTreeAccessor: # type: ignore[override] """Repeat the accessor.""" return PyTreeAccessor(super().__mul__(value)) - def __rmul__(self, value: int) -> PyTreeAccessor: # type: ignore[override] + def __rmul__(self, value: int, /) -> PyTreeAccessor: # type: ignore[override] """Repeat the accessor.""" return PyTreeAccessor(super().__rmul__(value)) - def __eq__(self, other: object) -> bool: + def __eq__(self, other: object, /) -> bool: """Check if the accessors are equal.""" return isinstance(other, PyTreeAccessor) and super().__eq__(other) diff --git a/optree/dataclasses.py b/optree/dataclasses.py index 421326c6..38998405 100644 --- a/optree/dataclasses.py +++ b/optree/dataclasses.py @@ -155,6 +155,7 @@ def field( # type: ignore[no-redef] # pylint: disable=function-redefined,too-ma @dataclass_transform(field_specifiers=(field,)) def dataclass( # pylint: disable=too-many-arguments cls: None, + /, *, init: bool = True, repr: bool = True, # pylint: disable=redefined-builtin @@ -175,6 +176,7 @@ def dataclass( # pylint: disable=too-many-arguments @dataclass_transform(field_specifiers=(field,)) def dataclass( # pylint: disable=too-many-arguments cls: _TypeT, + /, *, init: bool = True, repr: bool = True, # pylint: disable=redefined-builtin @@ -193,6 +195,7 @@ def dataclass( # pylint: disable=too-many-arguments @dataclass_transform(field_specifiers=(field,)) def dataclass( # noqa: C901 # pylint: disable=function-redefined,too-many-arguments,too-many-locals,too-many-branches cls: _TypeT | None = None, + /, *, init: bool = True, repr: bool = True, # pylint: disable=redefined-builtin diff --git a/optree/functools.py b/optree/functools.py index 1262e1b4..502122b3 100644 --- a/optree/functools.py +++ b/optree/functools.py @@ -42,13 +42,13 @@ class _HashablePartialShim: args: tuple[Any, ...] keywords: dict[str, Any] - def __init__(self, partial_func: functools.partial) -> None: + def __init__(self, partial_func: functools.partial, /) -> None: self.partial_func: functools.partial = partial_func def __call__(self, *args: Any, **kwargs: Any) -> Any: return self.partial_func(*args, **kwargs) - def __eq__(self, other: object) -> bool: + def __eq__(self, other: object, /) -> bool: if isinstance(other, _HashablePartialShim): return self.partial_func == other.partial_func return self.partial_func == other @@ -118,7 +118,7 @@ class partial( # noqa: N801 # pylint: disable=invalid-name,too-few-public-metho TREE_PATH_ENTRY_TYPE: ClassVar[type[PyTreeEntry]] = GetAttrEntry - def __new__(cls, func: Callable[..., Any], *args: T, **keywords: T) -> Self: + def __new__(cls, func: Callable[..., Any], /, *args: T, **keywords: T) -> Self: """Create a new :class:`partial` instance.""" # In Python 3.10+, if func is itself a functools.partial instance, functools.partial.__new__ # would merge the arguments of this partial instance with the arguments of the func. We box diff --git a/optree/ops.py b/optree/ops.py index 4444e6c5..958d29f4 100644 --- a/optree/ops.py +++ b/optree/ops.py @@ -120,6 +120,7 @@ def tree_flatten( tree: PyTree[T], + /, is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, @@ -188,6 +189,7 @@ def tree_flatten( def tree_flatten_with_path( tree: PyTree[T], + /, is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, @@ -261,6 +263,7 @@ def tree_flatten_with_path( def tree_flatten_with_accessor( tree: PyTree[T], + /, is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, @@ -383,6 +386,7 @@ def tree_unflatten(treespec: PyTreeSpec, leaves: Iterable[T]) -> PyTree[T]: def tree_iter( tree: PyTree[T], + /, is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, @@ -424,6 +428,7 @@ def tree_iter( def tree_leaves( tree: PyTree[T], + /, is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, @@ -465,6 +470,7 @@ def tree_leaves( def tree_structure( tree: PyTree[T], + /, is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, @@ -506,6 +512,7 @@ def tree_structure( def tree_paths( tree: PyTree[T], + /, is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, @@ -547,6 +554,7 @@ def tree_paths( def tree_accessors( tree: PyTree[T], + /, is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, @@ -602,6 +610,7 @@ def tree_accessors( def tree_is_leaf( tree: PyTree[T], + /, is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, @@ -640,6 +649,7 @@ def tree_is_leaf( def all_leaves( iterable: Iterable[T], + /, is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, @@ -693,6 +703,7 @@ def all_leaves( def tree_map( func: Callable[..., U], tree: PyTree[T], + /, *rests: PyTree[S], is_leaf: Callable[[T], bool] | None = None, none_is_leaf: bool = False, @@ -749,6 +760,7 @@ def tree_map( def tree_map_( func: Callable[..., Any], tree: PyTree[T], + /, *rests: PyTree[S], is_leaf: Callable[[T], bool] | None = None, none_is_leaf: bool = False, @@ -790,6 +802,7 @@ def tree_map_( def tree_map_with_path( func: Callable[..., U], tree: PyTree[T], + /, *rests: PyTree[S], is_leaf: Callable[[T], bool] | None = None, none_is_leaf: bool = False, @@ -840,6 +853,7 @@ def tree_map_with_path( def tree_map_with_path_( func: Callable[..., Any], tree: PyTree[T], + /, *rests: PyTree[S], is_leaf: Callable[[T], bool] | None = None, none_is_leaf: bool = False, @@ -883,6 +897,7 @@ def tree_map_with_path_( def tree_map_with_accessor( func: Callable[..., U], tree: PyTree[T], + /, *rests: PyTree[S], is_leaf: Callable[[T], bool] | None = None, none_is_leaf: bool = False, @@ -956,6 +971,7 @@ def tree_map_with_accessor( def tree_map_with_accessor_( func: Callable[..., Any], tree: PyTree[T], + /, *rests: PyTree[S], is_leaf: Callable[[T], bool] | None = None, none_is_leaf: bool = False, @@ -996,7 +1012,7 @@ def tree_map_with_accessor_( return tree -def tree_replace_nones(sentinel: Any, tree: PyTree[T] | None, namespace: str = '') -> PyTree[T]: +def tree_replace_nones(sentinel: Any, tree: PyTree[T] | None, /, namespace: str = '') -> PyTree[T]: """Replace :data:`None` in ``tree`` with ``sentinel``. See also :func:`tree_flatten` and :func:`tree_map`. @@ -1029,6 +1045,7 @@ def tree_transpose( outer_treespec: PyTreeSpec, inner_treespec: PyTreeSpec, tree: PyTree[T], + /, is_leaf: Callable[[T], bool] | None = None, ) -> PyTree[T]: # PyTree[PyTree[T]] """Transform a tree having tree structure (outer, inner) into one having structure (inner, outer). @@ -1105,6 +1122,7 @@ def tree_transpose( def tree_transpose_map( func: Callable[..., PyTree[U]], tree: PyTree[T], + /, *rests: PyTree[S], inner_treespec: PyTreeSpec | None = None, is_leaf: Callable[[T], bool] | None = None, @@ -1197,6 +1215,7 @@ def tree_transpose_map( def tree_transpose_map_with_path( func: Callable[..., PyTree[U]], tree: PyTree[T], + /, *rests: PyTree[S], inner_treespec: PyTreeSpec | None = None, is_leaf: Callable[[T], bool] | None = None, @@ -1283,6 +1302,7 @@ def tree_transpose_map_with_path( def tree_transpose_map_with_accessor( func: Callable[..., PyTree[U]], tree: PyTree[T], + /, *rests: PyTree[S], inner_treespec: PyTreeSpec | None = None, is_leaf: Callable[[T], bool] | None = None, @@ -1396,6 +1416,7 @@ def tree_transpose_map_with_accessor( def tree_broadcast_prefix( prefix_tree: PyTree[T], full_tree: PyTree[S], + /, is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, @@ -1476,6 +1497,7 @@ def broadcast_leaves(x: T, subtree: PyTree[S]) -> PyTree[T]: def broadcast_prefix( prefix_tree: PyTree[T], full_tree: PyTree[S], + /, is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, @@ -1558,6 +1580,7 @@ def add_leaves(x: T, subtree: PyTree[S]) -> None: def tree_broadcast_common( tree: PyTree[T], other_tree: PyTree[T], + /, is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, @@ -1651,6 +1674,7 @@ def broadcast_leaves(x: T, subtree: PyTree[T]) -> PyTree[T]: def broadcast_common( tree: PyTree[T], other_tree: PyTree[T], + /, is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, @@ -1733,6 +1757,7 @@ def add_leaves(x: T, y: T) -> None: def _tree_broadcast_common( tree: PyTree[T], + /, *rests: PyTree[T], is_leaf: Callable[[T], bool] | None = None, none_is_leaf: bool = False, @@ -1768,6 +1793,7 @@ def _tree_broadcast_common( def tree_broadcast_map( func: Callable[..., U], tree: PyTree[T], + /, *rests: PyTree[T], is_leaf: Callable[[T], bool] | None = None, none_is_leaf: bool = False, @@ -1838,6 +1864,7 @@ def tree_broadcast_map( def tree_broadcast_map_with_path( func: Callable[..., U], tree: PyTree[T], + /, *rests: PyTree[T], is_leaf: Callable[[T], bool] | None = None, none_is_leaf: bool = False, @@ -1915,6 +1942,7 @@ def tree_broadcast_map_with_path( def tree_broadcast_map_with_accessor( func: Callable[..., U], tree: PyTree[T], + /, *rests: PyTree[T], is_leaf: Callable[[T], bool] | None = None, none_is_leaf: bool = False, @@ -2020,6 +2048,7 @@ def __repr__(self) -> str: def tree_reduce( func: Callable[[T, T], T], tree: PyTree[T], + /, *, is_leaf: Callable[[T], bool] | None = None, none_is_leaf: bool = False, @@ -2031,6 +2060,7 @@ def tree_reduce( def tree_reduce( func: Callable[[T, S], T], tree: PyTree[S], + /, initial: T = __MISSING, *, is_leaf: Callable[[S], bool] | None = None, @@ -2042,6 +2072,7 @@ def tree_reduce( def tree_reduce( func: Callable[[T, S], T], tree: PyTree[S], + /, initial: T = __MISSING, *, is_leaf: Callable[[S], bool] | None = None, @@ -2088,6 +2119,7 @@ def tree_reduce( def tree_sum( tree: PyTree[T], + /, start: T = 0, # type: ignore[assignment] *, is_leaf: Callable[[T], bool] | None = None, @@ -2140,6 +2172,7 @@ def tree_sum( @overload def tree_max( tree: PyTree[T], + /, *, is_leaf: Callable[[T], bool] | None = None, key: Callable[[T], Any] | None = None, @@ -2151,6 +2184,7 @@ def tree_max( @overload def tree_max( tree: PyTree[T], + /, *, default: T = __MISSING, key: Callable[[T], Any] | None = None, @@ -2162,6 +2196,7 @@ def tree_max( def tree_max( tree: PyTree[T], + /, *, default: T = __MISSING, key: Callable[[T], Any] | None = None, @@ -2229,6 +2264,7 @@ def tree_max( @overload def tree_min( tree: PyTree[T], + /, *, key: Callable[[T], Any] | None = None, is_leaf: Callable[[T], bool] | None = None, @@ -2240,6 +2276,7 @@ def tree_min( @overload def tree_min( tree: PyTree[T], + /, *, default: T = __MISSING, key: Callable[[T], Any] | None = None, @@ -2251,6 +2288,7 @@ def tree_min( def tree_min( tree: PyTree[T], + /, *, default: T = __MISSING, key: Callable[[T], Any] | None = None, @@ -2317,6 +2355,7 @@ def tree_min( def tree_all( tree: PyTree[T], + /, *, is_leaf: Callable[[T], bool] | None = None, none_is_leaf: bool = False, @@ -2368,6 +2407,7 @@ def tree_all( def tree_any( tree: PyTree[T], + /, *, is_leaf: Callable[[T], bool] | None = None, none_is_leaf: bool = False, @@ -2419,6 +2459,7 @@ def tree_any( def tree_flatten_one_level( tree: PyTree[T], + /, is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, @@ -2492,7 +2533,7 @@ def tree_flatten_one_level( return children, metadata, entries, handler.unflatten_func # type: ignore[return-value] -def treespec_paths(treespec: PyTreeSpec) -> list[tuple[Any, ...]]: +def treespec_paths(treespec: PyTreeSpec, /) -> list[tuple[Any, ...]]: """Return a list of paths to the leaves of a treespec. See also :func:`tree_flatten_with_path`, :func:`tree_paths`, and :meth:`PyTreeSpec.paths`. @@ -2500,7 +2541,7 @@ def treespec_paths(treespec: PyTreeSpec) -> list[tuple[Any, ...]]: return treespec.paths() -def treespec_accessors(treespec: PyTreeSpec) -> list[PyTreeAccessor]: +def treespec_accessors(treespec: PyTreeSpec, /) -> list[PyTreeAccessor]: """Return a list of accessors to the leaves of a treespec. See also :func:`tree_flatten_with_accessor`, :func:`tree_accessors` and @@ -2509,7 +2550,7 @@ def treespec_accessors(treespec: PyTreeSpec) -> list[PyTreeAccessor]: return treespec.accessors() -def treespec_entries(treespec: PyTreeSpec) -> list[Any]: +def treespec_entries(treespec: PyTreeSpec, /) -> list[Any]: """Return a list of one-level entries of a treespec to its children. See also :func:`treespec_entry`, :func:`treespec_paths`, :func:`treespec_children`, @@ -2518,7 +2559,7 @@ def treespec_entries(treespec: PyTreeSpec) -> list[Any]: return treespec.entries() -def treespec_entry(treespec: PyTreeSpec, index: int) -> Any: +def treespec_entry(treespec: PyTreeSpec, index: int, /) -> Any: """Return the entry of a treespec at the given index. See also :func:`treespec_entries`, :func:`treespec_children`, and :meth:`PyTreeSpec.entry`. @@ -2526,7 +2567,7 @@ def treespec_entry(treespec: PyTreeSpec, index: int) -> Any: return treespec.entry(index) -def treespec_children(treespec: PyTreeSpec) -> list[PyTreeSpec]: +def treespec_children(treespec: PyTreeSpec, /) -> list[PyTreeSpec]: """Return a list of treespecs for the children of a treespec. See also :func:`treespec_child`, :func:`treespec_paths`, :func:`treespec_entries`, @@ -2535,7 +2576,7 @@ def treespec_children(treespec: PyTreeSpec) -> list[PyTreeSpec]: return treespec.children() -def treespec_child(treespec: PyTreeSpec, index: int) -> PyTreeSpec: +def treespec_child(treespec: PyTreeSpec, index: int, /) -> PyTreeSpec: """Return the treespec of the child of a treespec at the given index. See also :func:`treespec_children`, :func:`treespec_entries`, and :meth:`PyTreeSpec.child`. @@ -2543,7 +2584,7 @@ def treespec_child(treespec: PyTreeSpec, index: int) -> PyTreeSpec: return treespec.child(index) -def treespec_is_leaf(treespec: PyTreeSpec, strict: bool = True) -> bool: +def treespec_is_leaf(treespec: PyTreeSpec, /, strict: bool = True) -> bool: """Return whether the treespec is a leaf that has no children. See also :func:`treespec_is_strict_leaf` and :meth:`PyTreeSpec.is_leaf`. @@ -2587,7 +2628,7 @@ def treespec_is_leaf(treespec: PyTreeSpec, strict: bool = True) -> bool: return treespec.num_nodes == 1 -def treespec_is_strict_leaf(treespec: PyTreeSpec) -> bool: +def treespec_is_strict_leaf(treespec: PyTreeSpec, /) -> bool: """Return whether the treespec is a strict leaf. See also :func:`treespec_is_leaf` and :meth:`PyTreeSpec.is_leaf`. @@ -2623,6 +2664,7 @@ def treespec_is_strict_leaf(treespec: PyTreeSpec) -> bool: def treespec_is_prefix( treespec: PyTreeSpec, other_treespec: PyTreeSpec, + /, strict: bool = False, ) -> bool: """Return whether ``treespec`` is a prefix of ``other_treespec``. @@ -2635,6 +2677,7 @@ def treespec_is_prefix( def treespec_is_suffix( treespec: PyTreeSpec, other_treespec: PyTreeSpec, + /, strict: bool = False, ) -> bool: """Return whether ``treespec`` is a suffix of ``other_treespec``. @@ -2746,6 +2789,7 @@ def treespec_none( def treespec_tuple( iterable: Iterable[PyTreeSpec] = (), + /, *, none_is_leaf: bool = False, namespace: str = '', @@ -2791,6 +2835,7 @@ def treespec_tuple( def treespec_list( iterable: Iterable[PyTreeSpec] = (), + /, *, none_is_leaf: bool = False, namespace: str = '', @@ -2836,6 +2881,7 @@ def treespec_list( def treespec_dict( mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (), + /, *, none_is_leaf: bool = False, namespace: str = '', @@ -2882,6 +2928,7 @@ def treespec_dict( def treespec_namedtuple( namedtuple: NamedTuple[PyTreeSpec], # type: ignore[type-arg] + /, *, none_is_leaf: bool = False, namespace: str = '', @@ -2927,6 +2974,7 @@ def treespec_namedtuple( def treespec_ordereddict( mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (), + /, *, none_is_leaf: bool = False, namespace: str = '', @@ -2974,6 +3022,7 @@ def treespec_ordereddict( def treespec_defaultdict( default_factory: Callable[[], Any] | None = None, mapping: Mapping[Any, PyTreeSpec] | Iterable[tuple[Any, PyTreeSpec]] = (), + /, *, none_is_leaf: bool = False, namespace: str = '', @@ -3024,6 +3073,7 @@ def treespec_defaultdict( def treespec_deque( iterable: Iterable[PyTreeSpec] = (), + /, maxlen: int | None = None, *, none_is_leaf: bool = False, @@ -3072,6 +3122,7 @@ def treespec_deque( def treespec_structseq( structseq: PyStructSequence[PyTreeSpec], + /, *, none_is_leaf: bool = False, namespace: str = '', @@ -3104,6 +3155,7 @@ def treespec_structseq( def treespec_from_collection( collection: Collection[PyTreeSpec], + /, *, none_is_leaf: bool = False, namespace: str = '', @@ -3153,6 +3205,7 @@ def treespec_from_collection( def prefix_errors( prefix_tree: PyTree[T], full_tree: PyTree[S], + /, is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, @@ -3179,6 +3232,7 @@ def _prefix_error( accessor: PyTreeAccessor, prefix_tree: PyTree[T], full_tree: PyTree[S], + /, is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, @@ -3340,6 +3394,7 @@ def _prefix_error( def _child_entries( tree: PyTree[T], + /, is_leaf: Callable[[T], bool] | None = None, *, none_is_leaf: bool = False, diff --git a/optree/registry.py b/optree/registry.py index b4b714b7..ec7f4e8a 100644 --- a/optree/registry.py +++ b/optree/registry.py @@ -128,11 +128,12 @@ def get(self, *args: _GetP.args, **kwargs: _GetP.kwargs) -> _GetT: def _add_get( get: Callable[_GetP, _GetT], + /, ) -> Callable[ [Callable[_P, _T]], _CallableWithGet[_P, _T, _GetP, _GetT], ]: - def decorator(func: Callable[_P, _T]) -> _CallableWithGet[_P, _T, _GetP, _GetT]: + def decorator(func: Callable[_P, _T], /) -> _CallableWithGet[_P, _T, _GetP, _GetT]: func.get = get # type: ignore[attr-defined] return func # type: ignore[return-value] @@ -141,6 +142,7 @@ def decorator(func: Callable[_P, _T]) -> _CallableWithGet[_P, _T, _GetP, _GetT]: def _pytree_node_registry_get( cls: type, + /, *, namespace: str = '', ) -> PyTreeNodeRegistryEntry | None: @@ -179,6 +181,7 @@ def _pytree_node_registry_get( @_add_get(_pytree_node_registry_get) def register_pytree_node( cls: type[Collection[T]], + /, flatten_func: FlattenFunc[T], unflatten_func: UnflattenFunc[T], *, @@ -345,6 +348,7 @@ def register_pytree_node( @overload def register_pytree_node_class( cls: str | None = None, + /, *, path_entry_type: type[PyTreeEntry] | None = None, namespace: str | None = None, @@ -354,6 +358,7 @@ def register_pytree_node_class( @overload def register_pytree_node_class( cls: CustomTreeNodeType, + /, *, path_entry_type: type[PyTreeEntry] | None, namespace: str, @@ -362,6 +367,7 @@ def register_pytree_node_class( def register_pytree_node_class( # noqa: C901 cls: CustomTreeNodeType | str | None = None, + /, *, path_entry_type: type[PyTreeEntry] | None = None, namespace: str | None = None, @@ -467,7 +473,7 @@ def tree_unflatten(cls, metadata, children): return cls -def unregister_pytree_node(cls: type, *, namespace: str) -> PyTreeNodeRegistryEntry: +def unregister_pytree_node(cls: type, /, *, namespace: str) -> PyTreeNodeRegistryEntry: """Remove a type from the pytree node registry. See also :func:`register_pytree_node` and :func:`register_pytree_node_class`. @@ -521,7 +527,7 @@ def unregister_pytree_node(cls: type, *, namespace: str) -> PyTreeNodeRegistryEn @contextlib.contextmanager -def dict_insertion_ordered(mode: bool, *, namespace: str) -> Generator[None]: +def dict_insertion_ordered(mode: bool, /, *, namespace: str) -> Generator[None]: """Context manager to temporarily set the dictionary sorting mode. This context manager is used to temporarily set the dictionary sorting mode for a specific @@ -567,70 +573,73 @@ def dict_insertion_ordered(mode: bool, *, namespace: str) -> Generator[None]: _C.set_dict_insertion_ordered(prev, namespace) -def _sorted_items(items: Iterable[tuple[KT, VT]]) -> list[tuple[KT, VT]]: +def _sorted_items(items: Iterable[tuple[KT, VT]], /) -> list[tuple[KT, VT]]: return total_order_sorted(items, key=itemgetter(0)) -def _none_flatten(none: None) -> tuple[tuple[()], None]: +def _none_flatten(none: None, /) -> tuple[tuple[()], None]: return (), None -def _none_unflatten(_: None, children: Iterable[Any]) -> None: +def _none_unflatten(_: None, /, children: Iterable[Any]) -> None: sentinel = object() if next(iter(children), sentinel) is not sentinel: raise ValueError('Expected no children.') return None # noqa: RET501 -def _tuple_flatten(tup: tuple[T, ...]) -> tuple[tuple[T, ...], None]: +def _tuple_flatten(tup: tuple[T, ...], /) -> tuple[tuple[T, ...], None]: return tup, None -def _tuple_unflatten(_: None, children: Iterable[T]) -> tuple[T, ...]: +def _tuple_unflatten(_: None, children: Iterable[T], /) -> tuple[T, ...]: return tuple(children) -def _list_flatten(lst: list[T]) -> tuple[list[T], None]: +def _list_flatten(lst: list[T], /) -> tuple[list[T], None]: return lst, None -def _list_unflatten(_: None, children: Iterable[T]) -> list[T]: +def _list_unflatten(_: None, children: Iterable[T], /) -> list[T]: return list(children) -def _dict_flatten(dct: dict[KT, VT]) -> tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]: +def _dict_flatten(dct: dict[KT, VT], /) -> tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]: keys, values = unzip2(_sorted_items(dct.items())) return values, list(keys), keys -def _dict_unflatten(keys: list[KT], values: Iterable[VT]) -> dict[KT, VT]: +def _dict_unflatten(keys: list[KT], values: Iterable[VT], /) -> dict[KT, VT]: return dict(safe_zip(keys, values)) def _dict_insertion_ordered_flatten( dct: dict[KT, VT], + /, ) -> tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]: keys, values = unzip2(dct.items()) return values, list(keys), keys -def _dict_insertion_ordered_unflatten(keys: list[KT], values: Iterable[VT]) -> dict[KT, VT]: +def _dict_insertion_ordered_unflatten(keys: list[KT], values: Iterable[VT], /) -> dict[KT, VT]: return dict(safe_zip(keys, values)) def _ordereddict_flatten( dct: OrderedDict[KT, VT], + /, ) -> tuple[tuple[VT, ...], list[KT], tuple[KT, ...]]: keys, values = unzip2(dct.items()) return values, list(keys), keys -def _ordereddict_unflatten(keys: list[KT], values: Iterable[VT]) -> OrderedDict[KT, VT]: +def _ordereddict_unflatten(keys: list[KT], values: Iterable[VT], /) -> OrderedDict[KT, VT]: return OrderedDict(safe_zip(keys, values)) def _defaultdict_flatten( dct: defaultdict[KT, VT], + /, ) -> tuple[tuple[VT, ...], tuple[Callable[[], VT] | None, list[KT]], tuple[KT, ...]]: values, keys, entries = _dict_flatten(dct) return values, (dct.default_factory, keys), entries @@ -639,6 +648,7 @@ def _defaultdict_flatten( def _defaultdict_unflatten( metadata: tuple[Callable[[], VT], list[KT]], values: Iterable[VT], + /, ) -> defaultdict[KT, VT]: default_factory, keys = metadata return defaultdict(default_factory, _dict_unflatten(keys, values)) @@ -646,6 +656,7 @@ def _defaultdict_unflatten( def _defaultdict_insertion_ordered_flatten( dct: defaultdict[KT, VT], + /, ) -> tuple[tuple[VT, ...], tuple[Callable[[], VT] | None, list[KT]], tuple[KT, ...]]: values, keys, entries = _dict_insertion_ordered_flatten(dct) return values, (dct.default_factory, keys), entries @@ -654,32 +665,33 @@ def _defaultdict_insertion_ordered_flatten( def _defaultdict_insertion_ordered_unflatten( metadata: tuple[Callable[[], VT], list[KT]], values: Iterable[VT], + /, ) -> defaultdict[KT, VT]: default_factory, keys = metadata return defaultdict(default_factory, _dict_insertion_ordered_unflatten(keys, values)) -def _deque_flatten(deq: deque[T]) -> tuple[deque[T], int | None]: +def _deque_flatten(deq: deque[T], /) -> tuple[deque[T], int | None]: return deq, deq.maxlen -def _deque_unflatten(maxlen: int | None, children: Iterable[T]) -> deque[T]: +def _deque_unflatten(maxlen: int | None, children: Iterable[T], /) -> deque[T]: return deque(children, maxlen=maxlen) -def _namedtuple_flatten(tup: NamedTuple[T]) -> tuple[tuple[T, ...], type[NamedTuple[T]]]: # type: ignore[type-arg] +def _namedtuple_flatten(tup: NamedTuple[T], /) -> tuple[tuple[T, ...], type[NamedTuple[T]]]: # type: ignore[type-arg] return tup, type(tup) -def _namedtuple_unflatten(cls: type[NamedTuple[T]], children: Iterable[T]) -> NamedTuple[T]: # type: ignore[type-arg] +def _namedtuple_unflatten(cls: type[NamedTuple[T]], children: Iterable[T], /) -> NamedTuple[T]: # type: ignore[type-arg] return cls(*children) # type: ignore[call-overload] -def _structseq_flatten(seq: structseq[T]) -> tuple[tuple[T, ...], type[structseq[T]]]: +def _structseq_flatten(seq: structseq[T], /) -> tuple[tuple[T, ...], type[structseq[T]]]: return seq, type(seq) -def _structseq_unflatten(cls: type[structseq[T]], children: Iterable[T]) -> structseq[T]: +def _structseq_unflatten(cls: type[structseq[T]], children: Iterable[T], /) -> structseq[T]: return cls(children) @@ -707,7 +719,7 @@ def _structseq_unflatten(cls: type[structseq[T]], children: Iterable[T]) -> stru 'The function `_sorted_keys` is deprecated and will be removed in a future version.', category=FutureWarning, ) - def _sorted_keys(dct: dict[KT, VT]) -> list[KT]: + def _sorted_keys(dct: dict[KT, VT], /) -> list[KT]: return total_order_sorted(dct) @deprecated( @@ -718,14 +730,14 @@ def _sorted_keys(dct: dict[KT, VT]) -> list[KT]: class KeyPathEntry(NamedTuple): key: Any - def __add__(self, other: object) -> KeyPath: + def __add__(self, other: object, /) -> KeyPath: if isinstance(other, KeyPathEntry): return KeyPath((self, other)) if isinstance(other, KeyPath): return KeyPath((self, *other.keys)) return NotImplemented - def __eq__(self, other: object) -> bool: + def __eq__(self, other: object, /) -> bool: return isinstance(other, self.__class__) and self.key == other.key def pprint(self) -> str: @@ -740,14 +752,14 @@ def pprint(self) -> str: class KeyPath(NamedTuple): keys: tuple[KeyPathEntry, ...] = () - def __add__(self, other: object) -> KeyPath: + def __add__(self, other: object, /) -> KeyPath: if isinstance(other, KeyPathEntry): return KeyPath((*self.keys, other)) if isinstance(other, KeyPath): return KeyPath(self.keys + other.keys) return NotImplemented - def __eq__(self, other: object) -> bool: + def __eq__(self, other: object, /) -> bool: return isinstance(other, KeyPath) and self.keys == other.keys def pprint(self) -> str: @@ -803,6 +815,7 @@ def pprint(self) -> str: @_add_get(_KEYPATH_REGISTRY.get) def register_keypaths( cls: type[Collection[T]], + /, handler: KeyPathHandler[T], ) -> KeyPathHandler[T]: """Register a key path handler for a custom pytree node type.""" diff --git a/optree/typing.py b/optree/typing.py index fb3a0e0c..e3fb903f 100644 --- a/optree/typing.py +++ b/optree/typing.py @@ -155,7 +155,7 @@ def tree_unflatten(cls, metadata: MetaData, children: Children[T]) -> Self: _UnionType = type(Union[int, str]) -def _tp_cache(func: Callable[P, T]) -> Callable[P, T]: +def _tp_cache(func: Callable[P, T], /) -> Callable[P, T]: cached = functools.lru_cache(func) @functools.wraps(func) @@ -247,15 +247,15 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> NoReturn: """Prohibit subclassing.""" raise TypeError('Cannot subclass special typing classes.') - def __getitem__(self, key: Any) -> PyTree[T] | T: + def __getitem__(self, key: Any, /) -> PyTree[T] | T: """Emulate collection-like behavior.""" raise NotImplementedError - def __getattr__(self, name: str) -> PyTree[T] | T: + def __getattr__(self, name: str, /) -> PyTree[T] | T: """Emulate dataclass-like behavior.""" raise NotImplementedError - def __contains__(self, key: Any | T) -> bool: + def __contains__(self, key: Any | T, /) -> bool: """Emulate collection-like behavior.""" raise NotImplementedError @@ -267,15 +267,15 @@ def __iter__(self) -> Iterator[PyTree[T] | T | Any]: """Emulate collection-like behavior.""" raise NotImplementedError - def index(self, key: Any | T) -> int: + def index(self, key: Any | T, /) -> int: """Emulate sequence-like behavior.""" raise NotImplementedError - def count(self, key: Any | T) -> int: + def count(self, key: Any | T, /) -> int: """Emulate sequence-like behavior.""" raise NotImplementedError - def get(self, key: Any, default: T | None = None) -> T | None: + def get(self, key: Any, /, default: T | None = None) -> T | None: """Emulate mapping-like behavior.""" raise NotImplementedError @@ -329,7 +329,7 @@ def __init_subclass__(cls, *args: Any, **kwargs: Any) -> NoReturn: UnflattenFunc: TypeAlias = Callable[[MetaData, Children[T]], Collection[T]] -def _override_with_(cxx_implementation: F) -> Callable[[F], F]: +def _override_with_(cxx_implementation: F, /) -> Callable[[F], F]: """Decorator to override the Python implementation with the C++ implementation. >>> @_override_with_(any) @@ -343,7 +343,7 @@ def _override_with_(cxx_implementation: F) -> Callable[[F], F]: True """ - def wrapper(python_implementation: F) -> F: + def wrapper(python_implementation: F, /) -> F: @functools.wraps(python_implementation) def wrapped(*args: Any, **kwargs: Any) -> Any: return cxx_implementation(*args, **kwargs) @@ -357,20 +357,20 @@ def wrapped(*args: Any, **kwargs: Any) -> Any: @_override_with_(_C.is_namedtuple) -def is_namedtuple(obj: object | type) -> bool: +def is_namedtuple(obj: object | type, /) -> bool: """Return whether the object is an instance of namedtuple or a subclass of namedtuple.""" cls = obj if isinstance(obj, type) else type(obj) return is_namedtuple_class(cls) @_override_with_(_C.is_namedtuple_instance) -def is_namedtuple_instance(obj: object) -> bool: +def is_namedtuple_instance(obj: object, /) -> bool: """Return whether the object is an instance of namedtuple.""" return is_namedtuple_class(type(obj)) @_override_with_(_C.is_namedtuple_class) -def is_namedtuple_class(cls: type) -> bool: +def is_namedtuple_class(cls: type, /) -> bool: """Return whether the class is a subclass of namedtuple.""" return ( isinstance(cls, type) @@ -386,7 +386,7 @@ def is_namedtuple_class(cls: type) -> bool: @_override_with_(_C.namedtuple_fields) -def namedtuple_fields(obj: tuple | type[tuple]) -> tuple[str, ...]: +def namedtuple_fields(obj: tuple | type[tuple], /) -> tuple[str, ...]: """Return the field names of a namedtuple.""" if isinstance(obj, type): cls = obj @@ -405,7 +405,7 @@ def namedtuple_fields(obj: tuple | type[tuple]) -> tuple[str, ...]: class StructSequenceMeta(type): """The metaclass for PyStructSequence stub type.""" - def __subclasscheck__(cls, subclass: type) -> bool: + def __subclasscheck__(cls, subclass: type, /) -> bool: """Return whether the class is a PyStructSequence type. >>> import time @@ -420,7 +420,7 @@ def __subclasscheck__(cls, subclass: type) -> bool: """ return is_structseq_class(subclass) - def __instancecheck__(cls, instance: Any) -> bool: + def __instancecheck__(cls, instance: Any, /) -> bool: """Return whether the object is a PyStructSequence instance. >>> import sys @@ -456,14 +456,14 @@ def __new__(cls, sequence: Iterable[_T_co], dict: dict[str, Any] = ...) -> Self: @_override_with_(_C.is_structseq) -def is_structseq(obj: object | type) -> bool: +def is_structseq(obj: object | type, /) -> bool: """Return whether the object is an instance of PyStructSequence or a class of PyStructSequence.""" cls = obj if isinstance(obj, type) else type(obj) return is_structseq_class(cls) @_override_with_(_C.is_structseq_instance) -def is_structseq_instance(obj: object) -> bool: +def is_structseq_instance(obj: object, /) -> bool: """Return whether the object is an instance of PyStructSequence.""" return is_structseq_class(type(obj)) @@ -473,7 +473,7 @@ def is_structseq_instance(obj: object) -> bool: @_override_with_(_C.is_structseq_class) -def is_structseq_class(cls: type) -> bool: +def is_structseq_class(cls: type, /) -> bool: """Return whether the class is a class of PyStructSequence.""" if ( isinstance(cls, type) @@ -496,7 +496,7 @@ def is_structseq_class(cls: type) -> bool: @_override_with_(_C.structseq_fields) -def structseq_fields(obj: tuple | type[tuple]) -> tuple[str, ...]: +def structseq_fields(obj: tuple | type[tuple], /) -> tuple[str, ...]: """Return the field names of a PyStructSequence.""" if isinstance(obj, type): cls = obj diff --git a/optree/utils.py b/optree/utils.py index 0e877295..e472b7ab 100644 --- a/optree/utils.py +++ b/optree/utils.py @@ -25,6 +25,7 @@ def total_order_sorted( iterable: Iterable[T], + /, *, key: Callable[[T], Any] | None = None, reverse: bool = False, @@ -61,32 +62,36 @@ def key_fn(x: T) -> tuple[str, Any]: @overload def safe_zip( - __iter1: Iterable[T], + iter1: Iterable[T], + /, ) -> zip[tuple[T]]: ... @overload def safe_zip( - __iter1: Iterable[T], - __iter2: Iterable[S], + iter1: Iterable[T], + iter2: Iterable[S], + /, ) -> zip[tuple[T, S]]: ... @overload def safe_zip( - __iter1: Iterable[T], - __iter2: Iterable[S], - __iter3: Iterable[U], + iter1: Iterable[T], + iter2: Iterable[S], + iter3: Iterable[U], + /, ) -> zip[tuple[T, S, U]]: ... @overload def safe_zip( - __iter1: Iterable[Any], - __iter2: Iterable[Any], - __iter3: Iterable[Any], - __iter4: Iterable[Any], - *__iters: Iterable[Any], + iter1: Iterable[Any], + iter2: Iterable[Any], + iter3: Iterable[Any], + iter4: Iterable[Any], + /, + *iters: Iterable[Any], ) -> zip[tuple[Any, ...]]: ... @@ -98,7 +103,7 @@ def safe_zip(*args: Iterable[Any]) -> zip[tuple[Any, ...]]: return zip(*seqs) -def unzip2(xys: Iterable[tuple[T, S]]) -> tuple[tuple[T, ...], tuple[S, ...]]: +def unzip2(xys: Iterable[tuple[T, S]], /) -> tuple[tuple[T, ...], tuple[S, ...]]: """Unzip sequence of length-2 tuples into two tuples.""" # Note: we deliberately don't use zip(*xys) because it is lazily evaluated, # is too permissive about inputs, and does not guarantee a length-2 output. diff --git a/src/optree.cpp b/src/optree.cpp index cf6a2a17..860c26ab 100644 --- a/src/optree.cpp +++ b/src/optree.cpp @@ -20,6 +20,7 @@ limitations under the License. #include // std::optional, std::nullopt #include // std::string +#include #include #include @@ -60,6 +61,7 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] "Register a Python type. Extends the set of types that are considered internal nodes " "in pytrees.", py::arg("cls"), + py::pos_only(), py::arg("flatten_func"), py::arg("unflatten_func"), py::arg("path_entry_type"), @@ -68,6 +70,7 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] &PyTreeTypeRegistry::Unregister, "Unregister a Python type.", py::arg("cls"), + py::pos_only(), py::arg("namespace") = "") .def("is_dict_insertion_ordered", &PyTreeSpec::IsDictInsertionOrdered, @@ -78,11 +81,13 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] &PyTreeSpec::SetDictInsertionOrdered, "Set whether need to preserve the dict insertion order during flattening.", py::arg("mode"), + py::pos_only(), py::arg("namespace") = "") .def("flatten", &PyTreeSpec::Flatten, "Flattens a pytree.", py::arg("tree"), + py::pos_only(), py::arg("leaf_predicate") = std::nullopt, py::arg("none_is_leaf") = false, py::arg("namespace") = "") @@ -90,6 +95,7 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] &PyTreeSpec::FlattenWithPath, "Flatten a pytree and additionally record the paths.", py::arg("tree"), + py::pos_only(), py::arg("leaf_predicate") = std::nullopt, py::arg("none_is_leaf") = false, py::arg("namespace") = "") @@ -97,6 +103,7 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] &IsLeaf, "Test whether the given object is a leaf node.", py::arg("obj"), + py::pos_only(), py::arg("leaf_predicate") = std::nullopt, py::arg("none_is_leaf") = false, py::arg("namespace") = "") @@ -104,6 +111,7 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] &AllLeaves, "Test whether all elements in the given iterable are all leaves.", py::arg("iterable"), + py::pos_only(), py::arg("leaf_predicate") = std::nullopt, py::arg("none_is_leaf") = false, py::arg("namespace") = "") @@ -120,42 +128,51 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] .def("make_from_collection", &PyTreeSpec::MakeFromCollection, "Make a treespec from a collection of child treespecs.", - py::arg("tuple"), + py::arg("collection"), + py::pos_only(), py::arg("none_is_leaf") = false, py::arg("namespace") = "") .def("is_namedtuple", &IsNamedTuple, "Return whether the object is an instance of namedtuple or a subclass of namedtuple.", - py::arg("obj")) + py::arg("obj"), + py::pos_only()) .def("is_namedtuple_instance", &IsNamedTupleInstance, "Return whether the object is an instance of namedtuple.", - py::arg("obj")) + py::arg("obj"), + py::pos_only()) .def("is_namedtuple_class", &IsNamedTupleClass, "Return whether the class is a subclass of namedtuple.", - py::arg("cls")) + py::arg("cls"), + py::pos_only()) .def("namedtuple_fields", &NamedTupleGetFields, "Return the field names of a namedtuple.", - py::arg("obj")) + py::arg("obj"), + py::pos_only()) .def("is_structseq", &IsStructSequence, "Return whether the object is an instance of PyStructSequence or a class of " "PyStructSequence.", - py::arg("obj")) + py::arg("obj"), + py::pos_only()) .def("is_structseq_instance", &IsStructSequenceInstance, "Return whether the object is an instance of PyStructSequence.", - py::arg("obj")) + py::arg("obj"), + py::pos_only()) .def("is_structseq_class", &IsStructSequenceClass, "Return whether the object is a class of PyStructSequence.", - py::arg("cls")) + py::arg("cls"), + py::pos_only()) .def("structseq_fields", &StructSequenceGetFields, "Return the field names of a PyStructSequence.", - py::arg("obj")); + py::arg("obj"), + py::pos_only()); auto PyTreeKindTypeObject = py::enum_(mod, "PyTreeKind", "The kind of a pytree node.", py::module_local()) @@ -194,24 +211,29 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] .def("unflatten", &PyTreeSpec::Unflatten, "Reconstruct a pytree from the leaves.", - py::arg("leaves")) + py::arg("leaves"), + py::pos_only()) .def("flatten_up_to", &PyTreeSpec::FlattenUpTo, "Flatten the subtrees in ``full_tree`` up to the structure of this treespec " "and return a list of subtrees.", - py::arg("full_tree")) + py::arg("full_tree"), + py::pos_only()) .def("broadcast_to_common_suffix", &PyTreeSpec::BroadcastToCommonSuffix, "Broadcast to the common suffix of this treespec and other treespec.", - py::arg("other")) + py::arg("other"), + py::pos_only()) .def("compose", &PyTreeSpec::Compose, "Compose two treespecs. Constructs the inner treespec as a subtree at each leaf node.", - py::arg("inner_treespec")) + py::arg("inner_treespec"), + py::pos_only()) .def("walk", &PyTreeSpec::Walk, "Walk over the pytree structure, calling ``f_node(children, node_data)`` at nodes, " "and ``f_leaf(leaf)`` at leaves.", + py::pos_only(), py::arg("f_node"), py::arg("f_leaf"), py::arg("leaves")) @@ -220,12 +242,17 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] &PyTreeSpec::Accessors, "Return a list of accessors to the leaves in the treespec.") .def("entries", &PyTreeSpec::Entries, "Return a list of one-level entries to the children.") - .def("entry", &PyTreeSpec::Entry, "Return the entry at the given index.", py::arg("index")) + .def("entry", + &PyTreeSpec::Entry, + "Return the entry at the given index.", + py::arg("index"), + py::pos_only()) .def("children", &PyTreeSpec::Children, "Return a list of treespecs for the children.") .def("child", &PyTreeSpec::Child, "Return the treespec for the child at the given index.", - py::arg("index")) + py::arg("index"), + py::pos_only()) .def_property_readonly("num_leaves", &PyTreeSpec::GetNumLeaves, "Number of leaves in the tree.") @@ -255,47 +282,56 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] .def("is_leaf", &PyTreeSpec::IsLeaf, "Test whether the current node is a leaf.", + py::pos_only(), py::arg("strict") = true) .def("is_prefix", &PyTreeSpec::IsPrefix, "Test whether this treespec is a prefix of the given treespec.", py::arg("other"), + py::pos_only(), py::arg("strict") = false) .def("is_suffix", &PyTreeSpec::IsSuffix, "Test whether this treespec is a suffix of the given treespec.", py::arg("other"), + py::pos_only(), py::arg("strict") = false) .def("__eq__", std::equal_to(), "Test for equality to another object.", py::is_operator(), - py::arg("other")) + py::arg("other"), + py::pos_only()) .def("__ne__", std::not_equal_to(), "Test for inequality to another object.", py::is_operator(), - py::arg("other")) + py::arg("other"), + py::pos_only()) .def("__lt__", std::less(), "Test for this treespec is a strict prefix of another object.", py::is_operator(), - py::arg("other")) + py::arg("other"), + py::pos_only()) .def("__le__", std::less_equal(), "Test for this treespec is a prefix of another object.", py::is_operator(), - py::arg("other")) + py::arg("other"), + py::pos_only()) .def("__gt__", std::greater(), "Test for this treespec is a strict suffix of another object.", py::is_operator(), - py::arg("other")) + py::arg("other"), + py::pos_only()) .def("__ge__", std::greater_equal(), "Test for this treespec is a suffix of another object.", py::is_operator(), - py::arg("other")) + py::arg("other"), + py::pos_only()) .def("__repr__", &PyTreeSpec::ToString, "Return a string representation of the treespec.") .def("__hash__", &PyTreeSpec::HashValue, "Return the hash of the treespec.") .def("__len__", &PyTreeSpec::GetNumLeaves, "Number of leaves in the tree.") @@ -304,7 +340,8 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] return PyTreeSpec::FromPickleable(o); }), "Serialization support for PyTreeSpec.", - py::arg("state")); + py::arg("state"), + py::pos_only()); auto PyTreeIterTypeObject = py::class_( mod, @@ -326,6 +363,7 @@ void BuildModule(py::module_& mod) { // NOLINT[runtime/references] .def(py::init, bool, std::string>(), "Create a new iterator over the leaves of a pytree.", py::arg("tree"), + py::pos_only(), py::arg("leaf_predicate") = std::nullopt, py::arg("none_is_leaf") = false, py::arg("namespace") = "") diff --git a/tests/test_dataclasses.py b/tests/test_dataclasses.py index c206f658..ab61a473 100644 --- a/tests/test_dataclasses.py +++ b/tests/test_dataclasses.py @@ -51,18 +51,7 @@ def test_same_signature(): for name, param in field_parameters.items() ][: len(field_original_parameters)], ) == OrderedDict( - ( - name, - ( - param.name, - ( - param.kind - if param.kind != inspect.Parameter.POSITIONAL_ONLY - else inspect.Parameter.POSITIONAL_OR_KEYWORD - ), - param.default, - ), - ) + (name, (param.name, param.kind, param.default)) for name, param in field_original_parameters.items() ) @@ -79,18 +68,7 @@ def test_same_signature(): for name, param in dataclass_parameters.items() ][: len(dataclass_original_parameters)], ) == OrderedDict( - ( - name, - ( - param.name, - ( - param.kind - if param.kind != inspect.Parameter.POSITIONAL_ONLY - else inspect.Parameter.POSITIONAL_OR_KEYWORD - ), - param.default, - ), - ) + (name, (param.name, param.kind, param.default)) for name, param in dataclass_original_parameters.items() ) @@ -116,18 +94,7 @@ def test_same_signature(): for name, param in make_dataclass_parameters.items() ][: len(make_dataclass_original_parameters)], ) == OrderedDict( - ( - name, - ( - param.name, - ( - param.kind - if param.kind != inspect.Parameter.POSITIONAL_ONLY - else inspect.Parameter.POSITIONAL_OR_KEYWORD - ), - param.default, - ), - ) + (name, (param.name, param.kind, param.default)) for name, param in make_dataclass_original_parameters.items() )