Skip to content

Commit

Permalink
Cleanup of probnum's type aliases
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinpfoertner committed Dec 30, 2021
1 parent e1de948 commit 5ea8668
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 38 deletions.
4 changes: 2 additions & 2 deletions src/probnum/backend/_core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from probnum import backend as _backend
from probnum.typing import ArrayType, DTypeArgType, ScalarArgType
from probnum.typing import ArrayType, DTypeArgType, ScalarLike

if _backend.BACKEND is _backend.Backend.NUMPY:
from . import _numpy as _core
Expand Down Expand Up @@ -79,7 +79,7 @@
jit_method = _core.jit_method


def as_scalar(x: ScalarArgType, dtype: DTypeArgType = None) -> ArrayType:
def as_scalar(x: ScalarLike, dtype: DTypeArgType = None) -> ArrayType:
"""Convert a scalar into a NumPy scalar.
Parameters
Expand Down
10 changes: 5 additions & 5 deletions src/probnum/backend/random/_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import jax
from jax import numpy as jnp

from probnum.typing import DTypeArgType, FloatArgType, ShapeArgType
from probnum.typing import DTypeArgType, FloatLike, ShapeLike


def seed(seed: Optional[int]) -> jnp.ndarray:
Expand All @@ -28,9 +28,9 @@ def standard_normal(seed: jnp.ndarray, shape=(), dtype=jnp.double):

def gamma(
seed: jnp.ndarray,
shape_param: FloatArgType,
scale_param: FloatArgType = 1.0,
shape: ShapeArgType = (),
shape_param: FloatLike,
scale_param: FloatLike = 1.0,
shape: ShapeLike = (),
dtype: DTypeArgType = jnp.double,
):
return (
Expand All @@ -43,7 +43,7 @@ def gamma(
def uniform_so_group(
seed: jnp.ndarray,
n: int,
shape: ShapeArgType = (),
shape: ShapeLike = (),
dtype: DTypeArgType = jnp.double,
) -> jnp.ndarray:
if n == 1:
Expand Down
12 changes: 6 additions & 6 deletions src/probnum/backend/random/_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import numpy as np

from probnum.typing import DTypeArgType, FloatArgType, ShapeArgType
from probnum.typing import DTypeArgType, FloatLike, ShapeLike


def seed(seed: Optional[int]) -> np.random.SeedSequence:
Expand All @@ -21,17 +21,17 @@ def split(

def standard_normal(
seed: np.random.SeedSequence,
shape: ShapeArgType = (),
shape: ShapeLike = (),
dtype: DTypeArgType = np.double,
) -> np.ndarray:
return _make_rng(seed).standard_normal(size=shape, dtype=dtype)


def gamma(
seed: np.random.SeedSequence,
shape_param: FloatArgType,
scale_param: FloatArgType = 1.0,
shape: ShapeArgType = (),
shape_param: FloatLike,
scale_param: FloatLike = 1.0,
shape: ShapeLike = (),
dtype: DTypeArgType = np.double,
) -> np.ndarray:
return (
Expand All @@ -43,7 +43,7 @@ def gamma(
def uniform_so_group(
seed: np.random.SeedSequence,
n: int,
shape: ShapeArgType = (),
shape: ShapeLike = (),
dtype: DTypeArgType = np.double,
) -> np.ndarray:
if n == 1:
Expand Down
4 changes: 2 additions & 2 deletions src/probnum/backend/random/_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
from torch.distributions.utils import broadcast_all

from probnum.typing import DTypeArgType, ShapeArgType
from probnum.typing import DTypeArgType, ShapeLike

_RNG_STATE_SIZE = torch.Generator().get_state().shape[0]

Expand Down Expand Up @@ -51,7 +51,7 @@ def gamma(
def uniform_so_group(
seed: np.random.SeedSequence,
n: int,
shape: ShapeArgType = (),
shape: ShapeLike = (),
dtype: DTypeArgType = torch.double,
) -> torch.Tensor:
if n == 1:
Expand Down
1 change: 0 additions & 1 deletion src/probnum/randvars/_random_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from probnum import backend, utils as _utils
from probnum.typing import (
ArrayIndicesLike,
ArrayLike,
ArrayType,
DTypeLike,
SeedType,
Expand Down
44 changes: 22 additions & 22 deletions src/probnum/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,13 @@
# Array Utilities
ShapeType = Tuple[int, ...]

# Backend Types
ArrayType = Union[np.ndarray, "jax.numpy.ndarray", "torch.Tensor"]
ScalarType = Union[np.ndarray, "jax.numpy.ndarray", "torch.Tensor"]
# Scalars, Arrays and Matrices
ScalarType = "probnum.backend.ndarray"
MatrixType = Union["probnum.backend.ndarray", "probnum.linops.LinearOperator"]

# Random Number Generation
SeedType = Union[np.random.SeedSequence, "jax.random.PRNGKey"]

# ProbNum Types
MatrixType = Union[ArrayType, "probnum.linops.LinearOperator"]

# Scalars, Arrays and Matrices
ScalarType = np.number
MatrixType = Union[np.ndarray, "probnum.linops.LinearOperator"]

########################################################################################
# Argument Types
Expand All @@ -64,39 +59,39 @@
"""Type of a public API argument for supplying a shape.
Values of this type should always be converted into :class:`ShapeType` using the
function :func:`probnum.backend.as_scalar` before further internal processing."""
function :func:`probnum.backend.as_shape` before further internal processing."""

DTypeLike = _NumPyDTypeLike
DTypeLike = Union[_NumPyDTypeLike, "jax.numpy.dtype", "torch.dtype"]
"""Type of a public API argument for supplying an array's dtype.
Values of this type should always be converted into :class:`np.dtype`\\ s before further
internal processing."""
Values of this type should always be converted into :class:`backend.dtype`\\ s using the
function :func:`probnum.backend.as_dtype` before further internal processing."""

_ArrayIndexLike = Union[
int,
slice,
type(Ellipsis),
None,
np.newaxis,
np.ndarray,
"probnum.backend.newaxis",
"probnum.backend.ndarray",
]
ArrayIndicesLike = Union[_ArrayIndexLike, Tuple[_ArrayIndexLike, ...]]
"""Type of the argument to the :meth:`__getitem__` method of a NumPy-like array type
such as :class:`np.ndarray`, :class:`probnum.linops.LinearOperator` or
such as :class:`probnum.backend.ndarray`, :class:`probnum.linops.LinearOperator` or
:class:`probnum.randvars.RandomVariable`."""

# Scalars, Arrays and Matrices
ScalarLike = Union[int, float, complex, numbers.Number, np.number]
ScalarLike = Union[ScalarType, int, float, complex, numbers.Number, np.number]
"""Type of a public API argument for supplying a scalar value.
Values of this type should always be converted into :class:`np.number`\\ s using the
function :func:`probnum.utils.as_scalar` before further internal processing."""
Values of this type should always be converted into :class:`ScalarType`\\ s using
the function :func:`probnum.backend.as_scalar` before further internal processing."""

ArrayLike = Union[_NumPyArrayLike, "jax.numpy.ndarray", "torch.Tensor"]
"""Type of a public API argument for supplying an array.
Values of this type should always be converted into :class:`np.ndarray`\\ s using
the function :func:`np.asarray` before further internal processing."""
Values of this type should always be converted into :class:`backend.ndarray`\\ s using
the function :func:`probnum.backend.as_array` before further internal processing."""

LinearOperatorLike = Union[
ArrayLike,
Expand All @@ -106,10 +101,15 @@
"""Type of a public API argument for supplying a finite-dimensional linear operator.
Values of this type should always be converted into :class:`probnum.linops.\\
LinearOperator`\\ s using the function :func:`probnum.linops.aslinop` before further
LinearOperator`\\ s using the function :func:`probnum.linops.as_linop` before further
internal processing."""

# Random Number Generation
SeedLike = Optional[int]
"""Type of a public API argument for supplying the seed of a random number generator.
Values of this type should always be converted to :class:`SeedType` using the function
:func:`probnum.backend.random.seed` before further internal processing."""

########################################################################################
# Other Types
Expand Down

0 comments on commit 5ea8668

Please sign in to comment.