From acbd3028d77ecefb796fa041346408135eb183a7 Mon Sep 17 00:00:00 2001 From: Orion Yeung <11580988+orionyeung001@users.noreply.github.com> Date: Wed, 18 Sep 2024 12:50:44 -0500 Subject: [PATCH] example: random_clap can pmf for multinomial example: refactoring for multivariate densities example: set some default parameters, on distn basis --- examples/random_clap.rs | 56 +++++++++++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 10 deletions(-) diff --git a/examples/random_clap.rs b/examples/random_clap.rs index a1d1d4de..53235bfb 100644 --- a/examples/random_clap.rs +++ b/examples/random_clap.rs @@ -6,6 +6,7 @@ use statrs::statistics::Mode; use std::fmt::Display; use std::io::{self, BufWriter, Write}; +use std::str::{FromStr, Split}; use anyhow::Result; use clap::{ArgAction, Parser, Subcommand}; @@ -28,6 +29,7 @@ enum Commands { }, /// for evaluating distribution function density Density { + /// sample to evaluate density at, default=distribution's mode #[arg(short, long = "arg", action = ArgAction::Append, value_name = "SAMPLE")] args: Vec, #[command(subcommand)] @@ -48,14 +50,14 @@ enum DistributionAsCommand { Binomial { #[arg(value_name = "trial counts")] n: u64, - #[arg(value_name = "success probability")] + #[arg(value_name = "success probability", default_value = "0.5")] p: f64, }, /// the normal distribution Normal { - #[arg(value_name = "mean")] + #[arg(value_name = "mean", default_value = "0.0")] mu: f64, - #[arg(value_name = "standard deviation")] + #[arg(value_name = "standard deviation", default_value = "1.0")] sigma: f64, }, } @@ -70,15 +72,33 @@ fn main() -> Result<()> { Ok(()) } -fn run_command_density(args: &[String], dist: DistributionAsCommand) -> Result<()> { +fn run_command_density(args_str: &[String], dist: DistributionAsCommand) -> Result<()> { let densities = match dist { - DistributionAsCommand::Multinomial { .. } => { - unimplemented!() + DistributionAsCommand::Multinomial { n, p } => { + let dist = Multinomial::new(p, n)?; + if !args_str.is_empty() { + let mut densities = Vec::new(); + densities.reserve(args_str.len()); + + for arg_str in args_str { + let arg = parse_str_split_to_vec(arg_str.split(',')); + if arg.len() == dist.p().len() { + densities.push(dist.pmf(&arg.into())); + } else { + anyhow::bail!("dimension mismatch after parsing `--arg {arg_str}`"); + } + } + + densities + } else { + vec![dist.pmf(&dist.mode())] + } } DistributionAsCommand::Binomial { n, p } => { let dist = Binomial::new(p, n)?; - if !args.is_empty() { - args.iter() + if !args_str.is_empty() { + args_str + .iter() .map_while(|s| match s.parse() { Ok(x) => Some(x), Err(e) => { @@ -94,8 +114,9 @@ fn run_command_density(args: &[String], dist: DistributionAsCommand) -> Result<( } DistributionAsCommand::Normal { mu, sigma } => { let dist = Normal::new(mu, sigma)?; - if !args.is_empty() { - args.iter() + if !args_str.is_empty() { + args_str + .iter() .map_while(|s| match s.parse() { Ok(x) => Some(x), Err(e) => { @@ -116,6 +137,21 @@ fn run_command_density(args: &[String], dist: DistributionAsCommand) -> Result<( Ok(()) } +fn parse_str_split_to_vec(sp: Split) -> Vec +where + T: FromStr, + E: Display + std::error::Error, +{ + sp.map_while(|si| match si.parse::() { + Ok(x) => Some(x), + Err(e) => { + eprintln!("could not parse argment, got {e}"); + None + } + }) + .collect() +} + fn run_command_sample(count: Option, dist: DistributionAsCommand) -> Result<()> { let count = count.unwrap_or(10);