From 16fa0fe5cfa1182c2f7b66acc914e7929c5ff56b Mon Sep 17 00:00:00 2001 From: Marvin Pfoertner Date: Mon, 24 Aug 2020 18:30:58 +0200 Subject: [PATCH] Normal random variable should cast arguments to floating point dtype --- src/probnum/core/random_variables/_normal.py | 25 +++++++++++++++++++ .../core/random_variables/_random_variable.py | 20 +++++++-------- 2 files changed, 35 insertions(+), 10 deletions(-) diff --git a/src/probnum/core/random_variables/_normal.py b/src/probnum/core/random_variables/_normal.py index 5a696696d9..fe9ddbcee6 100644 --- a/src/probnum/core/random_variables/_normal.py +++ b/src/probnum/core/random_variables/_normal.py @@ -68,6 +68,31 @@ def __init__( if np.isscalar(cov): cov = _utils.as_numpy_scalar(cov) + # Data type normalization + is_mean_floating = mean.dtype is not None and np.issubdtype( + mean.dtype, np.floating + ) + is_cov_floating = cov.dtype is not None and np.issubdtype( + cov.dtype, np.floating + ) + + if is_mean_floating and is_cov_floating: + dtype = np.promote_types(mean.dtype, cov.dtype) + elif is_mean_floating: + dtype = mean.dtype + elif is_cov_floating: + dtype = cov.dtype + else: + dtype = np.float_ + + # TODO: Implement casting for linear operators + if not isinstance(mean, linops.LinearOperator): + mean = mean.astype(dtype, order="C", casting="safe", subok=True, copy=False) + + # TODO: Implement casting for linear operators + if not isinstance(cov, linops.LinearOperator): + cov = cov.astype(dtype, order="C", casting="safe", subok=True, copy=False) + # Shape checking if len(mean.shape) not in [0, 1, 2]: raise ValueError( diff --git a/src/probnum/core/random_variables/_random_variable.py b/src/probnum/core/random_variables/_random_variable.py index 46219467df..f7c824c8c3 100644 --- a/src/probnum/core/random_variables/_random_variable.py +++ b/src/probnum/core/random_variables/_random_variable.py @@ -75,10 +75,10 @@ def __init__( parameters: Optional[Dict[str, Any]] = None, sample: Optional[Callable[[ShapeArgType], _ValueType]] = None, in_support: Optional[Callable[[_ValueType], bool]] = None, - pdf: Optional[Callable[[_ValueType], np.float64]] = None, - logpdf: Optional[Callable[[_ValueType], np.float64]] = None, - cdf: Optional[Callable[[_ValueType], np.float64]] = None, - logcdf: Optional[Callable[[_ValueType], np.float64]] = None, + pdf: Optional[Callable[[_ValueType], np.float_]] = None, + logpdf: Optional[Callable[[_ValueType], np.float_]] = None, + cdf: Optional[Callable[[_ValueType], np.float_]] = None, + logcdf: Optional[Callable[[_ValueType], np.float_]] = None, quantile: Optional[Callable[[FloatArgType], _ValueType]] = None, mode: Optional[Callable[[], _ValueType]] = None, median: Optional[Callable[[], _ValueType]] = None, @@ -86,7 +86,7 @@ def __init__( cov: Optional[Callable[[], _ValueType]] = None, var: Optional[Callable[[], _ValueType]] = None, std: Optional[Callable[[], _ValueType]] = None, - entropy: Optional[Callable[[], np.float64]] = None, + entropy: Optional[Callable[[], np.float_]] = None, ): """Create a new random variable.""" self._shape = RandomVariable._check_shape(shape) @@ -704,10 +704,10 @@ def __init__( parameters: Optional[Dict[str, Any]] = None, sample: Optional[Callable[[ShapeArgType], _ValueType]] = None, in_support: Optional[Callable[[_ValueType], bool]] = None, - pmf: Optional[Callable[[_ValueType], np.float64]] = None, - logpmf: Optional[Callable[[_ValueType], np.float64]] = None, - cdf: Optional[Callable[[_ValueType], np.float64]] = None, - logcdf: Optional[Callable[[_ValueType], np.float64]] = None, + pmf: Optional[Callable[[_ValueType], np.float_]] = None, + logpmf: Optional[Callable[[_ValueType], np.float_]] = None, + cdf: Optional[Callable[[_ValueType], np.float_]] = None, + logcdf: Optional[Callable[[_ValueType], np.float_]] = None, quantile: Optional[Callable[[FloatArgType], _ValueType]] = None, mode: Optional[Callable[[], _ValueType]] = None, median: Optional[Callable[[], _ValueType]] = None, @@ -715,7 +715,7 @@ def __init__( cov: Optional[Callable[[], _ValueType]] = None, var: Optional[Callable[[], _ValueType]] = None, std: Optional[Callable[[], _ValueType]] = None, - entropy: Optional[Callable[[], np.float64]] = None, + entropy: Optional[Callable[[], np.float_]] = None, ): # Probability mass function self.__pmf = pmf