diff --git a/src/probnum/_lib/argtypes.py b/src/probnum/_lib/argtypes.py index a305328f8d..f93b97ff8b 100644 --- a/src/probnum/_lib/argtypes.py +++ b/src/probnum/_lib/argtypes.py @@ -10,4 +10,14 @@ ShapeArgType = Union[IntArgType, Iterable[IntArgType]] DTypeArgType = Union[np.dtype, str] +ArrayLikeGetitemArgType = Union[ + int, + slice, + np.ndarray, + np.newaxis, + None, + type(Ellipsis), + Tuple[Union[int, slice, np.ndarray, np.newaxis, None, type(Ellipsis)], ...], +] + RandomStateArgType = Union[None, int, np.random.RandomState, np.random.Generator] diff --git a/src/probnum/core/random_variables/_dirac.py b/src/probnum/core/random_variables/_dirac.py index b4adb10361..7a710ba2ff 100644 --- a/src/probnum/core/random_variables/_dirac.py +++ b/src/probnum/core/random_variables/_dirac.py @@ -3,7 +3,11 @@ import numpy as np from probnum import utils as _utils -from probnum._lib.argtypes import RandomStateArgType, ShapeArgType +from probnum._lib.argtypes import ( + ArrayLikeGetitemArgType, + RandomStateArgType, + ShapeArgType, +) from probnum.typing import ShapeType from . import _random_variable @@ -79,7 +83,7 @@ def __init__( def support(self) -> _ValueType: return self._support - def __getitem__(self, key) -> "Dirac": + def __getitem__(self, key: ArrayLikeGetitemArgType) -> "Dirac": """ Marginalization for multivariate Dirac distributions, expressed by means of (advanced) indexing, masking and slicing. diff --git a/src/probnum/core/random_variables/_normal.py b/src/probnum/core/random_variables/_normal.py index 0072398a16..f6cebf3047 100644 --- a/src/probnum/core/random_variables/_normal.py +++ b/src/probnum/core/random_variables/_normal.py @@ -4,7 +4,11 @@ import scipy.stats from probnum import utils as _utils -from probnum._lib.argtypes import RandomStateArgType, ShapeArgType +from probnum._lib.argtypes import ( + ArrayLikeGetitemArgType, + RandomStateArgType, + ShapeArgType, +) from probnum.linalg import linops from probnum.typing import ShapeType @@ -223,7 +227,7 @@ def __init__( entropy=entropy, ) - def __getitem__(self, key) -> "Normal": + def __getitem__(self, key: ArrayLikeGetitemArgType) -> "Normal": """ Marginalization in multi- and matrixvariate normal distributions, expressed by means of (advanced) indexing, masking and slicing. @@ -486,7 +490,7 @@ def _scale(self, scalar, other_random_state=None): ) # Univariate Gaussians - def _univariate_sample(self, size: ShapeArgType = ()) -> np.floating: + def _univariate_sample(self, size: ShapeType = ()) -> np.floating: sample = scipy.stats.norm.rvs( loc=self._mean, scale=self._cov, size=size, random_state=self.random_state ) @@ -496,7 +500,7 @@ def _univariate_sample(self, size: ShapeArgType = ()) -> np.floating: else: sample = sample.astype(self.dtype) - assert sample.shape == _utils.as_shape(size) + assert sample.shape == size return sample @@ -532,7 +536,7 @@ def _univariate_entropy(self: _ValueType) -> np.float_: ) # Multi- and matrixvariate Gaussians with dense covariance - def _dense_sample(self, size: ShapeArgType = ()) -> np.ndarray: + def _dense_sample(self, size: ShapeType = ()) -> np.ndarray: sample = scipy.stats.multivariate_normal.rvs( mean=self._mean.ravel(), cov=self._cov, @@ -591,7 +595,7 @@ def _operatorvariate_params_todense(self) -> Tuple[np.ndarray, np.ndarray]: return mean, self._cov.todense() - def _operatorvariate_sample(self, size: ShapeArgType = ()) -> np.ndarray: + def _operatorvariate_sample(self, size: ShapeType = ()) -> np.ndarray: mean, cov = self._operatorvariate_params_todense() sample = scipy.stats.multivariate_normal.rvs( @@ -603,15 +607,13 @@ def _operatorvariate_sample(self, size: ShapeArgType = ()) -> np.ndarray: # Operatorvariate Gaussian with symmetric Kronecker covariance from identical # factors def _symmetric_kronecker_identical_factors_sample( - self, size: ShapeArgType = () + self, size: ShapeType = () ) -> np.ndarray: assert isinstance(self._cov, linops.SymmetricKronecker) and self._cov._ABequal n = self._mean.shape[1] # Draw standard normal samples - size = _utils.as_shape(size) - size_sample = (n * n,) + size stdnormal_samples = scipy.stats.norm.rvs( diff --git a/src/probnum/core/random_variables/_random_variable.py b/src/probnum/core/random_variables/_random_variable.py index 29f9c46e81..e17f91c9dc 100644 --- a/src/probnum/core/random_variables/_random_variable.py +++ b/src/probnum/core/random_variables/_random_variable.py @@ -74,7 +74,7 @@ def __init__( dtype: DTypeArgType, random_state: RandomStateArgType = None, parameters: Optional[Dict[str, Any]] = None, - sample: Optional[Callable[[ShapeArgType], _ValueType]] = None, + sample: Optional[Callable[[ShapeType], _ValueType]] = None, in_support: Optional[Callable[[_ValueType], bool]] = None, cdf: Optional[Callable[[_ValueType], np.float_]] = None, logcdf: Optional[Callable[[_ValueType], np.float_]] = None, @@ -352,7 +352,7 @@ def in_support(self, x: _ValueType) -> bool: return self.__in_support(x) - def sample(self, size=()) -> _ValueType: + def sample(self, size: ShapeArgType = ()) -> _ValueType: """ Draw realizations from a random variable. @@ -371,7 +371,7 @@ def sample(self, size=()) -> _ValueType: return self.__sample(size=_utils.as_shape(size)) - def cdf(self, x) -> np.float_: + def cdf(self, x: _ValueType) -> np.float_: """ Cumulative distribution function. @@ -396,7 +396,7 @@ def cdf(self, x) -> np.float_: ) ) - def logcdf(self, x) -> np.float_: + def logcdf(self, x: _ValueType) -> np.float_: """ Log-cumulative distribution function. @@ -681,7 +681,7 @@ def pmf(self, x: _ValueType) -> np.float_: else: raise NotImplementedError - def logpmf(self, x) -> np.float_: + def logpmf(self, x: _ValueType) -> np.float_: if self.__logpmf is not None: return self.__logpmf(x) elif self.__pmf is not None: @@ -735,7 +735,7 @@ def __init__( entropy=entropy, ) - def pdf(self, x) -> np.float_: + def pdf(self, x: _ValueType) -> np.float_: """ Probability density or mass function. @@ -760,7 +760,7 @@ def pdf(self, x) -> np.float_: ) ) - def logpdf(self, x) -> np.float_: + def logpdf(self, x: _ValueType) -> np.float_: """ Natural logarithm of the probability density function.