From 5ea86687c370f3f3a2142045851b960d98b9d132 Mon Sep 17 00:00:00 2001 From: Marvin Pfoertner Date: Wed, 22 Dec 2021 17:14:19 +0100 Subject: [PATCH] Cleanup of probnum's type aliases --- src/probnum/backend/_core/__init__.py | 4 +-- src/probnum/backend/random/_jax.py | 10 +++--- src/probnum/backend/random/_numpy.py | 12 +++---- src/probnum/backend/random/_torch.py | 4 +-- src/probnum/randvars/_random_variable.py | 1 - src/probnum/typing.py | 44 ++++++++++++------------ 6 files changed, 37 insertions(+), 38 deletions(-) diff --git a/src/probnum/backend/_core/__init__.py b/src/probnum/backend/_core/__init__.py index 6cf990c267..02cc9fe70b 100644 --- a/src/probnum/backend/_core/__init__.py +++ b/src/probnum/backend/_core/__init__.py @@ -1,5 +1,5 @@ from probnum import backend as _backend -from probnum.typing import ArrayType, DTypeArgType, ScalarArgType +from probnum.typing import ArrayType, DTypeArgType, ScalarLike if _backend.BACKEND is _backend.Backend.NUMPY: from . import _numpy as _core @@ -79,7 +79,7 @@ jit_method = _core.jit_method -def as_scalar(x: ScalarArgType, dtype: DTypeArgType = None) -> ArrayType: +def as_scalar(x: ScalarLike, dtype: DTypeArgType = None) -> ArrayType: """Convert a scalar into a NumPy scalar. Parameters diff --git a/src/probnum/backend/random/_jax.py b/src/probnum/backend/random/_jax.py index 8759a1754a..d98bcb7046 100644 --- a/src/probnum/backend/random/_jax.py +++ b/src/probnum/backend/random/_jax.py @@ -5,7 +5,7 @@ import jax from jax import numpy as jnp -from probnum.typing import DTypeArgType, FloatArgType, ShapeArgType +from probnum.typing import DTypeArgType, FloatLike, ShapeLike def seed(seed: Optional[int]) -> jnp.ndarray: @@ -28,9 +28,9 @@ def standard_normal(seed: jnp.ndarray, shape=(), dtype=jnp.double): def gamma( seed: jnp.ndarray, - shape_param: FloatArgType, - scale_param: FloatArgType = 1.0, - shape: ShapeArgType = (), + shape_param: FloatLike, + scale_param: FloatLike = 1.0, + shape: ShapeLike = (), dtype: DTypeArgType = jnp.double, ): return ( @@ -43,7 +43,7 @@ def gamma( def uniform_so_group( seed: jnp.ndarray, n: int, - shape: ShapeArgType = (), + shape: ShapeLike = (), dtype: DTypeArgType = jnp.double, ) -> jnp.ndarray: if n == 1: diff --git a/src/probnum/backend/random/_numpy.py b/src/probnum/backend/random/_numpy.py index ec11e98827..dae2971f12 100644 --- a/src/probnum/backend/random/_numpy.py +++ b/src/probnum/backend/random/_numpy.py @@ -3,7 +3,7 @@ import numpy as np -from probnum.typing import DTypeArgType, FloatArgType, ShapeArgType +from probnum.typing import DTypeArgType, FloatLike, ShapeLike def seed(seed: Optional[int]) -> np.random.SeedSequence: @@ -21,7 +21,7 @@ def split( def standard_normal( seed: np.random.SeedSequence, - shape: ShapeArgType = (), + shape: ShapeLike = (), dtype: DTypeArgType = np.double, ) -> np.ndarray: return _make_rng(seed).standard_normal(size=shape, dtype=dtype) @@ -29,9 +29,9 @@ def standard_normal( def gamma( seed: np.random.SeedSequence, - shape_param: FloatArgType, - scale_param: FloatArgType = 1.0, - shape: ShapeArgType = (), + shape_param: FloatLike, + scale_param: FloatLike = 1.0, + shape: ShapeLike = (), dtype: DTypeArgType = np.double, ) -> np.ndarray: return ( @@ -43,7 +43,7 @@ def gamma( def uniform_so_group( seed: np.random.SeedSequence, n: int, - shape: ShapeArgType = (), + shape: ShapeLike = (), dtype: DTypeArgType = np.double, ) -> np.ndarray: if n == 1: diff --git a/src/probnum/backend/random/_torch.py b/src/probnum/backend/random/_torch.py index 4e85d5c90f..968885ffbb 100644 --- a/src/probnum/backend/random/_torch.py +++ b/src/probnum/backend/random/_torch.py @@ -4,7 +4,7 @@ import torch from torch.distributions.utils import broadcast_all -from probnum.typing import DTypeArgType, ShapeArgType +from probnum.typing import DTypeArgType, ShapeLike _RNG_STATE_SIZE = torch.Generator().get_state().shape[0] @@ -51,7 +51,7 @@ def gamma( def uniform_so_group( seed: np.random.SeedSequence, n: int, - shape: ShapeArgType = (), + shape: ShapeLike = (), dtype: DTypeArgType = torch.double, ) -> torch.Tensor: if n == 1: diff --git a/src/probnum/randvars/_random_variable.py b/src/probnum/randvars/_random_variable.py index e1f15abc2c..877ba925c9 100644 --- a/src/probnum/randvars/_random_variable.py +++ b/src/probnum/randvars/_random_variable.py @@ -9,7 +9,6 @@ from probnum import backend, utils as _utils from probnum.typing import ( ArrayIndicesLike, - ArrayLike, ArrayType, DTypeLike, SeedType, diff --git a/src/probnum/typing.py b/src/probnum/typing.py index 57541f1f90..1d9fc6bb18 100644 --- a/src/probnum/typing.py +++ b/src/probnum/typing.py @@ -29,18 +29,13 @@ # Array Utilities ShapeType = Tuple[int, ...] -# Backend Types -ArrayType = Union[np.ndarray, "jax.numpy.ndarray", "torch.Tensor"] -ScalarType = Union[np.ndarray, "jax.numpy.ndarray", "torch.Tensor"] +# Scalars, Arrays and Matrices +ScalarType = "probnum.backend.ndarray" +MatrixType = Union["probnum.backend.ndarray", "probnum.linops.LinearOperator"] +# Random Number Generation SeedType = Union[np.random.SeedSequence, "jax.random.PRNGKey"] -# ProbNum Types -MatrixType = Union[ArrayType, "probnum.linops.LinearOperator"] - -# Scalars, Arrays and Matrices -ScalarType = np.number -MatrixType = Union[np.ndarray, "probnum.linops.LinearOperator"] ######################################################################################## # Argument Types @@ -64,39 +59,39 @@ """Type of a public API argument for supplying a shape. Values of this type should always be converted into :class:`ShapeType` using the -function :func:`probnum.backend.as_scalar` before further internal processing.""" +function :func:`probnum.backend.as_shape` before further internal processing.""" -DTypeLike = _NumPyDTypeLike +DTypeLike = Union[_NumPyDTypeLike, "jax.numpy.dtype", "torch.dtype"] """Type of a public API argument for supplying an array's dtype. -Values of this type should always be converted into :class:`np.dtype`\\ s before further -internal processing.""" +Values of this type should always be converted into :class:`backend.dtype`\\ s using the +function :func:`probnum.backend.as_dtype` before further internal processing.""" _ArrayIndexLike = Union[ int, slice, type(Ellipsis), None, - np.newaxis, - np.ndarray, + "probnum.backend.newaxis", + "probnum.backend.ndarray", ] ArrayIndicesLike = Union[_ArrayIndexLike, Tuple[_ArrayIndexLike, ...]] """Type of the argument to the :meth:`__getitem__` method of a NumPy-like array type -such as :class:`np.ndarray`, :class:`probnum.linops.LinearOperator` or +such as :class:`probnum.backend.ndarray`, :class:`probnum.linops.LinearOperator` or :class:`probnum.randvars.RandomVariable`.""" # Scalars, Arrays and Matrices -ScalarLike = Union[int, float, complex, numbers.Number, np.number] +ScalarLike = Union[ScalarType, int, float, complex, numbers.Number, np.number] """Type of a public API argument for supplying a scalar value. -Values of this type should always be converted into :class:`np.number`\\ s using the -function :func:`probnum.utils.as_scalar` before further internal processing.""" +Values of this type should always be converted into :class:`ScalarType`\\ s using +the function :func:`probnum.backend.as_scalar` before further internal processing.""" ArrayLike = Union[_NumPyArrayLike, "jax.numpy.ndarray", "torch.Tensor"] """Type of a public API argument for supplying an array. -Values of this type should always be converted into :class:`np.ndarray`\\ s using -the function :func:`np.asarray` before further internal processing.""" +Values of this type should always be converted into :class:`backend.ndarray`\\ s using +the function :func:`probnum.backend.as_array` before further internal processing.""" LinearOperatorLike = Union[ ArrayLike, @@ -106,10 +101,15 @@ """Type of a public API argument for supplying a finite-dimensional linear operator. Values of this type should always be converted into :class:`probnum.linops.\\ -LinearOperator`\\ s using the function :func:`probnum.linops.aslinop` before further +LinearOperator`\\ s using the function :func:`probnum.linops.as_linop` before further internal processing.""" +# Random Number Generation SeedLike = Optional[int] +"""Type of a public API argument for supplying the seed of a random number generator. + +Values of this type should always be converted to :class:`SeedType` using the function +:func:`probnum.backend.random.seed` before further internal processing.""" ######################################################################################## # Other Types