diff --git a/tests/backend/jax_correctness.py b/tests/backend/jax_correctness.py index 5c4e04e7..e820f495 100644 --- a/tests/backend/jax_correctness.py +++ b/tests/backend/jax_correctness.py @@ -965,7 +965,7 @@ def setUpClass(cls): ) cls.fast_regress_coeffs_heteroscedastic_j = ( - cls.muygps_heteroscedastic_train_n.fast_coefficients( + cls.muygps_heteroscedastic_train_j.fast_coefficients( cls.heteroscedastic_K_fast_j, cls.train_nn_targets_fast_j ) ) @@ -1193,9 +1193,7 @@ def setUpClass(cls): ) cls.heteroscedastic_K_fast_j = ( - cls.muygps_heteroscedastic_train_j.noise.perturb( - l2_n(cls.K_fast_j), cls.noise_heteroscedastic_train_j - ) + cls.muygps_heteroscedastic_train_j.noise.perturb(l2_n(cls.K_fast_j)) ) cls.fast_regress_coeffs_heteroscedastic_j = (