Skip to content

Commit

Permalink
feat: add into_params() to multivariate distributions
Browse files Browse the repository at this point in the history
  • Loading branch information
FreezyLemon committed Sep 5, 2024
1 parent f7f2960 commit 77fdbe0
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 0 deletions.
31 changes: 31 additions & 0 deletions src/distribution/dirichlet.rs
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,37 @@ where
-gamma::ln_gamma(sum) + (sum - self.alpha.len() as f64) * gamma::digamma(sum) - num;
Some(entr)
}

/// Consumes the [`Dirichlet`] and returns the `alpha` parameter
/// originally passed to [`new_from_nalgebra`][Self::new_from_nalgebra]
/// to construct it.
///
/// This can be used to avoid allocations when creating the same
/// distribution multiple times.
///
/// # Examples
///
/// ```
/// use statrs::distribution::Dirichlet;
/// use nalgebra::dvector;
///
/// let alpha = dvector![0.1, 0.3, 0.5, 0.8];
/// let dir_1 = Dirichlet::new_from_nalgebra(alpha).unwrap();
/// assert_eq!(dir_1.entropy(), Some(-17.46469081094079));
///
/// let mut alpha = dir_1.into_params();
/// alpha[1] = 0.2;
/// alpha[2] = 0.3;
/// alpha[3] = 0.4;
///
/// let dir_2 = Dirichlet::new_from_nalgebra(alpha).unwrap();
/// assert_eq!(dir_2.entropy(), Some(-21.53881433791513));
/// ```
#[must_use]
#[inline]
pub fn into_params(self) -> OVector<f64, D> {
self.alpha
}
}

impl<D> std::fmt::Display for Dirichlet<D>
Expand Down
31 changes: 31 additions & 0 deletions src/distribution/multivariate_normal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,37 @@ where
.ln(),
)
}

/// Consumes the [`MultivariateNormal`] and returns the parameters
/// originally passed to [`new_from_nalgebra`][Self::new_from_nalgebra]
/// to construct it.
///
/// This can be used to avoid allocations when creating the same
/// distribution multiple times.
///
/// # Examples
///
/// ```
/// use statrs::distribution::MultivariateNormal;
/// use nalgebra::{dmatrix, dvector};
///
/// let mean = dvector![0.0, 0.0];
/// let cov = dmatrix![1.0, 0.0; 0.0, 1.0];
/// let mvn_1 = MultivariateNormal::new_from_nalgebra(mean, cov).unwrap();
/// assert_eq!(mvn_1.entropy(), Some(2.8378770664093453));
///
/// let (mean, mut cov) = mvn_1.into_params();
/// cov[1] = 0.5;
/// cov[2] = 0.5;
///
/// let mvn_2 = MultivariateNormal::new_from_nalgebra(mean, cov).unwrap();
/// assert_eq!(mvn_2.entropy(), Some(2.694036030183455));
/// ```
#[must_use]
#[inline]
pub fn into_params(self) -> (OVector<f64, D>, OMatrix<f64, D, D>) {
(self.mu, self.cov)
}
}

impl<D> std::fmt::Display for MultivariateNormal<D>
Expand Down
34 changes: 34 additions & 0 deletions src/distribution/multivariate_students_t.rs
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,40 @@ where
pub fn ln_pdf_const(&self) -> f64 {
self.ln_pdf_const
}

/// Consumes the [`MultivariateStudent`] and returns the parameters
/// originally passed to [`new_from_nalgebra`][Self::new_from_nalgebra]
/// to construct it.
///
/// This can be used to avoid allocations when creating the same
/// distribution multiple times.
///
/// # Examples
///
/// ```
/// use statrs::distribution::MultivariateStudent;
/// use statrs::distribution::Continuous;
/// use nalgebra::{dmatrix, dvector};
///
/// let location = dvector![0.0, 0.0];
/// let scale = dmatrix![1.0, 0.0; 0.0, 1.0];
/// let freedom = 4.0;
/// let x = dvector![1.0, 2.0];
///
/// let mvs_1 = MultivariateStudent::new_from_nalgebra(location, scale, freedom).unwrap();
/// assert_eq!(mvs_1.pdf(&x), 0.01397245042233379);
///
/// let (location, scale, _) = mvs_1.into_params();
/// let freedom = 2.0;
///
/// let mvs_2 = MultivariateStudent::new_from_nalgebra(location, scale, freedom).unwrap();
/// assert_eq!(mvs_2.pdf(&x), 0.012992240252399626);
/// ```
#[must_use]
#[inline]
pub fn into_params(self) -> (OVector<f64, D>, OMatrix<f64, D, D>, f64) {
(self.location, self.scale, self.freedom)
}
}

impl<D> ::rand::distributions::Distribution<OVector<f64, D>> for MultivariateStudent<D>
Expand Down

0 comments on commit 77fdbe0

Please sign in to comment.