From 1438bc938e7d867f68a9da8ef6ad50809d87610e Mon Sep 17 00:00:00 2001 From: Marvin Pfoertner Date: Thu, 27 Aug 2020 23:38:48 +0200 Subject: [PATCH] Simplify random variable arithmetic --- src/probnum/random_variables/_arithmetic.py | 31 +++-- .../random_variables/_random_variable.py | 117 +++++++----------- 2 files changed, 64 insertions(+), 84 deletions(-) diff --git a/src/probnum/random_variables/_arithmetic.py b/src/probnum/random_variables/_arithmetic.py index d15e2d9ce2..c977911f6e 100644 --- a/src/probnum/random_variables/_arithmetic.py +++ b/src/probnum/random_variables/_arithmetic.py @@ -4,44 +4,44 @@ import operator from typing import Any, Callable, Dict, Tuple, Union -from ._random_variable import RandomVariable as _RandomVariable +from ._random_variable import RandomVariable as _RandomVariable, asrandvar from ._dirac import Dirac as _Dirac from ._normal import Normal as _Normal -def add(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable: +def add(rv1: Any, rv2: Any) -> _RandomVariable: return _apply(_add_fns, rv1, rv2) -def sub(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable: +def sub(rv1: Any, rv2: Any) -> _RandomVariable: return _apply(_sub_fns, rv1, rv2) -def mul(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable: +def mul(rv1: Any, rv2: Any) -> _RandomVariable: return _apply(_mul_fns, rv1, rv2) -def matmul(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable: +def matmul(rv1: Any, rv2: Any) -> _RandomVariable: return _apply(_matmul_fns, rv1, rv2) -def truediv(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable: +def truediv(rv1: Any, rv2: Any) -> _RandomVariable: return _apply(_truediv_fns, rv1, rv2) -def floordiv(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable: +def floordiv(rv1: Any, rv2: Any) -> _RandomVariable: return _apply(_floordiv_fns, rv1, rv2) -def mod(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable: +def mod(rv1: Any, rv2: Any) -> _RandomVariable: return _apply(_mod_fns, rv1, rv2) -def divmod_(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable: +def divmod_(rv1: Any, rv2: Any) -> _RandomVariable: return _apply(_divmod_fns, rv1, rv2) -def pow_(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable: +def pow_(rv1: Any, rv2: Any) -> _RandomVariable: return _apply(_pow_fns, rv1, rv2) @@ -52,8 +52,15 @@ def pow_(rv1: _RandomVariable, rv2: _RandomVariable) -> _RandomVariable: def _apply( - op_registry: _OperatorRegistryType, rv1: _RandomVariable, rv2: _RandomVariable -) -> Union[_RandomVariable, type(NotImplemented)]: + op_registry: _OperatorRegistryType, + rv1: Any, + rv2: Any, +) -> Union[_RandomVariable]: + # Convert arguments to random variables + rv1 = asrandvar(rv1) + rv2 = asrandvar(rv2) + + # Search fitting method key = (type(rv1), type(rv2)) if key not in op_registry: diff --git a/src/probnum/random_variables/_random_variable.py b/src/probnum/random_variables/_random_variable.py index 45814eb69f..4627a6acba 100644 --- a/src/probnum/random_variables/_random_variable.py +++ b/src/probnum/random_variables/_random_variable.py @@ -581,140 +581,113 @@ def __abs__(self) -> "RandomVariable": RandomVariable with the correct shape is returned. """ - def __add__(self, other) -> "RandomVariable": - if not isinstance(other, RandomVariable): - return self + asrandvar(other) - + def __add__(self, other: Any) -> "RandomVariable": # pylint: disable=import-outside-toplevel,cyclic-import from ._arithmetic import add return add(self, other) - def __radd__(self, other) -> "RandomVariable": - if not isinstance(other, RandomVariable): - return asrandvar(other) + self - - return NotImplemented + def __radd__(self, other: Any) -> "RandomVariable": + # pylint: disable=import-outside-toplevel,cyclic-import + from ._arithmetic import add - def __sub__(self, other) -> "RandomVariable": - if not isinstance(other, RandomVariable): - return self - asrandvar(other) + return add(other, self) + def __sub__(self, other: Any) -> "RandomVariable": # pylint: disable=import-outside-toplevel,cyclic-import from ._arithmetic import sub return sub(self, other) - def __rsub__(self, other) -> "RandomVariable": - if not isinstance(other, RandomVariable): - return asrandvar(other) - self - - return NotImplemented + def __rsub__(self, other: Any) -> "RandomVariable": + # pylint: disable=import-outside-toplevel,cyclic-import + from ._arithmetic import sub - def __mul__(self, other) -> "RandomVariable": - if not isinstance(other, RandomVariable): - return self * asrandvar(other) + return sub(other, self) + def __mul__(self, other: Any) -> "RandomVariable": # pylint: disable=import-outside-toplevel,cyclic-import from ._arithmetic import mul return mul(self, other) - def __rmul__(self, other) -> "RandomVariable": - if not isinstance(other, RandomVariable): - return asrandvar(other) * self - - return NotImplemented + def __rmul__(self, other: Any) -> "RandomVariable": + # pylint: disable=import-outside-toplevel,cyclic-import + from ._arithmetic import mul - def __matmul__(self, other) -> "RandomVariable": - if not isinstance(other, RandomVariable): - return self @ asrandvar(other) + return mul(other, self) + def __matmul__(self, other: Any) -> "RandomVariable": # pylint: disable=import-outside-toplevel,cyclic-import from ._arithmetic import matmul return matmul(self, other) - def __rmatmul__(self, other) -> "RandomVariable": - if not isinstance(other, RandomVariable): - return asrandvar(other) @ self - - return NotImplemented + def __rmatmul__(self, other: Any) -> "RandomVariable": + # pylint: disable=import-outside-toplevel,cyclic-import + from ._arithmetic import matmul - def __truediv__(self, other) -> "RandomVariable": - if not isinstance(other, RandomVariable): - return self / asrandvar(other) + return matmul(other, self) + def __truediv__(self, other: Any) -> "RandomVariable": # pylint: disable=import-outside-toplevel,cyclic-import from ._arithmetic import truediv return truediv(self, other) - def __rtruediv__(self, other) -> "RandomVariable": - if not isinstance(other, RandomVariable): - return asrandvar(other) / self - - return NotImplemented + def __rtruediv__(self, other: Any) -> "RandomVariable": + # pylint: disable=import-outside-toplevel,cyclic-import + from ._arithmetic import truediv - def __floordiv__(self, other) -> "RandomVariable": - if not isinstance(other, RandomVariable): - return self // asrandvar(other) + return truediv(other, self) + def __floordiv__(self, other: Any) -> "RandomVariable": # pylint: disable=import-outside-toplevel,cyclic-import from ._arithmetic import floordiv return floordiv(self, other) - def __rfloordiv__(self, other) -> "RandomVariable": - if not isinstance(other, RandomVariable): - return asrandvar(other) // self - - return NotImplemented + def __rfloordiv__(self, other: Any) -> "RandomVariable": + # pylint: disable=import-outside-toplevel,cyclic-import + from ._arithmetic import floordiv - def __mod__(self, other) -> "RandomVariable": - if not isinstance(other, RandomVariable): - return self % asrandvar(other) + return floordiv(other, self) + def __mod__(self, other: Any) -> "RandomVariable": # pylint: disable=import-outside-toplevel,cyclic-import from ._arithmetic import mod return mod(self, other) - def __rmod__(self, other) -> "RandomVariable": - if not isinstance(other, RandomVariable): - return asrandvar(other) % self - - return NotImplemented + def __rmod__(self, other: Any) -> "RandomVariable": + # pylint: disable=import-outside-toplevel,cyclic-import + from ._arithmetic import mod - def __divmod__(self, other) -> "RandomVariable": - if not isinstance(other, RandomVariable): - return divmod(self, asrandvar(other)) + return mod(other, self) + def __divmod__(self, other: Any) -> "RandomVariable": # pylint: disable=import-outside-toplevel,cyclic-import from ._arithmetic import divmod_ return divmod_(self, other) - def __rdivmod__(self, other) -> "RandomVariable": - if not isinstance(other, RandomVariable): - return divmod(asrandvar(other), self) - - return NotImplemented + def __rdivmod__(self, other: Any) -> "RandomVariable": + # pylint: disable=import-outside-toplevel,cyclic-import + from ._arithmetic import divmod_ - def __pow__(self, other) -> "RandomVariable": - if not isinstance(other, RandomVariable): - return self ** asrandvar(other) + return divmod_(other, self) + def __pow__(self, other: Any) -> "RandomVariable": # pylint: disable=import-outside-toplevel,cyclic-import from ._arithmetic import pow_ return pow_(self, other) - def __rpow__(self, other) -> "RandomVariable": - if not isinstance(other, RandomVariable): - return asrandvar(other) ** self + def __rpow__(self, other: Any) -> "RandomVariable": + # pylint: disable=import-outside-toplevel,cyclic-import + from ._arithmetic import pow_ - return NotImplemented + return pow_(other, self) class DiscreteRandomVariable(RandomVariable[_ValueType]):