diff --git a/src/probnum/_lib/__init__.py b/src/probnum/_lib/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/src/probnum/_lib/argtypes.py b/src/probnum/_lib/argtypes.py deleted file mode 100644 index f93b97ff8b..0000000000 --- a/src/probnum/_lib/argtypes.py +++ /dev/null @@ -1,23 +0,0 @@ -import numbers -from typing import Iterable, Tuple, Union - -import numpy as np - - -IntArgType = Union[int, numbers.Integral, np.integer] -FloatArgType = Union[float, numbers.Real, np.floating] - -ShapeArgType = Union[IntArgType, Iterable[IntArgType]] -DTypeArgType = Union[np.dtype, str] - -ArrayLikeGetitemArgType = Union[ - int, - slice, - np.ndarray, - np.newaxis, - None, - type(Ellipsis), - Tuple[Union[int, slice, np.ndarray, np.newaxis, None, type(Ellipsis)], ...], -] - -RandomStateArgType = Union[None, int, np.random.RandomState, np.random.Generator] diff --git a/src/probnum/random_variables/_dirac.py b/src/probnum/random_variables/_dirac.py index b92f549b42..9716498d00 100644 --- a/src/probnum/random_variables/_dirac.py +++ b/src/probnum/random_variables/_dirac.py @@ -5,12 +5,13 @@ import numpy as np from probnum import utils as _utils -from probnum._lib.argtypes import ( +from probnum.typing import ( + ShapeType, + # Argument Types ArrayLikeGetitemArgType, RandomStateArgType, ShapeArgType, ) -from probnum.typing import ShapeType from . import _random_variable diff --git a/src/probnum/random_variables/_normal.py b/src/probnum/random_variables/_normal.py index a7fba62627..177cdbeeee 100644 --- a/src/probnum/random_variables/_normal.py +++ b/src/probnum/random_variables/_normal.py @@ -6,14 +6,15 @@ import scipy.stats from probnum import utils as _utils -from probnum._lib.argtypes import ( +from probnum.linalg import linops +from probnum.typing import ( + ShapeType, + # Argument Types ArrayLikeGetitemArgType, FloatArgType, RandomStateArgType, ShapeArgType, ) -from probnum.linalg import linops -from probnum.typing import ShapeType from . import _random_variable diff --git a/src/probnum/random_variables/_random_variable.py b/src/probnum/random_variables/_random_variable.py index 625edcb1d0..89aa9ca72a 100644 --- a/src/probnum/random_variables/_random_variable.py +++ b/src/probnum/random_variables/_random_variable.py @@ -10,13 +10,15 @@ import numpy as np from probnum import utils as _utils -from probnum._lib.argtypes import ( +from probnum.typing import ( + RandomStateType, + ShapeType, + # Argument Types DTypeArgType, FloatArgType, RandomStateArgType, ShapeArgType, ) -from probnum.typing import RandomStateType, ShapeType try: # functools.cached_property is only available in Python >=3.8 diff --git a/src/probnum/typing.py b/src/probnum/typing.py index f98fec1988..92a9e075d6 100644 --- a/src/probnum/typing.py +++ b/src/probnum/typing.py @@ -1,7 +1,36 @@ -from typing import Tuple, Union +import numbers +from typing import Iterable, Tuple, Union import numpy as np +######################################################################################## +# API Types +######################################################################################## + ShapeType = Tuple[int, ...] RandomStateType = Union[np.random.RandomState, np.random.Generator] + +######################################################################################## +# Argument Types +######################################################################################## + +IntArgType = Union[int, numbers.Integral, np.integer] +FloatArgType = Union[float, numbers.Real, np.floating] + +ShapeArgType = Union[IntArgType, Iterable[IntArgType]] +DTypeArgType = Union[np.dtype, str] + +ScalarArgType = Union[int, float, complex, numbers.Number, np.float_] + +ArrayLikeGetitemArgType = Union[ + int, + slice, + np.ndarray, + np.newaxis, + None, + type(Ellipsis), + Tuple[Union[int, slice, np.ndarray, np.newaxis, None, type(Ellipsis)], ...], +] + +RandomStateArgType = Union[None, int, np.random.RandomState, np.random.Generator] diff --git a/src/probnum/utils/__init__.py b/src/probnum/utils/__init__.py index 64e92b37bd..a3760107a8 100644 --- a/src/probnum/utils/__init__.py +++ b/src/probnum/utils/__init__.py @@ -2,7 +2,6 @@ from .arrayutils import * from .fctutils import * from .randomutils import * -from .scalarutils import * # Public classes and functions. Order is reflected in documentation. __all__ = [ diff --git a/src/probnum/utils/argutils.py b/src/probnum/utils/argutils.py index e68bfa8040..25f97b6a5e 100644 --- a/src/probnum/utils/argutils.py +++ b/src/probnum/utils/argutils.py @@ -3,8 +3,20 @@ import numpy as np import scipy._lib._util -from probnum.typing import ShapeType, RandomStateType -from probnum._lib.argtypes import ShapeArgType, RandomStateArgType +from probnum.typing import ( + DTypeArgType, + RandomStateArgType, + RandomStateType, + ScalarArgType, + ShapeArgType, + ShapeType, +) + +__all__ = ["as_shape", "as_random_state", "as_numpy_scalar"] + + +def as_random_state(x: RandomStateArgType) -> RandomStateType: + return scipy._lib._util.check_random_state(x) def as_shape(x: ShapeArgType) -> ShapeType: @@ -26,5 +38,11 @@ def as_shape(x: ShapeArgType) -> ShapeType: return tuple(int(item) for item in x) -def as_random_state(x: RandomStateArgType) -> RandomStateType: - return scipy._lib._util.check_random_state(x) +def as_numpy_scalar(x: ScalarArgType, dtype: DTypeArgType = None) -> np.generic: + is_scalar = np.isscalar(x) + is_scalar_array = isinstance(x, np.ndarray) and x.ndim == 0 + + if not (is_scalar or is_scalar_array): + raise ValueError("The given input is not a scalar.") + + return np.asarray(x, dtype=dtype)[()] diff --git a/src/probnum/utils/scalarutils.py b/src/probnum/utils/scalarutils.py deleted file mode 100644 index a1ff4ebf9a..0000000000 --- a/src/probnum/utils/scalarutils.py +++ /dev/null @@ -1,10 +0,0 @@ -import numpy as np - -__all__ = ["as_numpy_scalar"] - - -def as_numpy_scalar(x, dtype=None): - if not np.isscalar(x): - raise ValueError("The given input is not a scalar") - - return np.asarray([x], dtype=dtype)[0]