diff --git a/Cargo.toml b/Cargo.toml index 356ac27e..659681eb 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,7 +29,7 @@ tracing-subscriber = { version = "0.3.19", features = [ ] } faer = "0.21.8" faer-ext = { version = "0.5.0", features = ["nalgebra", "ndarray"] } -pharmsol = "0.7.8" +pharmsol = { git = "https://github.com/LAPKB/pharmsol.git", rev = "8f78864" } rand = "0.9.0" anyhow = "1.0.97" rayon = "1.10.0" diff --git a/src/algorithms/mod.rs b/src/algorithms/mod.rs index b8e21b52..244682c8 100644 --- a/src/algorithms/mod.rs +++ b/src/algorithms/mod.rs @@ -7,13 +7,9 @@ use crate::structs::psi::Psi; use crate::structs::theta::Theta; use anyhow::Context; use anyhow::Result; -use faer_ext::IntoNdarray; -use ndarray::parallel::prelude::{IntoParallelIterator, ParallelIterator}; -use ndarray::{Array, ArrayBase, Dim, OwnedRepr}; use npag::*; use npod::NPOD; use pharmsol::prelude::{data::Data, simulator::Equation}; -use pharmsol::{ErrorModel, Predictions, Subject}; use postprob::POSTPROB; use serde::{Deserialize, Serialize}; @@ -35,182 +31,6 @@ pub trait Algorithms: Sync { where Self: Sized; fn validate_psi(&mut self) -> Result<()> { - // Count problematic values in psi - let mut nan_count = 0; - let mut inf_count = 0; - - let psi = self.psi().matrix().as_ref().into_ndarray(); - // First coerce all NaN and infinite in psi to 0.0 - for i in 0..psi.nrows() { - for j in 0..self.psi().matrix().ncols() { - let val = psi.get((i, j)).unwrap(); - if val.is_nan() { - nan_count += 1; - // *val = 0.0; - } else if val.is_infinite() { - inf_count += 1; - // *val = 0.0; - } - } - } - - if nan_count + inf_count > 0 { - tracing::warn!( - "Psi matrix contains {} NaN, {} Infinite values of {} total values", - nan_count, - inf_count, - psi.ncols() * psi.nrows() - ); - } - - let (_, col) = psi.dim(); - let ecol: ArrayBase, Dim<[usize; 1]>> = Array::ones(col); - let plam = psi.dot(&ecol); - let w = 1. / &plam; - - // Get the index of each element in `w` that is NaN or infinite - let indices: Vec = w - .iter() - .enumerate() - .filter(|(_, x)| x.is_nan() || x.is_infinite()) - .map(|(i, _)| i) - .collect::>(); - - if !indices.is_empty() { - let subject: Vec<&Subject> = self.get_data().get_subjects(); - let zero_probability_subjects: Vec<&String> = - indices.iter().map(|&i| subject[i].id()).collect(); - - tracing::error!( - "{}/{} subjects have zero probability given the model", - indices.len(), - psi.nrows() - ); - - // For each problematic subject - for index in &indices { - tracing::debug!("Subject with zero probability: {}", subject[*index].id()); - - let e_type = self.get_settings().error().error_model().into(); - - let error_model = ErrorModel::new( - self.get_settings().error().poly, - self.get_settings().error().value, - &e_type, - ); - - // Simulate all support points in parallel - let spp_results: Vec<_> = self - .get_theta() - .matrix() - .row_iter() - .enumerate() - .collect::>() - .into_par_iter() - .map(|(i, spp)| { - let support_point: Vec = spp.iter().copied().collect(); - let (pred, ll) = self.equation().simulate_subject( - subject[*index], - &support_point, - Some(&error_model), - ); - (i, support_point, pred.get_predictions(), ll) - }) - .collect(); - - // Count problematic likelihoods for this subject - let mut nan_ll = 0; - let mut inf_pos_ll = 0; - let mut inf_neg_ll = 0; - let mut zero_ll = 0; - let mut valid_ll = 0; - - for (_, _, _, ll) in &spp_results { - match ll { - Some(ll_val) if ll_val.is_nan() => nan_ll += 1, - Some(ll_val) if ll_val.is_infinite() && ll_val.is_sign_positive() => { - inf_pos_ll += 1 - } - Some(ll_val) if ll_val.is_infinite() && ll_val.is_sign_negative() => { - inf_neg_ll += 1 - } - Some(ll_val) if *ll_val == 0.0 => zero_ll += 1, - Some(_) => valid_ll += 1, - None => nan_ll += 1, - } - } - - tracing::debug!( - "\tLikelihood analysis for subject {} ({} support points):", - subject[*index].id(), - spp_results.len() - ); - tracing::debug!( - "\tNaN likelihoods: {} ({:.1}%)", - nan_ll, - 100.0 * nan_ll as f64 / spp_results.len() as f64 - ); - tracing::debug!( - "\t+Inf likelihoods: {} ({:.1}%)", - inf_pos_ll, - 100.0 * inf_pos_ll as f64 / spp_results.len() as f64 - ); - tracing::debug!( - "\t-Inf likelihoods: {} ({:.1}%)", - inf_neg_ll, - 100.0 * inf_neg_ll as f64 / spp_results.len() as f64 - ); - tracing::debug!( - "\tZero likelihoods: {} ({:.1}%)", - zero_ll, - 100.0 * zero_ll as f64 / spp_results.len() as f64 - ); - tracing::debug!( - "\tValid likelihoods: {} ({:.1}%)", - valid_ll, - 100.0 * valid_ll as f64 / spp_results.len() as f64 - ); - - // Sort and show top 10 most likely support points - let mut sorted_results = spp_results; - sorted_results.sort_by(|a, b| { - b.3.unwrap_or(f64::NEG_INFINITY) - .partial_cmp(&a.3.unwrap_or(f64::NEG_INFINITY)) - .unwrap_or(std::cmp::Ordering::Equal) - }); - let take = 3; - - tracing::debug!("Top {} most likely support points:", take); - for (i, support_point, preds, ll) in sorted_results.iter().take(take) { - tracing::debug!("\tSupport point #{}: {:?}", i, support_point); - tracing::debug!("\t\tLog-likelihood: {:?}", ll); - - let times = preds.iter().map(|x| x.time()).collect::>(); - let observations = preds.iter().map(|x| x.observation()).collect::>(); - let predictions = preds.iter().map(|x| x.prediction()).collect::>(); - let outeqs = preds.iter().map(|x| x.outeq()).collect::>(); - let states = preds - .iter() - .map(|x| x.state().clone()) - .collect::>>(); - - tracing::debug!("\t\tTimes: {:?}", times); - tracing::debug!("\t\tObservations: {:?}", observations); - tracing::debug!("\t\tPredictions: {:?}", predictions); - tracing::debug!("\t\tOuteqs: {:?}", outeqs); - tracing::debug!("\t\tStates: {:?}", states); - } - tracing::debug!("====================="); - } - - return Err(anyhow::anyhow!( - "The probability of {}/{} subjects is zero given the model. Affected subjects: {:?}", - indices.len(), - self.psi().matrix().nrows(), - zero_probability_subjects - )); - } - Ok(()) } fn get_settings(&self) -> &Settings; diff --git a/tests/onecomp.rs b/tests/onecomp.rs index 5047e8a7..d8dff8f9 100644 --- a/tests/onecomp.rs +++ b/tests/onecomp.rs @@ -65,7 +65,9 @@ fn test_one_compartment() -> Result<()> { // Check the results assert_eq!(result.cycles(), 32); - assert_eq!(result.objf(), 97.57533032670898); + // Check that likelihood are reasonably close + let objf_diff = (result.objf() - 97.57533032670898).abs(); + assert!(objf_diff < 0.01, "objf_diff: {}", objf_diff); Ok(()) }