Skip to content

Commit

Permalink
heteroscedastic fixes to jax correctness tests
Browse files Browse the repository at this point in the history
  • Loading branch information
bwpriest committed Dec 8, 2023
1 parent f7f3cdd commit 0680e0a
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions tests/backend/jax_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,11 +430,21 @@ def setUpClass(cls):
cls.noise_heteroscedastic_n,
smoothness_bounds=cls.smoothness_bounds,
)
cls.muygps_heteroscedastic_train_n = cls._make_heteroscedastic_muygps_n(
cls.smoothness,
cls.noise_heteroscedastic_train_n,
smoothness_bounds=cls.smoothness_bounds,
)
cls.muygps_heteroscedastic_j = cls._make_heteroscedastic_muygps_j(
cls.smoothness,
cls.noise_heteroscedastic_j,
smoothness_bounds=cls.smoothness_bounds,
)
cls.muygps_heteroscedastic_train_j = cls._make_heteroscedastic_muygps_j(
cls.smoothness,
cls.noise_heteroscedastic_train_j,
smoothness_bounds=cls.smoothness_bounds,
)

cls.batch_indices_n, cls.batch_nn_indices_n = sample_batch(
cls.nbrs_lookup, cls.batch_count, cls.train_count
Expand Down Expand Up @@ -897,17 +907,15 @@ def setUpClass(cls):
)

cls.heteroscedastic_K_fast_n = (
cls.muygps_heteroscedastic_n.noise.perturb(
l2_n(cls.K_fast_n),
)
cls.muygps_heteroscedastic_train_n.noise.perturb(l2_n(cls.K_fast_n))
)

cls.fast_regress_coeffs_n = cls.muygps_gen_n.fast_coefficients(
cls.homoscedastic_K_fast_n, cls.train_nn_targets_fast_n
)

cls.fast_regress_coeffs_heteroscedastic_n = (
cls.muygps_heteroscedastic_n.fast_coefficients(
cls.muygps_heteroscedastic_train_n.fast_coefficients(
cls.heteroscedastic_K_fast_n, cls.train_nn_targets_fast_n
)
)
Expand Down Expand Up @@ -947,7 +955,7 @@ def setUpClass(cls):
)

cls.heteroscedastic_K_fast_j = (
cls.muygps_heteroscedastic_j.noise.perturb(
cls.muygps_heteroscedastic_train_j.noise.perturb(
l2_j(cls.K_fast_j),
)
)
Expand All @@ -957,7 +965,7 @@ def setUpClass(cls):
)

cls.fast_regress_coeffs_heteroscedastic_j = (
cls.muygps_heteroscedastic_n.fast_coefficients(
cls.muygps_heteroscedastic_train_n.fast_coefficients(
cls.heteroscedastic_K_fast_j, cls.train_nn_targets_fast_j
)
)
Expand Down Expand Up @@ -1132,12 +1140,12 @@ def setUpClass(cls):
)

cls.heteroscedastic_K_fast_n = (
cls.muygps_heteroscedastic_n.noise.perturb(
cls.muygps_heteroscedastic_train_n.noise.perturb(
l2_n(cls.K_fast_n),
)
)
cls.fast_regress_coeffs_heteroscedastic_n = (
cls.muygps_heteroscedastic_n.fast_coefficients(
cls.muygps_heteroscedastic_train_n.fast_coefficients(
cls.heteroscedastic_K_fast_n, cls.train_nn_targets_fast_n
)
)
Expand Down Expand Up @@ -1185,13 +1193,13 @@ def setUpClass(cls):
)

cls.heteroscedastic_K_fast_j = (
cls.muygps_heteroscedastic_j.noise.perturb(
cls.muygps_heteroscedastic_train_j.noise.perturb(
l2_n(cls.K_fast_j), cls.noise_heteroscedastic_train_j
)
)

cls.fast_regress_coeffs_heteroscedastic_j = (
cls.muygps_heteroscedastic_j.fast_coefficients(
cls.muygps_heteroscedastic_train_j.fast_coefficients(
cls.heteroscedastic_K_fast_j, cls.train_nn_targets_fast_j
)
)
Expand Down

0 comments on commit 0680e0a

Please sign in to comment.