diff --git a/MuyGPyS/gp/mean.py b/MuyGPyS/gp/mean.py index 310a8514..36ecf168 100644 --- a/MuyGPyS/gp/mean.py +++ b/MuyGPyS/gp/mean.py @@ -14,6 +14,19 @@ from MuyGPyS.gp.noise import NoiseFn +def _reshaper_to_be_removed(fn: Callable) -> Callable: + def reshaped_fn(Kin, Kcross, *args, **kwargs): + if len(Kcross.shape) == 2: + batch_count, nn_count = Kcross.shape + Kcross = Kcross.reshape(batch_count, 1, nn_count) + ret = fn(Kin, Kcross, *args, **kwargs) + if len(ret.shape) == 1: + ret = ret.reshape(ret.shape[0], 1) + return ret + + return reshaped_fn + + class PosteriorMean: def __init__( self, @@ -23,6 +36,7 @@ def __init__( ): self._fn = _backend_fn self._fn = noise.perturb_fn(self._fn) + self._fn = _reshaper_to_be_removed(self._fn) def __call__( self, @@ -31,13 +45,7 @@ def __call__( batch_nn_targets: mm.ndarray, **kwargs, ) -> mm.ndarray: - if len(Kcross.shape) == 2: - batch_count, nn_count = Kcross.shape - Kcross = Kcross.reshape(batch_count, 1, nn_count) - responses = self._fn(Kin, Kcross, batch_nn_targets, **kwargs) - if len(responses.shape) == 1: - responses = responses.reshape(responses.shape[0], 1) - return responses + return self._fn(Kin, Kcross, batch_nn_targets, **kwargs) def get_opt_fn(self) -> Callable: return self.__call__ diff --git a/MuyGPyS/gp/variance.py b/MuyGPyS/gp/variance.py index 668e1264..4f87a812 100644 --- a/MuyGPyS/gp/variance.py +++ b/MuyGPyS/gp/variance.py @@ -13,6 +13,7 @@ from MuyGPyS._src.gp.muygps import _muygps_diagonal_variance from MuyGPyS.gp.hyperparameter import ScaleFn from MuyGPyS.gp.noise import NoiseFn +from MuyGPyS.gp.mean import _reshaper_to_be_removed class PosteriorVariance: @@ -20,13 +21,12 @@ def __init__( self, noise: NoiseFn, scale: ScaleFn, - apply_scale: bool = True, _backend_fn: Callable = _muygps_diagonal_variance, ): self._fn = _backend_fn self._fn = noise.perturb_fn(self._fn) - if apply_scale is True: - self._fn = scale.scale_fn(self._fn) + self._opt_fn = _reshaper_to_be_removed(self._fn) + self._fn = _reshaper_to_be_removed(scale.scale_fn(self._fn)) def __call__( self, @@ -34,13 +34,7 @@ def __call__( Kcross: mm.ndarray, **kwargs, ) -> mm.ndarray: - if len(Kcross.shape) == 2: - batch_count, nn_count = Kcross.shape - Kcross = Kcross.reshape(batch_count, 1, nn_count) - variances = self._fn(Kin, Kcross, **kwargs) - if len(variances.shape) == 1: - variances = variances.reshape(variances.shape[0], 1) - return variances + return self._fn(Kin, Kcross, **kwargs) def get_opt_fn(self) -> Callable: - return self.__call__ + return self._opt_fn