Skip to content

Commit

Permalink
Add type hints in probnum.core
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinpfoertner committed Aug 25, 2020
1 parent 8bdeb76 commit 533ca96
Show file tree
Hide file tree
Showing 11 changed files with 226 additions and 132 deletions.
Empty file added src/probnum/_lib/__init__.py
Empty file.
13 changes: 13 additions & 0 deletions src/probnum/_lib/argtypes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import numbers
from typing import Iterable, Tuple, Union

import numpy as np


IntArgType = Union[int, numbers.Integral, np.integer]
FloatArgType = Union[float, numbers.Real, np.floating]

ShapeArgType = Union[IntArgType, Iterable[IntArgType]]
DTypeArgType = Union[np.dtype, str]

RandomStateArgType = Union[None, int, np.random.RandomState, np.random.Generator]
15 changes: 13 additions & 2 deletions src/probnum/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,18 @@

from . import random_variables

from .random_variables import RandomVariable, asrandvar
from .random_variables import (
asrandvar,
RandomVariable,
DiscreteRandomVariable,
ContinuousRandomVariable,
)

# Public classes and functions. Order is reflected in documentation.
__all__ = ["random_variables", "RandomVariable", "asrandvar"]
__all__ = [
"random_variables",
"asrandvar",
"RandomVariable",
"DiscreteRandomVariable",
"ContinuousRandomVariable",
]
10 changes: 9 additions & 1 deletion src/probnum/core/random_variables/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
from ._random_variable import RandomVariable, asrandvar
from ._random_variable import (
asrandvar,
RandomVariable,
DiscreteRandomVariable,
ContinuousRandomVariable,
)

from ._dirac import Dirac
from ._normal import Normal
Expand All @@ -7,5 +12,8 @@
asrandvar.__module__ = "probnum.random_variables"

RandomVariable.__module__ = "probnum.random_variables"
DiscreteRandomVariable.__module__ = "probnum.random_variables"
ContinuousRandomVariable.__module__ = "probnum.random_variables"

Dirac.__module__ = "probnum.random_variables"
Normal.__module__ = "probnum.random_variables"
30 changes: 19 additions & 11 deletions src/probnum/core/random_variables/_arithmetic.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import operator
from typing import Any, Callable, Dict, Tuple, Union

from ._random_variable import RandomVariable as _RandomVariable
from probnum.core.random_variables import Dirac as _Dirac, Normal as _Normal
Expand Down Expand Up @@ -41,7 +42,14 @@ def pow_(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable:


# Operator registry
def _apply(op_registry, rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable:
_OperatorRegistryType = Dict[
Tuple[type, type], Callable[[_RandomVariable, _RandomVariable], _RandomVariable]
]


def _apply(
op_registry: _OperatorRegistryType, rv1: _RandomVariable, rv2: _RandomVariable
) -> Union[_RandomVariable, type(NotImplemented)]:
key = (type(rv1), type(rv2))

if key not in op_registry:
Expand All @@ -52,18 +60,18 @@ def _apply(op_registry, rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVa
return res


_add_fns = {}
_sub_fns = {}
_mul_fns = {}
_matmul_fns = {}
_truediv_fns = {}
_floordiv_fns = {}
_mod_fns = {}
_divmod_fns = {}
_pow_fns = {}
_add_fns: _OperatorRegistryType = {}
_sub_fns: _OperatorRegistryType = {}
_mul_fns: _OperatorRegistryType = {}
_matmul_fns: _OperatorRegistryType = {}
_truediv_fns: _OperatorRegistryType = {}
_floordiv_fns: _OperatorRegistryType = {}
_mod_fns: _OperatorRegistryType = {}
_divmod_fns: _OperatorRegistryType = {}
_pow_fns: _OperatorRegistryType = {}


def _swap_operands(fn):
def _swap_operands(fn: Callable[[Any, Any], Any]) -> Callable[[Any, Any], Any]:
return lambda op1, op2: fn(op2, op1)


Expand Down
27 changes: 13 additions & 14 deletions src/probnum/core/random_variables/_dirac.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import numpy as np

from probnum import utils as _utils
from probnum._lib.argtypes import RandomStateArgType, ShapeArgType
from probnum.typing import ShapeType

from . import _random_variable

Expand Down Expand Up @@ -43,9 +45,7 @@ class Dirac(_random_variable.DiscreteRandomVariable[_ValueType]):
"""

def __init__(
self,
support: _ValueType,
random_state: Optional[_random_variable.RandomStateType] = None,
self, support: _ValueType, random_state: RandomStateArgType = None,
):
if np.isscalar(support):
support = _utils.as_numpy_scalar(support)
Expand All @@ -59,8 +59,8 @@ def __init__(
parameters={"support": self._support},
sample=self._sample,
in_support=lambda x: np.all(x == self._support),
pmf=lambda x: 1.0 if np.all(x == self._support) else 0.0,
cdf=lambda x: 0.0 if np.any(x < self._support) else 0.0,
pmf=lambda x: np.float_(1.0 if np.all(x == self._support) else 0.0),
cdf=lambda x: np.float_(0.0 if np.any(x < self._support) else 0.0),
mode=lambda: self._support,
median=lambda: self._support,
mean=lambda: self._support,
Expand All @@ -76,10 +76,10 @@ def __init__(
)

@property
def support(self):
def support(self) -> _ValueType:
return self._support

def __getitem__(self, key):
def __getitem__(self, key) -> "Dirac":
"""
Marginalization for multivariate Dirac distributions, expressed by means of
(advanced) indexing, masking and slicing.
Expand All @@ -96,21 +96,20 @@ def __getitem__(self, key):
"""
return Dirac(support=self._support[key], random_state=self.random_state)

def reshape(self, newshape):
def reshape(self, newshape: ShapeType) -> "Dirac":
return Dirac(
support=self._support.reshape(newshape),
random_state=_utils.derive_random_seed(self.random_state),
)

def transpose(self, *axes):
def transpose(self, *axes: int) -> "Dirac":
return Dirac(
support=self._support.transpose(*axes),
random_state=_utils.derive_random_seed(self.random_state),
)

def _sample(self, size=()):
if isinstance(size, int):
size = (size,)
def _sample(self, size: ShapeArgType = ()) -> _ValueType:
size = _utils.as_shape(size)

if size == ():
return self._support.copy()
Expand Down Expand Up @@ -143,12 +142,12 @@ def __abs__(self) -> "Dirac":
def _binary_operator_factory(
operator: Callable[[_ValueType, _ValueType], _ValueType]
) -> Callable[["Dirac", "Dirac"], "Dirac"]:
def _dirac_operator(dirac_rv1: Dirac, dirac_rv2: Dirac) -> Dirac:
def _dirac_binary_operator(dirac_rv1: Dirac, dirac_rv2: Dirac) -> Dirac:
return Dirac(
support=operator(dirac_rv1, dirac_rv2),
random_state=_utils.derive_random_seed(
dirac_rv1.random_state, dirac_rv2.random_state,
),
)

return _dirac_operator
return _dirac_binary_operator
Loading

0 comments on commit 533ca96

Please sign in to comment.