Skip to content

Commit

Permalink
Add additional type hints
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinpfoertner committed Aug 26, 2020
1 parent 7e8ea3a commit 4f92600
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 18 deletions.
10 changes: 10 additions & 0 deletions src/probnum/_lib/argtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,14 @@
ShapeArgType = Union[IntArgType, Iterable[IntArgType]]
DTypeArgType = Union[np.dtype, str]

ArrayLikeGetitemArgType = Union[
int,
slice,
np.ndarray,
np.newaxis,
None,
type(Ellipsis),
Tuple[Union[int, slice, np.ndarray, np.newaxis, None, type(Ellipsis)], ...],
]

RandomStateArgType = Union[None, int, np.random.RandomState, np.random.Generator]
8 changes: 6 additions & 2 deletions src/probnum/core/random_variables/_dirac.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
import numpy as np

from probnum import utils as _utils
from probnum._lib.argtypes import RandomStateArgType, ShapeArgType
from probnum._lib.argtypes import (
ArrayLikeGetitemArgType,
RandomStateArgType,
ShapeArgType,
)
from probnum.typing import ShapeType

from . import _random_variable
Expand Down Expand Up @@ -79,7 +83,7 @@ def __init__(
def support(self) -> _ValueType:
return self._support

def __getitem__(self, key) -> "Dirac":
def __getitem__(self, key: ArrayLikeGetitemArgType) -> "Dirac":
"""
Marginalization for multivariate Dirac distributions, expressed by means of
(advanced) indexing, masking and slicing.
Expand Down
20 changes: 11 additions & 9 deletions src/probnum/core/random_variables/_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
import scipy.stats

from probnum import utils as _utils
from probnum._lib.argtypes import RandomStateArgType, ShapeArgType
from probnum._lib.argtypes import (
ArrayLikeGetitemArgType,
RandomStateArgType,
ShapeArgType,
)
from probnum.linalg import linops
from probnum.typing import ShapeType

Expand Down Expand Up @@ -223,7 +227,7 @@ def __init__(
entropy=entropy,
)

def __getitem__(self, key) -> "Normal":
def __getitem__(self, key: ArrayLikeGetitemArgType) -> "Normal":
"""
Marginalization in multi- and matrixvariate normal distributions, expressed by
means of (advanced) indexing, masking and slicing.
Expand Down Expand Up @@ -486,7 +490,7 @@ def _scale(self, scalar, other_random_state=None):
)

# Univariate Gaussians
def _univariate_sample(self, size: ShapeArgType = ()) -> np.floating:
def _univariate_sample(self, size: ShapeType = ()) -> np.floating:
sample = scipy.stats.norm.rvs(
loc=self._mean, scale=self._cov, size=size, random_state=self.random_state
)
Expand All @@ -496,7 +500,7 @@ def _univariate_sample(self, size: ShapeArgType = ()) -> np.floating:
else:
sample = sample.astype(self.dtype)

assert sample.shape == _utils.as_shape(size)
assert sample.shape == size

return sample

Expand Down Expand Up @@ -532,7 +536,7 @@ def _univariate_entropy(self: _ValueType) -> np.float_:
)

# Multi- and matrixvariate Gaussians with dense covariance
def _dense_sample(self, size: ShapeArgType = ()) -> np.ndarray:
def _dense_sample(self, size: ShapeType = ()) -> np.ndarray:
sample = scipy.stats.multivariate_normal.rvs(
mean=self._mean.ravel(),
cov=self._cov,
Expand Down Expand Up @@ -591,7 +595,7 @@ def _operatorvariate_params_todense(self) -> Tuple[np.ndarray, np.ndarray]:

return mean, self._cov.todense()

def _operatorvariate_sample(self, size: ShapeArgType = ()) -> np.ndarray:
def _operatorvariate_sample(self, size: ShapeType = ()) -> np.ndarray:
mean, cov = self._operatorvariate_params_todense()

sample = scipy.stats.multivariate_normal.rvs(
Expand All @@ -603,15 +607,13 @@ def _operatorvariate_sample(self, size: ShapeArgType = ()) -> np.ndarray:
# Operatorvariate Gaussian with symmetric Kronecker covariance from identical
# factors
def _symmetric_kronecker_identical_factors_sample(
self, size: ShapeArgType = ()
self, size: ShapeType = ()
) -> np.ndarray:
assert isinstance(self._cov, linops.SymmetricKronecker) and self._cov._ABequal

n = self._mean.shape[1]

# Draw standard normal samples
size = _utils.as_shape(size)

size_sample = (n * n,) + size

stdnormal_samples = scipy.stats.norm.rvs(
Expand Down
14 changes: 7 additions & 7 deletions src/probnum/core/random_variables/_random_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def __init__(
dtype: DTypeArgType,
random_state: RandomStateArgType = None,
parameters: Optional[Dict[str, Any]] = None,
sample: Optional[Callable[[ShapeArgType], _ValueType]] = None,
sample: Optional[Callable[[ShapeType], _ValueType]] = None,
in_support: Optional[Callable[[_ValueType], bool]] = None,
cdf: Optional[Callable[[_ValueType], np.float_]] = None,
logcdf: Optional[Callable[[_ValueType], np.float_]] = None,
Expand Down Expand Up @@ -352,7 +352,7 @@ def in_support(self, x: _ValueType) -> bool:

return self.__in_support(x)

def sample(self, size=()) -> _ValueType:
def sample(self, size: ShapeArgType = ()) -> _ValueType:
"""
Draw realizations from a random variable.
Expand All @@ -371,7 +371,7 @@ def sample(self, size=()) -> _ValueType:

return self.__sample(size=_utils.as_shape(size))

def cdf(self, x) -> np.float_:
def cdf(self, x: _ValueType) -> np.float_:
"""
Cumulative distribution function.
Expand All @@ -396,7 +396,7 @@ def cdf(self, x) -> np.float_:
)
)

def logcdf(self, x) -> np.float_:
def logcdf(self, x: _ValueType) -> np.float_:
"""
Log-cumulative distribution function.
Expand Down Expand Up @@ -681,7 +681,7 @@ def pmf(self, x: _ValueType) -> np.float_:
else:
raise NotImplementedError

def logpmf(self, x) -> np.float_:
def logpmf(self, x: _ValueType) -> np.float_:
if self.__logpmf is not None:
return self.__logpmf(x)
elif self.__pmf is not None:
Expand Down Expand Up @@ -735,7 +735,7 @@ def __init__(
entropy=entropy,
)

def pdf(self, x) -> np.float_:
def pdf(self, x: _ValueType) -> np.float_:
"""
Probability density or mass function.
Expand All @@ -760,7 +760,7 @@ def pdf(self, x) -> np.float_:
)
)

def logpdf(self, x) -> np.float_:
def logpdf(self, x: _ValueType) -> np.float_:
"""
Natural logarithm of the probability density function.
Expand Down

0 comments on commit 4f92600

Please sign in to comment.