-
Notifications
You must be signed in to change notification settings - Fork 307
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Multivariate normal distribution in ndarray-rand #582
Comments
Fwiw, I've been using this for a while in my own code: use failure::{Context, ResultExt};
use ndarray::{Data, DataClone, DataOwned, OwnedRepr, ViewRepr};
use ndarray::prelude::*;
use ndarray_linalg::cholesky::{CholeskyInto, UPLO};
use ndarray_rand::RandomExt;
use rand::distributions::{Distribution, Normal};
use rand::Rng;
use std::clone::Clone;
use std::fmt::{self, Debug};
use std::ops::AddAssign;
// ...
/// Multivariate Gaussian distribution.
#[derive(PartialEq, Deserialize, Serialize)]
#[serde(bound(deserialize = "S: DataOwned, S::Elem: ::serde::Deserialize<'de>"))]
#[serde(bound(serialize = "S: Data, S::Elem: ::serde::Serialize"))]
pub struct GaussianDistroBase<S>
where
S: Data<Elem = f64>,
{
pub mean: ArrayBase<S, Ix1>,
pub covariance: ArrayBase<S, Ix2>,
}
pub type GaussianDistro = GaussianDistroBase<OwnedRepr<f64>>;
pub type GaussianDistroView<'a> = GaussianDistroBase<ViewRepr<&'a f64>>;
impl<S> GaussianDistroBase<S>
where
S: Data<Elem = f64>,
{
pub fn len(&self) -> usize {
assert_eq!(self.mean.len(), self.covariance.len_of(Axis(0)));
assert_eq!(self.mean.len(), self.covariance.len_of(Axis(1)));
self.mean.len()
}
pub fn to_owned(&self) -> GaussianDistro {
GaussianDistro {
mean: self.mean.to_owned(),
covariance: self.covariance.to_owned(),
}
}
pub fn view(&self) -> GaussianDistroView {
GaussianDistroView {
mean: self.mean.view(),
covariance: self.covariance.view(),
}
}
}
impl<S> Clone for GaussianDistroBase<S>
where
S: DataClone<Elem = f64>,
{
fn clone(&self) -> Self {
GaussianDistroBase {
mean: self.mean.clone(),
covariance: self.covariance.clone(),
}
}
}
impl<S> Debug for GaussianDistroBase<S>
where
S: Data<Elem = f64>,
{
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
fmt.debug_struct("GaussianDistroBase")
.field("mean", &self.mean)
.field("covariance", &self.covariance)
.finish()
}
}
#[derive(Debug, Fail)]
#[fail(display = "error sampling from multivariate normal distribution: {}", _0)]
pub struct GaussianSampleError(Context<String>);
impl From<Context<String>> for GaussianSampleError {
fn from(context: Context<String>) -> GaussianSampleError {
GaussianSampleError(context)
}
}
impl<S> Distribution<Result<Array1<f64>, GaussianSampleError>> for GaussianDistroBase<S>
where
S: Data<Elem = f64>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> Result<Array1<f64>, GaussianSampleError> {
let mut cov = self.covariance.to_owned();
// Add a small multiple of I for numerical reasons.
cov.diag_mut().add_assign(1e2 * ::std::f64::EPSILON);
let chol = cov.cholesky_into(UPLO::Lower)
.context("error factoring covariance".into())?;
Ok(chol.dot(&Array1::random_using(self.len(), Normal::new(0., 1.), rng)) + &self.mean)
}
}
// ... It's probably a bit more complex than what you're looking for (since For the purpose of |
Yes, a feature flag seems appropriate, I'm using the same thing on a MCMC algorithms crate I'm working on (for adding support to multivariate distributions using Your code seems a bit overkill, I was only planning on implementation for I'll make a pull request so you can check my work. |
Implementing multivariate normal distributions involves a bit of boilerplate, and maybe the use of
ndarray-linalg
to perform a Cholesky decomposition. Would it be interesting to implement it on the crate's end ? I made a fork and started writing some code.The text was updated successfully, but these errors were encountered: