From 6a2f4ec169817e64fe9f72c2e15cfe6821c470ae Mon Sep 17 00:00:00 2001 From: Marvin Pfoertner Date: Tue, 7 Dec 2021 16:52:21 +0100 Subject: [PATCH] Remove as_scalar from utils --- src/probnum/backend/_core/__init__.py | 18 ++++++++++++ src/probnum/randprocs/kernels/_linear.py | 4 +-- src/probnum/randprocs/kernels/_polynomial.py | 6 ++-- .../randprocs/kernels/_rational_quadratic.py | 8 ++---- src/probnum/randprocs/kernels/_white_noise.py | 4 +-- src/probnum/typing.py | 2 +- src/probnum/utils/__init__.py | 1 - src/probnum/utils/argutils.py | 28 ++----------------- 8 files changed, 31 insertions(+), 40 deletions(-) diff --git a/src/probnum/backend/_core/__init__.py b/src/probnum/backend/_core/__init__.py index 6b61a62866..4057cd6ee8 100644 --- a/src/probnum/backend/_core/__init__.py +++ b/src/probnum/backend/_core/__init__.py @@ -1,4 +1,5 @@ from probnum import backend as _backend +from probnum.typing import ArrayType, DTypeArgType, ScalarArgType if _backend.BACKEND is _backend.Backend.NUMPY: from . import _numpy as _core @@ -73,3 +74,20 @@ # Just-in-Time Compilation jit = _core.jit jit_method = _core.jit_method + + +def as_scalar(x: ScalarArgType, dtype: DTypeArgType = None) -> ArrayType: + """Convert a scalar into a NumPy scalar. + + Parameters + ---------- + x + Scalar value. + dtype + Data type of the scalar. + """ + + if ndim(x) != 0: + raise ValueError("The given input is not a scalar.") + + return asarray(x, dtype=dtype)[()] diff --git a/src/probnum/randprocs/kernels/_linear.py b/src/probnum/randprocs/kernels/_linear.py index bcb1888eaf..968eaa35a1 100644 --- a/src/probnum/randprocs/kernels/_linear.py +++ b/src/probnum/randprocs/kernels/_linear.py @@ -2,7 +2,7 @@ from typing import Optional -from probnum import backend, utils +from probnum import backend from probnum.typing import ArrayType, IntArgType, ScalarArgType from ._kernel import Kernel @@ -39,7 +39,7 @@ class Linear(Kernel): """ def __init__(self, input_dim: IntArgType, constant: ScalarArgType = 0.0): - self.constant = utils.as_scalar(constant) + self.constant = backend.as_scalar(constant) super().__init__(input_dim=input_dim) @backend.jit_method diff --git a/src/probnum/randprocs/kernels/_polynomial.py b/src/probnum/randprocs/kernels/_polynomial.py index df0d88c47d..a6fa60f010 100644 --- a/src/probnum/randprocs/kernels/_polynomial.py +++ b/src/probnum/randprocs/kernels/_polynomial.py @@ -2,7 +2,7 @@ from typing import Optional -from probnum import backend, utils +from probnum import backend from probnum.typing import ArrayType, IntArgType, ScalarArgType from ._kernel import Kernel @@ -46,8 +46,8 @@ def __init__( constant: ScalarArgType = 0.0, exponent: IntArgType = 1.0, ): - self.constant = utils.as_scalar(constant) - self.exponent = utils.as_scalar(exponent) + self.constant = backend.as_scalar(constant) + self.exponent = backend.as_scalar(exponent) super().__init__(input_dim=input_dim) @backend.jit_method diff --git a/src/probnum/randprocs/kernels/_rational_quadratic.py b/src/probnum/randprocs/kernels/_rational_quadratic.py index 470b2fb6e8..15e0998254 100644 --- a/src/probnum/randprocs/kernels/_rational_quadratic.py +++ b/src/probnum/randprocs/kernels/_rational_quadratic.py @@ -2,9 +2,7 @@ from typing import Optional -import numpy as np - -from probnum import backend, utils +from probnum import backend from probnum.typing import ArrayType, IntArgType, ScalarArgType from ._kernel import IsotropicMixin, Kernel @@ -62,8 +60,8 @@ def __init__( lengthscale: ScalarArgType = 1.0, alpha: ScalarArgType = 1.0, ): - self.lengthscale = utils.as_scalar(lengthscale) - self.alpha = utils.as_scalar(alpha) + self.lengthscale = backend.as_scalar(lengthscale) + self.alpha = backend.as_scalar(alpha) if not self.alpha > 0: raise ValueError(f"Scale mixture alpha={self.alpha} must be positive.") super().__init__(input_dim=input_dim) diff --git a/src/probnum/randprocs/kernels/_white_noise.py b/src/probnum/randprocs/kernels/_white_noise.py index 770a281eaa..81c1d1dc72 100644 --- a/src/probnum/randprocs/kernels/_white_noise.py +++ b/src/probnum/randprocs/kernels/_white_noise.py @@ -2,7 +2,7 @@ from typing import Optional -from probnum import backend, utils +from probnum import backend from probnum.typing import ArrayType, IntArgType, ScalarArgType from ._kernel import Kernel @@ -25,7 +25,7 @@ class WhiteNoise(Kernel): """ def __init__(self, input_dim: IntArgType, sigma: ScalarArgType = 1.0): - self.sigma = utils.as_scalar(sigma) + self.sigma = backend.as_scalar(sigma) self._sigma_sq = self.sigma ** 2 super().__init__(input_dim=input_dim) diff --git a/src/probnum/typing.py b/src/probnum/typing.py index 1ec1d969dc..28e84131ad 100644 --- a/src/probnum/typing.py +++ b/src/probnum/typing.py @@ -59,7 +59,7 @@ ScalarArgType = Union[int, float, complex, numbers.Number, np.number] """Type of a public API argument for supplying a scalar value. Values of this type should always be converted into :class:`np.generic` using the function -:func:`probnum.utils.as_scalar` before further internal processing.""" +:func:`probnum.backend.as_scalar` before further internal processing.""" LinearOperatorArgType = Union[ np.ndarray, diff --git a/src/probnum/utils/__init__.py b/src/probnum/utils/__init__.py index 032e42157a..c89157c2e3 100644 --- a/src/probnum/utils/__init__.py +++ b/src/probnum/utils/__init__.py @@ -7,7 +7,6 @@ __all__ = [ "as_colvec", "atleast_1d", - "as_scalar", "as_numpy_scalar", "as_shape", ] diff --git a/src/probnum/utils/argutils.py b/src/probnum/utils/argutils.py index 3fa51416ce..25668536d8 100644 --- a/src/probnum/utils/argutils.py +++ b/src/probnum/utils/argutils.py @@ -5,16 +5,9 @@ import numpy as np -from probnum import backend -from probnum.typing import ( - ArrayType, - DTypeArgType, - ScalarArgType, - ShapeArgType, - ShapeType, -) +from probnum.typing import DTypeArgType, ScalarArgType, ShapeArgType, ShapeType -__all__ = ["as_shape", "as_numpy_scalar", "as_scalar"] +__all__ = ["as_shape", "as_numpy_scalar"] def as_shape(x: ShapeArgType, ndim: Optional[numbers.Integral] = None) -> ShapeType: @@ -64,20 +57,3 @@ def as_numpy_scalar(x: ScalarArgType, dtype: DTypeArgType = None) -> np.generic: raise ValueError("The given input is not a scalar.") return np.asarray(x, dtype=dtype)[()] - - -def as_scalar(x: ScalarArgType, dtype: DTypeArgType = None) -> ArrayType: - """Convert a scalar into a NumPy scalar. - - Parameters - ---------- - x - Scalar value. - dtype - Data type of the scalar. - """ - - if backend.ndim(x) != 0: - raise ValueError("The given input is not a scalar.") - - return backend.asarray(x, dtype=dtype)[()]