Skip to content

Commit

Permalink
reorganized temporary reshaping in tensor math to make it easier to m…
Browse files Browse the repository at this point in the history
…odify. removed unused apply_scale flag. removed insertion of initial scale parameter to optimization variance function.
  • Loading branch information
bwpriest committed Dec 18, 2023
1 parent 41724e0 commit 4a1acf4
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 18 deletions.
22 changes: 15 additions & 7 deletions MuyGPyS/gp/mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@
from MuyGPyS.gp.noise import NoiseFn


def _reshaper_to_be_removed(fn: Callable) -> Callable:
def reshaped_fn(Kin, Kcross, *args, **kwargs):
if len(Kcross.shape) == 2:
batch_count, nn_count = Kcross.shape
Kcross = Kcross.reshape(batch_count, 1, nn_count)
ret = fn(Kin, Kcross, *args, **kwargs)
if len(ret.shape) == 1:
ret = ret.reshape(ret.shape[0], 1)
return ret

return reshaped_fn


class PosteriorMean:
def __init__(
self,
Expand All @@ -23,6 +36,7 @@ def __init__(
):
self._fn = _backend_fn
self._fn = noise.perturb_fn(self._fn)
self._fn = _reshaper_to_be_removed(self._fn)

def __call__(
self,
Expand All @@ -31,13 +45,7 @@ def __call__(
batch_nn_targets: mm.ndarray,
**kwargs,
) -> mm.ndarray:
if len(Kcross.shape) == 2:
batch_count, nn_count = Kcross.shape
Kcross = Kcross.reshape(batch_count, 1, nn_count)
responses = self._fn(Kin, Kcross, batch_nn_targets, **kwargs)
if len(responses.shape) == 1:
responses = responses.reshape(responses.shape[0], 1)
return responses
return self._fn(Kin, Kcross, batch_nn_targets, **kwargs)

def get_opt_fn(self) -> Callable:
return self.__call__
16 changes: 5 additions & 11 deletions MuyGPyS/gp/variance.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,34 +13,28 @@
from MuyGPyS._src.gp.muygps import _muygps_diagonal_variance
from MuyGPyS.gp.hyperparameter import ScaleFn
from MuyGPyS.gp.noise import NoiseFn
from MuyGPyS.gp.mean import _reshaper_to_be_removed


class PosteriorVariance:
def __init__(
self,
noise: NoiseFn,
scale: ScaleFn,
apply_scale: bool = True,
_backend_fn: Callable = _muygps_diagonal_variance,
):
self._fn = _backend_fn
self._fn = noise.perturb_fn(self._fn)
if apply_scale is True:
self._fn = scale.scale_fn(self._fn)
self._opt_fn = _reshaper_to_be_removed(self._fn)
self._fn = _reshaper_to_be_removed(scale.scale_fn(self._fn))

def __call__(
self,
Kin: mm.ndarray,
Kcross: mm.ndarray,
**kwargs,
) -> mm.ndarray:
if len(Kcross.shape) == 2:
batch_count, nn_count = Kcross.shape
Kcross = Kcross.reshape(batch_count, 1, nn_count)
variances = self._fn(Kin, Kcross, **kwargs)
if len(variances.shape) == 1:
variances = variances.reshape(variances.shape[0], 1)
return variances
return self._fn(Kin, Kcross, **kwargs)

def get_opt_fn(self) -> Callable:
return self.__call__
return self._opt_fn

0 comments on commit 4a1acf4

Please sign in to comment.