Skip to content

Commit

Permalink
refactor: rewrite Dirichlet::ln_pdf
Browse files Browse the repository at this point in the history
The new code also avoids looping through x twice
  • Loading branch information
FreezyLemon authored and YeungOnion committed Sep 22, 2024
1 parent b83fc78 commit 1cbcc49
Showing 1 changed file with 17 additions and 21 deletions.
38 changes: 17 additions & 21 deletions src/distribution/dirichlet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -339,31 +339,27 @@ where
/// `Π` is the product from `1` to `K`, `Σ` is the sum from `1` to `K`,
/// and `K` is the number of concentration parameters
fn ln_pdf(&self, x: &OVector<f64, D>) -> f64 {
// TODO: would it be clearer here to just do a for loop instead
// of using iterators?
if self.alpha.len() != x.len() {
panic!("Arguments must have correct dimensions.");
}
if x.iter().any(|&x| x <= 0.0 || x >= 1.0) {
panic!("Arguments must be in (0, 1)");
}
let (term, sum_xi, sum_alpha) = x
.iter()
.enumerate()
.map(|pair| (pair.1, self.alpha[pair.0]))
.fold((0.0, 0.0, 0.0), |acc, pair| {
(
acc.0 + (pair.1 - 1.0) * pair.0.ln() - gamma::ln_gamma(pair.1),
acc.1 + pair.0,
acc.2 + pair.1,
)
});

if !prec::almost_eq(sum_xi, 1.0, 1e-4) {
panic!();
} else {
term + gamma::ln_gamma(sum_alpha)

let mut term = 0.0;
let mut sum_x = 0.0;
let mut sum_alpha = 0.0;

for (&x_i, &alpha_i) in x.iter().zip(self.alpha.iter()) {
assert!(0.0 < x_i && x_i < 1.0, "Arguments must be in (0, 1)");

term += (alpha_i - 1.0) * x_i.ln() - gamma::ln_gamma(alpha_i);
sum_x += x_i;
sum_alpha += alpha_i;
}

assert!(
prec::almost_eq(sum_x, 1.0, 1e-4),
"Arguments must sum up to 1"
);
term + gamma::ln_gamma(sum_alpha)
}
}

Expand Down

0 comments on commit 1cbcc49

Please sign in to comment.