Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

renaming K to Kin throughout for clarity #210

Merged
merged 4 commits into from
Dec 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions MuyGPyS/_src/gp/muygps/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,22 +10,22 @@

@jit
def _muygps_posterior_mean(
K: jnp.ndarray,
Kin: jnp.ndarray,
Kcross: jnp.ndarray,
batch_nn_targets: jnp.ndarray,
**kwargs,
) -> jnp.ndarray:
return jnp.squeeze(Kcross @ jnp.linalg.solve(K, batch_nn_targets))
return jnp.squeeze(Kcross @ jnp.linalg.solve(Kin, batch_nn_targets))


@jit
def _muygps_diagonal_variance(
K: jnp.ndarray,
Kin: jnp.ndarray,
Kcross: jnp.ndarray,
**kwargs,
) -> jnp.ndarray:
return jnp.squeeze(
1 - Kcross @ jnp.linalg.solve(K, Kcross.transpose(0, 2, 1))
1 - Kcross @ jnp.linalg.solve(Kin, Kcross.transpose(0, 2, 1))
)


Expand All @@ -47,7 +47,7 @@ def _mmuygps_fast_posterior_mean(

@jit
def _muygps_fast_posterior_mean_precompute(
K: jnp.ndarray,
Kin: jnp.ndarray,
train_nn_targets_fast: jnp.ndarray,
) -> jnp.ndarray:
return jnp.linalg.solve(K, train_nn_targets_fast)
return jnp.linalg.solve(Kin, train_nn_targets_fast)
2 changes: 1 addition & 1 deletion MuyGPyS/_src/gp/muygps/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _mmuygps_fast_posterior_mean(


def _muygps_fast_posterior_mean_precompute(
K: np.ndarray,
Kin: np.ndarray,
train_nn_targets_fast: np.ndarray,
**kwargs,
) -> np.ndarray:
Expand Down
12 changes: 6 additions & 6 deletions MuyGPyS/_src/gp/muygps/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@


def _muygps_posterior_mean(
K: np.ndarray,
Kin: np.ndarray,
Kcross: np.ndarray,
batch_nn_targets: np.ndarray,
**kwargs,
) -> np.ndarray:
return np.squeeze(Kcross @ np.linalg.solve(K, batch_nn_targets))
return np.squeeze(Kcross @ np.linalg.solve(Kin, batch_nn_targets))


def _muygps_diagonal_variance(
K: np.ndarray,
Kin: np.ndarray,
Kcross: np.ndarray,
**kwargs,
) -> np.ndarray:
return np.squeeze(
1 - Kcross @ np.linalg.solve(K, Kcross.transpose(0, 2, 1))
1 - Kcross @ np.linalg.solve(Kin, Kcross.transpose(0, 2, 1))
)


Expand All @@ -42,8 +42,8 @@ def _mmuygps_fast_posterior_mean(


def _muygps_fast_posterior_mean_precompute(
K: np.ndarray,
Kin: np.ndarray,
train_nn_targets_fast: np.ndarray,
**kwargs,
) -> np.ndarray:
return np.linalg.solve(K, train_nn_targets_fast)
return np.linalg.solve(Kin, train_nn_targets_fast)
12 changes: 6 additions & 6 deletions MuyGPyS/_src/gp/muygps/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,21 @@


def _muygps_posterior_mean(
K: torch.ndarray,
Kin: torch.ndarray,
Kcross: torch.ndarray,
batch_nn_targets: torch.ndarray,
**kwargs,
) -> torch.ndarray:
return torch.squeeze(Kcross @ torch.linalg.solve(K, batch_nn_targets))
return torch.squeeze(Kcross @ torch.linalg.solve(Kin, batch_nn_targets))


def _muygps_diagonal_variance(
K: torch.ndarray,
Kin: torch.ndarray,
Kcross: torch.ndarray,
**kwargs,
) -> torch.ndarray:
return torch.squeeze(
1 - Kcross @ torch.linalg.solve(K, Kcross.transpose(1, -1))
1 - Kcross @ torch.linalg.solve(Kin, Kcross.transpose(1, -1))
)


Expand All @@ -40,7 +40,7 @@ def _mmuygps_fast_posterior_mean(


def _muygps_fast_posterior_mean_precompute(
K: torch.ndarray,
Kin: torch.ndarray,
train_nn_targets_fast: torch.ndarray,
) -> torch.ndarray:
return torch.linalg.solve(K, train_nn_targets_fast)
return torch.linalg.solve(Kin, train_nn_targets_fast)
12 changes: 6 additions & 6 deletions MuyGPyS/_src/gp/noise/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,18 @@

@jit
def _homoscedastic_perturb(
K: jnp.ndarray, noise_variance: float
Kin: jnp.ndarray, noise_variance: float
) -> jnp.ndarray:
_, nn_count, _ = K.shape
return K + noise_variance * jnp.eye(nn_count)
_, nn_count, _ = Kin.shape
return Kin + noise_variance * jnp.eye(nn_count)


@jit
def _heteroscedastic_perturb(
K: jnp.ndarray, noise_variances: jnp.ndarray
Kin: jnp.ndarray, noise_variances: jnp.ndarray
) -> jnp.ndarray:
batch_count, nn_count, _ = K.shape
ret = K.copy()
batch_count, nn_count, _ = Kin.shape
ret = Kin.copy()
indices = (
jnp.repeat(jnp.arange(batch_count), nn_count),
jnp.tile(jnp.arange(nn_count), batch_count),
Expand Down
2 changes: 1 addition & 1 deletion MuyGPyS/_src/gp/noise/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,6 @@


def _heteroscedastic_perturb(
K: np.ndarray, noise_variances: np.ndarray
Kin: np.ndarray, noise_variances: np.ndarray
) -> np.ndarray:
raise NotImplementedError("heteroscedastic noise does not support mpi!")
12 changes: 6 additions & 6 deletions MuyGPyS/_src/gp/noise/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,16 @@
import MuyGPyS._src.math.numpy as np


def _homoscedastic_perturb(K: np.ndarray, noise_variance: float) -> np.ndarray:
_, nn_count, _ = K.shape
return K + noise_variance * np.eye(nn_count)
def _homoscedastic_perturb(Kin: np.ndarray, noise_variance: float) -> np.ndarray:
_, nn_count, _ = Kin.shape
return Kin + noise_variance * np.eye(nn_count)


def _heteroscedastic_perturb(
K: np.ndarray, noise_variances: np.ndarray
Kin: np.ndarray, noise_variances: np.ndarray
) -> np.ndarray:
ret = K.copy()
batch_count, nn_count, _ = K.shape
ret = Kin.copy()
batch_count, nn_count, _ = Kin.shape
indices = (
np.repeat(range(batch_count), nn_count),
np.tile(np.arange(nn_count), batch_count),
Expand Down
12 changes: 6 additions & 6 deletions MuyGPyS/_src/gp/noise/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@


def _homoscedastic_perturb(
K: torch.ndarray, noise_variance: float
Kin: torch.ndarray, noise_variance: float
) -> torch.ndarray:
_, nn_count, _ = K.shape
return K + noise_variance * torch.eye(nn_count)
_, nn_count, _ = Kin.shape
return Kin + noise_variance * torch.eye(nn_count)


def _heteroscedastic_perturb(
K: torch.ndarray, noise_variances: torch.ndarray
Kin: torch.ndarray, noise_variances: torch.ndarray
) -> torch.ndarray:
ret = K.clone()
batch_count, nn_count, _ = K.shape
ret = Kin.clone()
batch_count, nn_count, _ = Kin.shape
indices = (
torch.repeat(torch.arange(batch_count), nn_count),
torch.arange(nn_count).repeat(batch_count),
Expand Down
8 changes: 4 additions & 4 deletions MuyGPyS/_src/optimize/scale/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,21 @@

@jit
def _analytic_scale_optim_unnormalized(
K: jnp.ndarray,
Kin: jnp.ndarray,
nn_targets: jnp.ndarray,
) -> jnp.ndarray:
return jnp.sum(
jnp.einsum("ijk,ijk->ik", nn_targets, jnp.linalg.solve(K, nn_targets)),
jnp.einsum("ijk,ijk->ik", nn_targets, jnp.linalg.solve(Kin, nn_targets)),
axis=0,
)


@jit
def _analytic_scale_optim(
K: jnp.ndarray,
Kin: jnp.ndarray,
nn_targets: jnp.ndarray,
) -> jnp.ndarray:
batch_count, nn_count, _ = nn_targets.shape
return _analytic_scale_optim_unnormalized(K, nn_targets) / (
return _analytic_scale_optim_unnormalized(Kin, nn_targets) / (
batch_count * nn_count
)
4 changes: 2 additions & 2 deletions MuyGPyS/_src/optimize/scale/mpi.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@


def _analytic_scale_optim(
K: np.ndarray,
Kin: np.ndarray,
nn_targets: np.ndarray,
) -> np.ndarray:
local_batch_count, nn_count, _ = nn_targets.shape
local_sum = _analytic_scale_optim_unnormalized(K, nn_targets)
local_sum = _analytic_scale_optim_unnormalized(Kin, nn_targets)
global_sum = world.allreduce(local_sum, op=MPI.SUM)
global_batch_count = world.allreduce(local_batch_count, op=MPI.SUM)
return global_sum / (nn_count * global_batch_count)
8 changes: 4 additions & 4 deletions MuyGPyS/_src/optimize/scale/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,20 +7,20 @@


def _analytic_scale_optim_unnormalized(
K: np.ndarray,
Kin: np.ndarray,
nn_targets: np.ndarray,
) -> np.ndarray:
return np.sum(
np.einsum("ijk,ijk->ik", nn_targets, np.linalg.solve(K, nn_targets)),
np.einsum("ijk,ijk->ik", nn_targets, np.linalg.solve(Kin, nn_targets)),
axis=0,
)


def _analytic_scale_optim(
K: np.ndarray,
Kin: np.ndarray,
nn_targets: np.ndarray,
) -> np.ndarray:
batch_count, nn_count, _ = nn_targets.shape
return _analytic_scale_optim_unnormalized(K, nn_targets) / (
return _analytic_scale_optim_unnormalized(Kin, nn_targets) / (
nn_count * batch_count
)
8 changes: 4 additions & 4 deletions MuyGPyS/_src/optimize/scale/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@


def _analytic_scale_optim_unnormalized(
K: torch.ndarray,
Kin: torch.ndarray,
nn_targets: torch.ndarray,
) -> torch.ndarray:
return torch.sum(
torch.einsum(
"ijk,ijk->ik", nn_targets, torch.linalg.solve(K, nn_targets)
"ijk,ijk->ik", nn_targets, torch.linalg.solve(Kin, nn_targets)
),
axis=0,
)


def _analytic_scale_optim(
K: torch.ndarray,
Kin: torch.ndarray,
nn_targets: torch.ndarray,
) -> torch.ndarray:
batch_count, nn_count, _ = nn_targets.shape
return _analytic_scale_optim_unnormalized(K, nn_targets) / (
return _analytic_scale_optim_unnormalized(Kin, nn_targets) / (
nn_count * batch_count
)
6 changes: 3 additions & 3 deletions MuyGPyS/_test/gp.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,6 @@ def benchmark_sample_from_cholK(cholK: np.ndarray) -> np.ndarray:
).reshape(data_count, 1)


def get_analytic_scale(K, y):
assert y.shape[0] == K.shape[0]
return (1 / y.shape[0]) * y.T @ np.linalg.solve(K, y)
def get_analytic_scale(Kin, y):
assert y.shape[0] == Kin.shape[0]
return (1 / y.shape[0]) * y.T @ np.linalg.solve(Kin, y)
16 changes: 8 additions & 8 deletions MuyGPyS/_test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def _normalize(X: mm.ndarray) -> mm.ndarray:


def _get_scale_series(
K: mm.ndarray,
Kin: mm.ndarray,
nn_targets_column: mm.ndarray,
noise_variance: float,
) -> mm.ndarray:
Expand All @@ -308,7 +308,7 @@ def _get_scale_series(
NOTE[bwp]: This function is only for testing purposes.

Args:
K:
Kin:
A tensor of shape `(batch_count, nn_count, nn_count)` containing
the `(nn_count, nn_count` -shaped kernel matrices corresponding
to each of the batch elements.
Expand All @@ -324,13 +324,13 @@ def _get_scale_series(
batch_count, nn_count, _ = nn_targets_column.shape

scales = np.zeros((batch_count,))
for i, el in enumerate(_get_scale(K, nn_targets_column, noise_variance)):
for i, el in enumerate(_get_scale(Kin, nn_targets_column, noise_variance)):
scales[i] = el
return mm.array(scales / nn_count)


def _get_scale(
K: mm.ndarray,
Kin: mm.ndarray,
nn_targets_column: mm.ndarray,
noise_variance: float,
) -> Generator[float, None, None]:
Expand All @@ -339,14 +339,14 @@ def _get_scale(
individual solve along a single dimension:

.. math::
\\sigma^2 = \\frac{1}{k} * Y_{nn}^T K_{nn}^{-1} Y_{nn}
\\sigma^2 = \\frac{1}{k} * Y_{nn}^T Kin_{nn}^{-1} Y_{nn}

Here :math:`Y_{nn}` and :math:`K_{nn}` are the target and kernel
Here :math:`Y_{nn}` and :math:`Kin_{nn}` are the target and kernel
matrices with respect to the nearest neighbor set in scope, where
:math:`k` is the number of nearest neighbors.

Args:
K:
Kin:
A tensor of shape `(batch_count, nn_count, nn_count)` containing
the `(nn_count, nn_count` -shaped kernel matrices corresponding
to each of the batch elements.
Expand All @@ -364,7 +364,7 @@ def _get_scale(
for j in range(batch_count):
Y_0 = nn_targets_column[j, :, 0]
yield Y_0 @ mm.linalg.solve(
K[j, :, :] + noise_variance * mm.eye(nn_count), Y_0
Kin[j, :, :] + noise_variance * mm.eye(nn_count), Y_0
)


Expand Down
4 changes: 2 additions & 2 deletions MuyGPyS/examples/fast_posterior_mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,10 @@ def make_fast_regressor(
nn_indices = fast_nn_update(nn_indices)

train_nn_targets = train_targets[nn_indices]
K = muygps.kernel(pairwise_tensor(train_features, nn_indices))
Kin = muygps.kernel(pairwise_tensor(train_features, nn_indices))

precomputed_coefficients_matrix = muygps.fast_coefficients(
K, train_nn_targets
Kin, train_nn_targets
)

return precomputed_coefficients_matrix, nn_indices
Expand Down
Loading
Loading