Skip to content

Commit

Permalink
Simplify random variable arithmetic
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinpfoertner committed Aug 28, 2020
1 parent deacf2b commit 1438bc9
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 84 deletions.
31 changes: 19 additions & 12 deletions src/probnum/random_variables/_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,44 @@
import operator
from typing import Any, Callable, Dict, Tuple, Union

from ._random_variable import RandomVariable as _RandomVariable
from ._random_variable import RandomVariable as _RandomVariable, asrandvar
from ._dirac import Dirac as _Dirac
from ._normal import Normal as _Normal


def add(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable:
def add(rv1: Any, rv2: Any) -> _RandomVariable:
return _apply(_add_fns, rv1, rv2)


def sub(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable:
def sub(rv1: Any, rv2: Any) -> _RandomVariable:
return _apply(_sub_fns, rv1, rv2)


def mul(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable:
def mul(rv1: Any, rv2: Any) -> _RandomVariable:
return _apply(_mul_fns, rv1, rv2)


def matmul(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable:
def matmul(rv1: Any, rv2: Any) -> _RandomVariable:
return _apply(_matmul_fns, rv1, rv2)


def truediv(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable:
def truediv(rv1: Any, rv2: Any) -> _RandomVariable:
return _apply(_truediv_fns, rv1, rv2)


def floordiv(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable:
def floordiv(rv1: Any, rv2: Any) -> _RandomVariable:
return _apply(_floordiv_fns, rv1, rv2)


def mod(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable:
def mod(rv1: Any, rv2: Any) -> _RandomVariable:
return _apply(_mod_fns, rv1, rv2)


def divmod_(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable:
def divmod_(rv1: Any, rv2: Any) -> _RandomVariable:
return _apply(_divmod_fns, rv1, rv2)


def pow_(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable:
def pow_(rv1: Any, rv2: Any) -> _RandomVariable:
return _apply(_pow_fns, rv1, rv2)


Expand All @@ -52,8 +52,15 @@ def pow_(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable:


def _apply(
op_registry: _OperatorRegistryType, rv1: _RandomVariable, rv2: _RandomVariable
) -> Union[_RandomVariable, type(NotImplemented)]:
op_registry: _OperatorRegistryType,
rv1: Any,
rv2: Any,
) -> Union[_RandomVariable]:
# Convert arguments to random variables
rv1 = asrandvar(rv1)
rv2 = asrandvar(rv2)

# Search fitting method
key = (type(rv1), type(rv2))

if key not in op_registry:
Expand Down
117 changes: 45 additions & 72 deletions src/probnum/random_variables/_random_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,140 +581,113 @@ def __abs__(self) -> "RandomVariable":
RandomVariable with the correct shape is returned.
"""

def __add__(self, other) -> "RandomVariable":
if not isinstance(other, RandomVariable):
return self + asrandvar(other)

def __add__(self, other: Any) -> "RandomVariable":
# pylint: disable=import-outside-toplevel,cyclic-import
from ._arithmetic import add

return add(self, other)

def __radd__(self, other) -> "RandomVariable":
if not isinstance(other, RandomVariable):
return asrandvar(other) + self

return NotImplemented
def __radd__(self, other: Any) -> "RandomVariable":
# pylint: disable=import-outside-toplevel,cyclic-import
from ._arithmetic import add

def __sub__(self, other) -> "RandomVariable":
if not isinstance(other, RandomVariable):
return self - asrandvar(other)
return add(other, self)

def __sub__(self, other: Any) -> "RandomVariable":
# pylint: disable=import-outside-toplevel,cyclic-import
from ._arithmetic import sub

return sub(self, other)

def __rsub__(self, other) -> "RandomVariable":
if not isinstance(other, RandomVariable):
return asrandvar(other) - self

return NotImplemented
def __rsub__(self, other: Any) -> "RandomVariable":
# pylint: disable=import-outside-toplevel,cyclic-import
from ._arithmetic import sub

def __mul__(self, other) -> "RandomVariable":
if not isinstance(other, RandomVariable):
return self * asrandvar(other)
return sub(other, self)

def __mul__(self, other: Any) -> "RandomVariable":
# pylint: disable=import-outside-toplevel,cyclic-import
from ._arithmetic import mul

return mul(self, other)

def __rmul__(self, other) -> "RandomVariable":
if not isinstance(other, RandomVariable):
return asrandvar(other) * self

return NotImplemented
def __rmul__(self, other: Any) -> "RandomVariable":
# pylint: disable=import-outside-toplevel,cyclic-import
from ._arithmetic import mul

def __matmul__(self, other) -> "RandomVariable":
if not isinstance(other, RandomVariable):
return self @ asrandvar(other)
return mul(other, self)

def __matmul__(self, other: Any) -> "RandomVariable":
# pylint: disable=import-outside-toplevel,cyclic-import
from ._arithmetic import matmul

return matmul(self, other)

def __rmatmul__(self, other) -> "RandomVariable":
if not isinstance(other, RandomVariable):
return asrandvar(other) @ self

return NotImplemented
def __rmatmul__(self, other: Any) -> "RandomVariable":
# pylint: disable=import-outside-toplevel,cyclic-import
from ._arithmetic import matmul

def __truediv__(self, other) -> "RandomVariable":
if not isinstance(other, RandomVariable):
return self / asrandvar(other)
return matmul(other, self)

def __truediv__(self, other: Any) -> "RandomVariable":
# pylint: disable=import-outside-toplevel,cyclic-import
from ._arithmetic import truediv

return truediv(self, other)

def __rtruediv__(self, other) -> "RandomVariable":
if not isinstance(other, RandomVariable):
return asrandvar(other) / self

return NotImplemented
def __rtruediv__(self, other: Any) -> "RandomVariable":
# pylint: disable=import-outside-toplevel,cyclic-import
from ._arithmetic import truediv

def __floordiv__(self, other) -> "RandomVariable":
if not isinstance(other, RandomVariable):
return self // asrandvar(other)
return truediv(other, self)

def __floordiv__(self, other: Any) -> "RandomVariable":
# pylint: disable=import-outside-toplevel,cyclic-import
from ._arithmetic import floordiv

return floordiv(self, other)

def __rfloordiv__(self, other) -> "RandomVariable":
if not isinstance(other, RandomVariable):
return asrandvar(other) // self

return NotImplemented
def __rfloordiv__(self, other: Any) -> "RandomVariable":
# pylint: disable=import-outside-toplevel,cyclic-import
from ._arithmetic import floordiv

def __mod__(self, other) -> "RandomVariable":
if not isinstance(other, RandomVariable):
return self % asrandvar(other)
return floordiv(other, self)

def __mod__(self, other: Any) -> "RandomVariable":
# pylint: disable=import-outside-toplevel,cyclic-import
from ._arithmetic import mod

return mod(self, other)

def __rmod__(self, other) -> "RandomVariable":
if not isinstance(other, RandomVariable):
return asrandvar(other) % self

return NotImplemented
def __rmod__(self, other: Any) -> "RandomVariable":
# pylint: disable=import-outside-toplevel,cyclic-import
from ._arithmetic import mod

def __divmod__(self, other) -> "RandomVariable":
if not isinstance(other, RandomVariable):
return divmod(self, asrandvar(other))
return mod(other, self)

def __divmod__(self, other: Any) -> "RandomVariable":
# pylint: disable=import-outside-toplevel,cyclic-import
from ._arithmetic import divmod_

return divmod_(self, other)

def __rdivmod__(self, other) -> "RandomVariable":
if not isinstance(other, RandomVariable):
return divmod(asrandvar(other), self)

return NotImplemented
def __rdivmod__(self, other: Any) -> "RandomVariable":
# pylint: disable=import-outside-toplevel,cyclic-import
from ._arithmetic import divmod_

def __pow__(self, other) -> "RandomVariable":
if not isinstance(other, RandomVariable):
return self ** asrandvar(other)
return divmod_(other, self)

def __pow__(self, other: Any) -> "RandomVariable":
# pylint: disable=import-outside-toplevel,cyclic-import
from ._arithmetic import pow_

return pow_(self, other)

def __rpow__(self, other) -> "RandomVariable":
if not isinstance(other, RandomVariable):
return asrandvar(other) ** self
def __rpow__(self, other: Any) -> "RandomVariable":
# pylint: disable=import-outside-toplevel,cyclic-import
from ._arithmetic import pow_

return NotImplemented
return pow_(other, self)


class DiscreteRandomVariable(RandomVariable[_ValueType]):
Expand Down

0 comments on commit 1438bc9

Please sign in to comment.