diff --git a/src/snakeoil/_fileutils.py b/src/snakeoil/_fileutils.py index cd0bb645..9da88c90 100644 --- a/src/snakeoil/_fileutils.py +++ b/src/snakeoil/_fileutils.py @@ -15,9 +15,11 @@ import itertools import mmap import os +import typing +from typing import Any, AnyStr, IO, Iterable, Iterator, Optional -def mmap_and_close(fd, *args, **kwargs): +def mmap_and_close(fd: int, *args, **kwargs) -> mmap.mmap: """ see :py:obj:`mmap.mmap`; basically this maps, then closes, to ensure the fd doesn't bleed out. @@ -34,7 +36,13 @@ def mmap_and_close(fd, *args, **kwargs): class readlines_iter: __slots__ = ("iterable", "mtime", "source") - def __init__(self, iterable, mtime, close=True, source=None): + def __init__( + self, + iterable: Iterator[AnyStr], + mtime: int, + close=True, + source: Optional[IO[Any]] = None, + ) -> None: if source is None: source = iterable self.source = source @@ -44,7 +52,7 @@ def __init__(self, iterable, mtime, close=True, source=None): self.mtime = mtime @staticmethod - def _close_on_stop(source): + def _close_on_stop(source) -> Iterator[None]: # we explicitly write this to force this method to be # a generator; we intend to return nothing, but close # the file on the way out. @@ -56,7 +64,7 @@ def _close_on_stop(source): yield None source.close() - def close(self): + def close(self) -> None: if hasattr(self.source, "close"): self.source.close() @@ -65,12 +73,12 @@ def __iter__(self): def native_readlines( - mode, - mypath, + mode: str, + mypath: str, strip_whitespace=True, swallow_missing=False, none_on_missing=False, - encoding=None, + encoding: Optional[str] = None, ): """Read a file, yielding each line. @@ -98,12 +106,16 @@ def native_readlines( return readlines_iter(_strip_whitespace_filter(iterable), mtime, source=handle) -def _strip_whitespace_filter(iterable): +def _strip_whitespace_filter( + iterable: Iterable[typing.AnyStr], +) -> Iterator[typing.AnyStr]: for line in iterable: yield line.strip() -def native_readfile(mode, mypath, none_on_missing=False, encoding=None): +def native_readfile( + mode: str, mypath: str, none_on_missing=False, encoding: Optional[str] = None +) -> Optional[AnyStr]: """Read a file, returning the contents. :param mypath: fs path for the file to read diff --git a/src/snakeoil/mappings.py b/src/snakeoil/mappings.py index b05001fa..c13a91b1 100644 --- a/src/snakeoil/mappings.py +++ b/src/snakeoil/mappings.py @@ -1,7 +1,17 @@ """ Miscellaneous mapping related classes and functionality """ -from typing import Any, Callable, cast, Dict, Iterable, Optional, Sequence, Tuple, Union +from typing import ( + Any, + cast, + Generic, + Iterable, + Iterator, + Literal, + Optional, + Tuple, + TypeVar, +) __all__ = ( "DictMixin", @@ -26,7 +36,11 @@ from .klass import contains, get, sentinel, steal_docs -class DictMixin: +KT = TypeVar("KT", bound=Hashable) +VT = TypeVar("VT", bound=Any) + + +class DictMixin(Generic[KT, VT]): """ new style class replacement for :py:func:`UserDict.DictMixin` designed around iter* methods rather then forcing lists as DictMixin does @@ -45,7 +59,9 @@ class DictMixin: __slots__ = () __externally_mutable__ = True - def __init__(self, iterable=None, **kwargs): + def __init__( + self, iterable: Optional[Iterable[tuple[KT, VT]]] = None, **kwargs: VT + ) -> None: """ :param iterables: optional, an iterable of (key, value) to initialize this instance with @@ -59,32 +75,32 @@ def __init__(self, iterable=None, **kwargs): self.update(kwargs.items()) @steal_docs(dict) - def __iter__(self): + def __iter__(self) -> Iterator[KT]: return self.keys() @steal_docs(dict) - def __str__(self): + def __str__(self) -> str: return str(dict(self.items())) @steal_docs(dict) - def items(self): + def items(self) -> Iterator[tuple[KT, VT]]: for k in self: yield k, self[k] @steal_docs(dict) - def keys(self): + def keys(self) -> Iterator[KT]: raise NotImplementedError(self, "keys") @steal_docs(dict) - def values(self): + def values(self) -> Iterator[VT]: return map(self.__getitem__, self) @steal_docs(dict) def update( self, - iterable: Iterable[Tuple[Hashable, Any]] | Mapping[Hashable, Any] = (), + iterable: Iterable[Tuple[KT, VT]] | Mapping[KT, VT] = (), /, - **kwargs: Any, + **kwargs: VT, ) -> None: # this matches how python does dict.update at the c level. if hasattr(iterable, "keys"): @@ -100,7 +116,7 @@ def update( __contains__ = contains @steal_docs(dict) - def __eq__(self, other): + def __eq__(self, other: Mapping[Any, Any]) -> bool: if len(self) != len(other): return False for k1, k2 in zip(sorted(self), sorted(other)): @@ -111,11 +127,11 @@ def __eq__(self, other): return True @steal_docs(dict) - def __ne__(self, other): + def __ne__(self, other: Mapping[Any, Any]) -> bool: return not self.__eq__(other) @steal_docs(dict) - def pop(self, key, default=sentinel): + def pop(self, key: KT, default: VT | Literal[sentinel] = sentinel) -> VT: if not self.__externally_mutable__: raise AttributeError(self, "pop") try: