@@ -57,6 +57,29 @@ def test_inverse_wishart_sample(df=7.0, dim=3, scale_factor=3.0, n_samples=10000
57
57
mc_std = jnp .sqrt (iw .variance () / n_samples )
58
58
assert jnp .allclose (samples .mean (axis = 0 ), iw .mean (), atol = num_std * mc_std )
59
59
60
+ def test_inverse_wishart_sample_non_diagonal_scale (n_samples : int = 10_000 , num_std = 3 ):
61
+ """Test sample mean of an inverse-Wishart distr. w/ non-diagonal scale matrix."""
62
+ k = 2
63
+ 𝜈 = 5.5 # 𝜈 > k
64
+ Ψ = jnp .array ([[20.712932 , 25.124634 ],
65
+ [25.124634 , 32.814785 ]], dtype = jnp .float32 ) # k x k
66
+ Ψ_diag = jnp .diagonal (Ψ )
67
+ assert all (jnp .linalg .eigvals (Ψ ) > 0 ) # Is positive definite.
68
+
69
+ iw = InverseWishart (df = 𝜈 , scale = Ψ )
70
+ Σs = iw .sample (sample_shape = n_samples , seed = jr .key (42 ))
71
+ actual_Σ_avg = jnp .mean (Σs , axis = 0 )
72
+
73
+ # Closed form expression of mean.
74
+ true_Σ_avg = Ψ / (𝜈 - k - 1 )
75
+ # Closed form expression of variance.
76
+ numerator = (𝜈 - k + 1 ) * Ψ ** 2 + (𝜈 - k - 1 ) * jnp .outer (Ψ_diag , Ψ_diag )
77
+ denominator = (𝜈 - k ) * (𝜈 - k - 1 )** 2 * (𝜈 - k - 3 )
78
+ true_Σ_var = numerator / denominator
79
+
80
+ mc_std = jnp .sqrt (true_Σ_var / n_samples )
81
+ assert jnp .allclose (actual_Σ_avg , true_Σ_avg , atol = num_std * mc_std )
82
+
60
83
61
84
def test_normal_inverse_wishart_mode (loc = 0. , mean_conc = 1.0 , df = 7.0 , dim = 3 , scale_factor = 3.0 ):
62
85
"""
0 commit comments