From 1cbcc49c53c60df5ea73b1a991fcf20c1f0c2f4c Mon Sep 17 00:00:00 2001 From: FreezyLemon Date: Sun, 22 Sep 2024 14:01:00 +0200 Subject: [PATCH] refactor: rewrite Dirichlet::ln_pdf The new code also avoids looping through x twice --- src/distribution/dirichlet.rs | 38 ++++++++++++++++------------------- 1 file changed, 17 insertions(+), 21 deletions(-) diff --git a/src/distribution/dirichlet.rs b/src/distribution/dirichlet.rs index fffab765..5e08d0a4 100644 --- a/src/distribution/dirichlet.rs +++ b/src/distribution/dirichlet.rs @@ -339,31 +339,27 @@ where /// `Π` is the product from `1` to `K`, `Σ` is the sum from `1` to `K`, /// and `K` is the number of concentration parameters fn ln_pdf(&self, x: &OVector) -> f64 { - // TODO: would it be clearer here to just do a for loop instead - // of using iterators? if self.alpha.len() != x.len() { panic!("Arguments must have correct dimensions."); } - if x.iter().any(|&x| x <= 0.0 || x >= 1.0) { - panic!("Arguments must be in (0, 1)"); - } - let (term, sum_xi, sum_alpha) = x - .iter() - .enumerate() - .map(|pair| (pair.1, self.alpha[pair.0])) - .fold((0.0, 0.0, 0.0), |acc, pair| { - ( - acc.0 + (pair.1 - 1.0) * pair.0.ln() - gamma::ln_gamma(pair.1), - acc.1 + pair.0, - acc.2 + pair.1, - ) - }); - - if !prec::almost_eq(sum_xi, 1.0, 1e-4) { - panic!(); - } else { - term + gamma::ln_gamma(sum_alpha) + + let mut term = 0.0; + let mut sum_x = 0.0; + let mut sum_alpha = 0.0; + + for (&x_i, &alpha_i) in x.iter().zip(self.alpha.iter()) { + assert!(0.0 < x_i && x_i < 1.0, "Arguments must be in (0, 1)"); + + term += (alpha_i - 1.0) * x_i.ln() - gamma::ln_gamma(alpha_i); + sum_x += x_i; + sum_alpha += alpha_i; } + + assert!( + prec::almost_eq(sum_x, 1.0, 1e-4), + "Arguments must sum up to 1" + ); + term + gamma::ln_gamma(sum_alpha) } }