Skip to content

Commit

Permalink
Additional value checks for cdf, pdf, pmf
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinpfoertner committed Aug 28, 2020
1 parent f00c817 commit 4ce2dde
Showing 1 changed file with 71 additions and 48 deletions.
119 changes: 71 additions & 48 deletions src/probnum/random_variables/_random_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def __init__(
var: Optional[Callable[[], _ValueType]] = None,
std: Optional[Callable[[], _ValueType]] = None,
entropy: Optional[Callable[[], np.float_]] = None,
as_value_type: Optional[Callable[[Any], _ValueType]] = None,
):
# pylint: disable=too-many-arguments,too-many-locals
"""Create a new random variable."""
Expand Down Expand Up @@ -114,9 +115,12 @@ def __init__(
self.__std = std
self.__entropy = entropy

# Utilities
self._median_dtype = np.promote_types(self._dtype, np.float_)
self._moments_dtype = np.promote_types(self._dtype, np.float_)

self.__as_value_type = as_value_type

def __repr__(self) -> str:
return f"<{self.shape} {self.__class__.__name__} with dtype={self.dtype}>"

Expand Down Expand Up @@ -171,44 +175,6 @@ def parameters(self) -> Dict[str, Any]:
"""
return self._parameters.copy()

@staticmethod
def _check_property_value(
name: str,
value: Any,
shape: Optional[Tuple[int, ...]] = None,
dtype: Optional[np.dtype] = None,
):
if shape is not None:
if value.shape != shape:
raise ValueError(
f"The {name} of the random variable does not have the correct "
f"shape. Expected {shape} but got {value.shape}."
)

if dtype is not None:
if not np.issubdtype(value.dtype, dtype):
raise ValueError(
f"The {name} of the random variable does not have the correct "
f"dtype. Expected {dtype.name} but got {value.dtype.name}."
)

@classmethod
def _ensure_numpy_float(cls, name: str, value: Any) -> np.float_:
if not isinstance(value, np.float_):
try:
value = _utils.as_numpy_scalar(value, dtype=np.float_)
except TypeError as err:
raise TypeError(
f"The function `{name}` specified via the constructor of "
f"`{cls.__name__}` must return a scalar value that can be "
f"converted to a `np.float_`, which is possible for {value} "
f"of type {type(value)}."
) from err

assert isinstance(value, np.float_)

return value

@cached_property
def mode(self) -> _ValueType:
"""
Expand Down Expand Up @@ -403,7 +369,7 @@ def in_support(self, x: _ValueType) -> bool:
if self.__in_support is None:
raise NotImplementedError

in_support = self.__in_support(x)
in_support = self.__in_support(self._as_value_type(x))

if not isinstance(in_support, bool):
raise ValueError(
Expand Down Expand Up @@ -447,9 +413,11 @@ def cdf(self, x: _ValueType) -> np.float_:
Value of the cumulative density function at the given points.
"""
if self.__cdf is not None:
return RandomVariable._ensure_numpy_float("cdf", self.__cdf(x))
return RandomVariable._ensure_numpy_float(
"cdf", self.__cdf(self._as_value_type(x))
)
elif self.__logcdf is not None:
cdf = np.exp(self.logcdf(x))
cdf = np.exp(self.logcdf(self._as_value_type(x)))

assert isinstance(cdf, np.float_)

Expand All @@ -475,7 +443,9 @@ def logcdf(self, x: _ValueType) -> np.float_:
Value of the log-cumulative density function at the given points.
"""
if self.__logcdf is not None:
return RandomVariable._ensure_numpy_float("logcdf", self.__logcdf(x))
return RandomVariable._ensure_numpy_float(
"logcdf", self.__logcdf(self._as_value_type(x))
)
elif self.__cdf is not None:
logcdf = np.log(self.__cdf(x))

Expand All @@ -497,6 +467,13 @@ def quantile(self, p: FloatArgType) -> _ValueType:
if self.__quantile is None:
raise NotImplementedError

try:
p = _utils.as_numpy_scalar(p, dtype=np.floating)
except TypeError as exc:
raise TypeError(
"The given argument `p` can not be cast to a `np.floating` object."
) from exc

quantile = self.__quantile(p)

if quantile.shape != self._shape:
Expand Down Expand Up @@ -689,6 +666,50 @@ def __rpow__(self, other: Any) -> "RandomVariable":

return pow_(other, self)

def _as_value_type(self, x: Any) -> _ValueType:
if self.__as_value_type is not None:
return self.__as_value_type(x)

return x

@staticmethod
def _check_property_value(
name: str,
value: Any,
shape: Optional[Tuple[int, ...]] = None,
dtype: Optional[np.dtype] = None,
):
if shape is not None:
if value.shape != shape:
raise ValueError(
f"The {name} of the random variable does not have the correct "
f"shape. Expected {shape} but got {value.shape}."
)

if dtype is not None:
if not np.issubdtype(value.dtype, dtype):
raise ValueError(
f"The {name} of the random variable does not have the correct "
f"dtype. Expected {dtype.name} but got {value.dtype.name}."
)

@classmethod
def _ensure_numpy_float(cls, name: str, value: Any) -> np.float_:
if not isinstance(value, np.float_):
try:
value = _utils.as_numpy_scalar(value, dtype=np.float_)
except TypeError as err:
raise TypeError(
f"The function `{name}` specified via the constructor of "
f"`{cls.__name__}` must return a scalar value that can be "
f"converted to a `np.float_`, which is possible for {value} "
f"of type {type(value)}."
) from err

assert isinstance(value, np.float_)

return value


class DiscreteRandomVariable(RandomVariable[_ValueType]):
def __init__(
Expand Down Expand Up @@ -753,10 +774,10 @@ def pmf(self, x: _ValueType) -> np.float_:
def logpmf(self, x: _ValueType) -> np.float_:
if self.__logpmf is not None:
return DiscreteRandomVariable._ensure_numpy_float(
"logpmf", self.__logpmf(x)
"logpmf", self.__logpmf(self._as_value_type(x))
)
elif self.__pmf is not None:
logpmf = np.log(self.__pmf(x))
logpmf = np.log(self.__pmf(self._as_value_type(x)))

assert isinstance(logpmf, np.float_)

Expand Down Expand Up @@ -829,9 +850,11 @@ def pdf(self, x: _ValueType) -> np.float_:
"""
if self.__pdf is not None:
return ContinuousRandomVariable._ensure_numpy_float("pdf", self.__pdf(x))
return ContinuousRandomVariable._ensure_numpy_float(
"pdf", self.__pdf(self._as_value_type(x))
)
if self.__logpdf is not None:
pdf = np.exp(self.__logpdf(x))
pdf = np.exp(self.__logpdf(self._as_value_type(x)))

assert isinstance(pdf, np.float_)

Expand All @@ -857,10 +880,10 @@ def logpdf(self, x: _ValueType) -> np.float_:
"""
if self.__logpdf is not None:
return ContinuousRandomVariable._ensure_numpy_float(
"logpdf", self.__logpdf(x)
"logpdf", self.__logpdf(self._as_value_type(x))
)
elif self.__pdf is not None:
logpdf = np.log(self.__pdf(x))
logpdf = np.log(self.__pdf(self._as_value_type(x)))

assert isinstance(logpdf, np.float_)

Expand Down

0 comments on commit 4ce2dde

Please sign in to comment.