Skip to content

Commit b02d892

Browse files
authored
Merge pull request #408 from hylkedonker/bug/inv-wish-scale
Bugfix InverseWishart scale matrix
2 parents d9f5334 + 8dba507 commit b02d892

File tree

2 files changed

+25
-2
lines changed

2 files changed

+25
-2
lines changed

dynamax/utils/distributions.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def __init__(self, df, scale):
4242
# Wishart distribution
4343
dim = scale.shape[-1]
4444
eye = jnp.broadcast_to(jnp.eye(dim), scale.shape)
45-
cho_scale = jnp.linalg.cholesky(scale)
46-
inv_scale_tril = solve_triangular(cho_scale, eye, lower=True)
45+
inv_scale = psd_solve(A=scale, b=eye)
46+
inv_scale_tril = jnp.linalg.cholesky(inv_scale)
4747

4848
super().__init__(
4949
tfd.WishartTriL(df, scale_tril=inv_scale_tril),

dynamax/utils/distributions_test.py

+23
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,29 @@ def test_inverse_wishart_sample(df=7.0, dim=3, scale_factor=3.0, n_samples=10000
5757
mc_std = jnp.sqrt(iw.variance() / n_samples)
5858
assert jnp.allclose(samples.mean(axis=0), iw.mean(), atol=num_std * mc_std)
5959

60+
def test_inverse_wishart_sample_non_diagonal_scale(n_samples: int = 10_000, num_std=3):
61+
"""Test sample mean of an inverse-Wishart distr. w/ non-diagonal scale matrix."""
62+
k = 2
63+
𝜈 = 5.5 # 𝜈 > k
64+
Ψ = jnp.array([[20.712932, 25.124634],
65+
[25.124634, 32.814785]], dtype=jnp.float32) # k x k
66+
Ψ_diag = jnp.diagonal(Ψ)
67+
assert all(jnp.linalg.eigvals(Ψ) > 0) # Is positive definite.
68+
69+
iw = InverseWishart(df=𝜈, scale=Ψ)
70+
Σs = iw.sample(sample_shape=n_samples, seed=jr.key(42))
71+
actual_Σ_avg = jnp.mean(Σs, axis=0)
72+
73+
# Closed form expression of mean.
74+
true_Σ_avg = Ψ / (𝜈 - k - 1)
75+
# Closed form expression of variance.
76+
numerator = (𝜈 - k + 1) * Ψ**2 + (𝜈 - k - 1) * jnp.outer(Ψ_diag, Ψ_diag)
77+
denominator = (𝜈 - k) * (𝜈 - k - 1)**2 * (𝜈 - k - 3)
78+
true_Σ_var = numerator / denominator
79+
80+
mc_std = jnp.sqrt(true_Σ_var / n_samples)
81+
assert jnp.allclose(actual_Σ_avg, true_Σ_avg, atol=num_std * mc_std)
82+
6083

6184
def test_normal_inverse_wishart_mode(loc=0., mean_conc=1.0, df=7.0, dim=3, scale_factor=3.0):
6285
"""

0 commit comments

Comments
 (0)