diff --git a/.github/dependabot.yml b/.github/dependabot.yml new file mode 100644 index 00000000..70f86105 --- /dev/null +++ b/.github/dependabot.yml @@ -0,0 +1,16 @@ +# To get started with Dependabot version updates, you'll need to specify which +# package ecosystems to update and where the package manifests are located. +# Please see the documentation for all configuration options: +# https://docs.github.com/code-security/dependabot/dependabot-version-updates/configuration-options-for-the-dependabot.yml-file + +version: 2 +updates: + - package-ecosystem: "cargo" + directory: "/" + open-pull-requests-limit: 10 + schedule: + interval: "monthly" + - package-ecosystem: "github-actions" + directory: "/" + schedule: + interval: "monthly" diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml new file mode 100644 index 00000000..1485889c --- /dev/null +++ b/.github/workflows/coverage.yml @@ -0,0 +1,39 @@ +name: Coverage + +on: + push: + branches: [ master ] + pull_request: + branches: [ master ] +jobs: + coverage: + name: Coverage + runs-on: ubuntu-latest + env: + RUSTFLAGS: -D warnings + CARGO_TERM_COLOR: always + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@nightly + with: + components: llvm-tools-preview + + - uses: taiki-e/install-action@v2 + with: + tool: nextest + - uses: taiki-e/install-action@v2 + with: + tool: cargo-llvm-cov + + - name: Collect coverage + run: | + cargo llvm-cov --no-report nextest + cargo llvm-cov --no-report --doc + cargo llvm-cov report --doctests --lcov --output-path lcov.info + + - name: Upload to codecov.io + uses: codecov/codecov-action@v4 + with: + files: lcov.info + token: ${{secrets.CODECOV_TOKEN}} + fail_ci_if_error: false diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4f9fd9cd..65d901eb 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -6,34 +6,71 @@ on: pull_request: branches: [ master ] +env: + CARGO_INCREMENTAL: 0 + RUSTFLAGS: "-Dwarnings" + jobs: + clippy: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install Rust stable with clippy + uses: dtolnay/rust-toolchain@stable + with: + components: clippy + + - name: Run cargo clippy (default features) + run: cargo clippy --all-targets + + fmt: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install Rust stable with rustfmt + uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + + - name: Run rustfmt --check + run: cargo fmt -- --check + + msrv: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install cargo-hack + uses: taiki-e/install-action@cargo-hack + - uses: Swatinem/rust-cache@v2 + - name: Use predefined lockfile + run: mv -v Cargo.lock.MSRV Cargo.lock + - name: Build (lib only) + run: cargo hack check --rust-version --locked + test: - name: Test + needs: [clippy, fmt, msrv] runs-on: ${{ matrix.os }} strategy: - fail-fast: false matrix: - include: - - os: ubuntu-latest - target: x86_64-unknown-linux-gnu - toolchain: stable - - os: ubuntu-latest - target: x86_64-unknown-linux-gnu - toolchain: nightly + os: [ubuntu-latest, macos-latest, windows-latest] + steps: - - uses: actions/checkout@v2 - - name: Install toolchain - uses: actions-rs/toolchain@v1 - with: - profile: minimal - target: ${{ matrix.target }} - toolchain: ${{ matrix.toolchain }} - override: true - - name: Test nightly feature (if possible) - if: ${{ matrix.toolchain == 'nightly' }} - run: | - cargo test --target ${{ matrix.target }} --features=nightly - cargo test --target ${{ matrix.target }} --benches --features=nightly + - uses: actions/checkout@v4 + - name: Install Rust stable + uses: dtolnay/rust-toolchain@stable + - name: Test default features - run: | - cargo test --target ${{ matrix.target }} \ No newline at end of file + run: cargo test + + features: + needs: [clippy, fmt] + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Install Rust stable + uses: dtolnay/rust-toolchain@stable + - name: Install cargo-hack + uses: taiki-e/install-action@cargo-hack + - uses: Swatinem/rust-cache@v2 + - name: Check all possible feature sets + run: cargo hack check --feature-powerset --no-dev-deps diff --git a/.gitignore b/.gitignore index dea682e4..fe9e65ec 100644 --- a/.gitignore +++ b/.gitignore @@ -7,10 +7,17 @@ # Executables *.exe +# Test data for integration tests +tests/*.dat + # Generated by Cargo /target/ *.lock #editor specific /.vscode/ -.idea/ \ No newline at end of file +.idea/ +*.iml + +# macOS +.DS_Store diff --git a/CHANGELOG.md b/CHANGELOG.md index a65a3c75..2992f26f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,15 +1,61 @@ -v0.16.0 +# Changelog +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [0.17.0](https://github.com/statrs-dev/statrs/compare/v0.16.0...v0.17.0) - 2024-05-30 + +### Added +- specializes `inverse_cdf()` for Uniform (#166) +- Add way to get standard normal distribution easily. (#228) +- reject constructing Uniform of infinite support (#218) +- extend `StatsError` for finiteness (#218) +- default implementation of survival function with generics (#179) +- update `MultivariateNormal` API + - construct from nalgebra with `MultivariateNormal::new_from_nalgebra` (#177) + - support `std::vec` vector input in addition to `nalgebra` vectors (#199) + +### Fixed +- Update nalgebra to 0.32 (#187) +- for Gamma with shape<1 there is no mode, returns `None` instead of some negative number (#212) +- fix precision of ::inverse_cdf with some newton raphson steps (#227) + - adds test case from #200 +- fix integer bisection for default implementation of `::inverse_cdf` (#220) + - also add tests from (#185) + +### Other +- Remove "nightly" feature and drop testing requirement for `nightly` (#234) +- Allow some imprecision in specific test case (#215) +- Update CI (#215) + - Check formatting in CI via rustfmt + - Expand CI test job + - Add clippy job to CI +- update README with formatting and adding to "Contributing" (#213) +- Add test asserting that `StatsError` is Sync & Send (#226) +- Rename private struct NonNAN to NonNan (#222) +- Remove `lazy-static` dependency and make FCACHE a proper const (#211) +- crate examples shall be in docstrings instead of README (#213) +- alias `inverse_cdf` as "quantile function" in docs (#213) +- docstrings with math shall be `text` instead of `ignore` (#213) + + + +## [0.16.0] - Adds an `sf` method to the `ContinuousCDF` and `DiscreteCDF` traits - Calculates the survival function (CDF complement) for the distribution. - Survival function implemented for all distributions implementing `ContinuousCDF` and `DiscreteCDF` - See [PR description](https://github.com/statrs-dev/statrs/pull/172) for in-depth changes +- update `nalgebra` to `0.29` -v0.15.0 +## [v0.15.0](https://www.github.com/statrs-dev/statrs/compare/v0.15.0...v0.16.0) - upgrade `nalgebra` to `0.27.1` to avoid RUSTSEC-2021-0070 -v0.14.0 +## [v0.14.0](https://www.github.com/statrs-dev/statrs/compare/v0.14.0...v0.15.0) - upgrade `rand` dependency to `0.8` - fix inaccurate sampling of `Gamma` @@ -23,28 +69,28 @@ v0.14.0 - Moved to dynamic vectors in the MultivariateNormal distribution - Reduced a number of distribution-specific traits into the Distribution and DiscreteDistribution traits -v0.13.0 +## [v0.13.0](https://www.github.com/statrs-dev/statrs/compare/v0.12.0...v0.13.0) - Implemented `MultivariateNormal` distribution (depends on `nalgebra 0.19`) - Implemented `Dirac` distribution - Implemented `Negative Binomial` distribution -v0.12.0 +## [v0.12.0](https://www.github.com/statrs-dev/statrs/compare/v0.11.0...v0.12.0) - upgrade `rand` dependency to `0.7` -v0.11.0 +## [v0.11.0](https://www.github.com/statrs-dev/statrs/compare/v0.10.0...v0.11.0) - upgrade `rand` dependency to `0.6` - Implement `CheckedInverseCDF` and `InverseCDF` for `Normal` distribution -v0.10.0 +## [v0.10.0](https://www.github.com/statrs-dev/statrs/compare/v0.9.0...v0.10.0) - upgrade `rand` dependency to `0.5` - Removes the `Distribution` trait in favor of the `rand::distributions::Distribution` trait - Removed functions deprecated in `0.8.0` (`periodic`, `periodic_custom`, `sinusoidal`, `sinusoidal_custom`) -v0.9.0 +## [v0.9.0](https://www.github.com/statrs-dev/statrs/compare/v0.16.0...v0.17.0) - implemented infinite sequence generator for periodic sequence - implemented infinite sequence generator for sinusoidal sequence @@ -56,7 +102,7 @@ v0.9.0 - Implemented `Entropy` trait for the `Categorical` distribution - Add a `checked_` interface to all distribution methods and functions that may panic -v0.8.0 +## [v0.8.0](https://www.github.com/statrs-dev/statrs/compare/v0.16.0...v0.17.0) - `cdf(x)`, `pdf(x)` and `pmf(x)` now return the correct value instead of panicking when `x` is outside the range of values that the distribution can attain. - Fixed a bug in the `Uniform` distribution implementation where samples were drawn from range `[min, max + 1)` instead of `[min, max]`. The samples are now drawn correctly from the range `[min, max]`. @@ -93,14 +139,14 @@ assert!(x.min().is_nan()); Since the regression affects a very slim edge-case and the fix is very simple, no breaking changes to the `Statistics` API was deemed necessary -v0.7.0 +## [v0.7.0](https://www.github.com/statrs-dev/statrs/compare/v0.6.0...v0.7.0) - Implemented `Categorical` distribution - Implemented `Erlang` distribution - Implemented `Multinomial` distribution - New `InverseCDF` trait for distributions that implement the inverse cdf function -v0.6.0 +## [v0.6.0](https://www.github.com/statrs-dev/statrs/compare/v0.16.0...v0.17.0) - `gamma::gamma_ur`, `gamma::gamma_ui`, `gamma::gamma_lr`, and `gamma::gamma_li` now follow strict gamma function domain, panicking if `a` or `x` are not in `(0, +inf)` - `beta::beta_reg` no longer allows `0.0` for `a` or `b` arguments @@ -131,11 +177,11 @@ v0.6.0 - `Hypergeometric` now implements `Discrete` rather than `Discrete` - `Poisson` now implements `Discrete` rather than `Discrete` -v0.5.1 +## [v0.5.1](https://www.github.com/statrs-dev/statrs/compare/v0.5.0...v0.5.1) - Fixed critical bug in `normal::sample_unchecked` where it was returning `NaN` -v0.5.0 +## [v0.5.0](https://www.github.com/statrs-dev/statrs/compare/v0.4.0...v0.5.0) - Implemented the `logistic::logistic` special function - Implemented the `logistic::logit` special function @@ -150,22 +196,22 @@ v0.5.0 - `Binomial::pdf` and `Binomial::ln_pdf` now panic if `x > n` or `x < 0` - `Bernoulli::pdf` and `Bernoulli::ln_pdf` now panic if `x > 1` or `x < 0` -v0.4.0 +## [v0.4.0] - Implemented the `exponential::integral` special function - Implemented the `Cauchy` (otherwise known as the `Lorenz`) distribution - Implemented the `Dirichlet` distribution - `Continuous` and `Discrete` traits no longer dependent on `Distribution` trait -v0.3.2 +## [v0.3.2] - Implemented the `FisherSnedecor` (F) distribution -v0.3.1 +## [v0.3.1] - Removed print statements from `ln_pdf` method in `Beta` distribution -v0.3.0 +## [v0.3.0] - Moved methods `min` and `max` out of trait `Univariate` into their own respective traits `Min` and `Max` - Traits `Min`, `Max`, `Mean`, `Variance`, `Entropy`, `Skewness`, `Median`, and `Mode` moved from `distribution` module to `statistics` module @@ -180,7 +226,7 @@ v0.3.0 - `InplaceStatistics` renamed to `OrderStatistics`, all methods in `InplaceStatistics` have `_inplace` trimmed from method name. - Inverse DiGamma function implemented with signature `gamma::inv_digamma(x: f64) -> f64` -v0.2.0 +## [v0.2.0] - Created `statistics` module and `Statistics` trait - `Statistics` trait implementation for `[f64]` diff --git a/Cargo.lock.MSRV b/Cargo.lock.MSRV new file mode 100644 index 00000000..3dcbc271 --- /dev/null +++ b/Cargo.lock.MSRV @@ -0,0 +1,848 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "aho-corasick" +version = "1.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e60d3430d3a69478ad0993f19238d2df97c507009a52b3c10addcd7f6bcb916" +dependencies = [ + "memchr", +] + +[[package]] +name = "anes" +version = "0.1.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" + +[[package]] +name = "anstyle" +version = "1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" + +[[package]] +name = "anyhow" +version = "1.0.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" + +[[package]] +name = "approx" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cab112f0a86d568ea0e627cc1d6be74a1e9cd55214684db5561995f6dad897c6" +dependencies = [ + "num-traits", +] + +[[package]] +name = "autocfg" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" + +[[package]] +name = "bumpalo" +version = "3.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" + +[[package]] +name = "bytemuck" +version = "1.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773d90827bc3feecfb67fab12e24de0749aad83c74b9504ecde46237b5cd24e2" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "cast" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "ciborium" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42e69ffd6f0917f5c029256a24d0161db17cea3997d185db0d35926308770f0e" +dependencies = [ + "ciborium-io", + "ciborium-ll", + "serde", +] + +[[package]] +name = "ciborium-io" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "05afea1e0a06c9be33d539b876f1ce3692f4afea2cb41f740e7743225ed1c757" + +[[package]] +name = "ciborium-ll" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57663b653d948a338bfb3eeba9bb2fd5fcfaecb9e199e87e1eda4d9e8b240fd9" +dependencies = [ + "ciborium-io", + "half", +] + +[[package]] +name = "clap" +version = "4.5.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e5a21b8495e732f1b3c364c9949b201ca7bae518c502c80256c96ad79eaf6ac" +dependencies = [ + "clap_builder", +] + +[[package]] +name = "clap_builder" +version = "4.5.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8cf2dd12af7a047ad9d6da2b6b249759a22a7abc0f474c1dae1777afa4b21a73" +dependencies = [ + "anstyle", + "clap_lex", +] + +[[package]] +name = "clap_lex" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" + +[[package]] +name = "criterion" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +dependencies = [ + "anes", + "cast", + "ciborium", + "clap", + "criterion-plot", + "is-terminal", + "itertools", + "num-traits", + "once_cell", + "oorandom", + "plotters", + "rayon", + "regex", + "serde", + "serde_derive", + "serde_json", + "tinytemplate", + "walkdir", +] + +[[package]] +name = "criterion-plot" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +dependencies = [ + "cast", + "itertools", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "either" +version = "1.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "half" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +dependencies = [ + "cfg-if", + "crunchy", +] + +[[package]] +name = "hermit-abi" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" + +[[package]] +name = "is-terminal" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" +dependencies = [ + "hermit-abi", + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "itertools" +version = "0.10.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" +dependencies = [ + "either", +] + +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + +[[package]] +name = "js-sys" +version = "0.3.70" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" +dependencies = [ + "wasm-bindgen", +] + +[[package]] +name = "libc" +version = "0.2.158" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" + +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + +[[package]] +name = "log" +version = "0.4.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" + +[[package]] +name = "matrixmultiply" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9380b911e3e96d10c1f415da0876389aaf1b56759054eeb0de7df940c456ba1a" +dependencies = [ + "autocfg", + "rawpointer", +] + +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + +[[package]] +name = "nalgebra" +version = "0.32.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5c17de023a86f59ed79891b2e5d5a94c705dbe904a5b5c9c952ea6221b03e4" +dependencies = [ + "approx", + "matrixmultiply", + "nalgebra-macros", + "num-complex", + "num-rational", + "num-traits", + "rand", + "rand_distr", + "simba", + "typenum", +] + +[[package]] +name = "nalgebra-macros" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "254a5372af8fc138e36684761d3c0cdb758a4410e938babcff1c860ce14ddbfc" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "num-complex" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "oorandom" +version = "11.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" + +[[package]] +name = "paste" +version = "1.0.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" + +[[package]] +name = "plotters" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a15b6eccb8484002195a3e44fe65a4ce8e93a625797a063735536fd59cb01cf3" +dependencies = [ + "num-traits", + "plotters-backend", + "plotters-svg", + "wasm-bindgen", + "web-sys", +] + +[[package]] +name = "plotters-backend" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "414cec62c6634ae900ea1c56128dfe87cf63e7caece0852ec76aba307cebadb7" + +[[package]] +name = "plotters-svg" +version = "0.3.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "81b30686a7d9c3e010b84284bdd26a29f2138574f52f5eb6f794fc0ad924e705" +dependencies = [ + "plotters-backend", +] + +[[package]] +name = "ppv-lite86" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77957b295656769bb8ad2b6a6b09d897d94f05c41b069aede1fcdaa675eaea04" +dependencies = [ + "zerocopy", +] + +[[package]] +name = "proc-macro2" +version = "1.0.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "rawpointer" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "60a357793950651c4ed0f3f52338f53b2f809f32d83a07f72909fa13e4c6c1e3" + +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "regex" +version = "1.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" +dependencies = [ + "aho-corasick", + "memchr", + "regex-automata", + "regex-syntax", +] + +[[package]] +name = "regex-automata" +version = "0.4.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +dependencies = [ + "aho-corasick", + "memchr", + "regex-syntax", +] + +[[package]] +name = "regex-syntax" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" + +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + +[[package]] +name = "safe_arch" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3460605018fdc9612bce72735cba0d27efbcd9904780d44c7e3a9948f96148a" +dependencies = [ + "bytemuck", +] + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "serde" +version = "1.0.209" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.209" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.128" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" +dependencies = [ + "itoa", + "memchr", + "ryu", + "serde", +] + +[[package]] +name = "simba" +version = "0.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "061507c94fc6ab4ba1c9a0305018408e312e17c041eb63bef8aa726fa33aceae" +dependencies = [ + "approx", + "num-complex", + "num-traits", + "paste", + "wide", +] + +[[package]] +name = "statrs" +version = "0.17.1" +dependencies = [ + "anyhow", + "approx", + "criterion", + "nalgebra", + "num-traits", + "rand", +] + +[[package]] +name = "syn" +version = "2.0.77" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "tinytemplate" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "typenum" +version = "1.17.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "wasm-bindgen" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" +dependencies = [ + "cfg-if", + "once_cell", + "wasm-bindgen-macro", +] + +[[package]] +name = "wasm-bindgen-backend" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" +dependencies = [ + "bumpalo", + "log", + "once_cell", + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-macro" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" +dependencies = [ + "quote", + "wasm-bindgen-macro-support", +] + +[[package]] +name = "wasm-bindgen-macro-support" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "wasm-bindgen-backend", + "wasm-bindgen-shared", +] + +[[package]] +name = "wasm-bindgen-shared" +version = "0.2.93" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" + +[[package]] +name = "web-sys" +version = "0.3.70" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + +[[package]] +name = "wide" +version = "0.7.28" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b828f995bf1e9622031f8009f8481a85406ce1f4d4588ff746d872043e855690" +dependencies = [ + "bytemuck", + "safe_arch", +] + +[[package]] +name = "winapi-util" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" +dependencies = [ + "windows-sys 0.59.0", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "byteorder", + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/Cargo.toml b/Cargo.toml index d8e6eadf..6c80cc70 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,34 +1,58 @@ [package] - name = "statrs" -version = "0.16.0" +version = "0.17.1" authors = ["Michael Ma"] description = "Statistical computing library for Rust" license = "MIT" keywords = ["probability", "statistics", "stats", "distribution", "math"] categories = ["science"] -documentation = "https://docs.rs/statrs/0.15.0/statrs/" -homepage = "https://github.com/boxtown/statrs" -repository = "https://github.com/boxtown/statrs" -edition = "2018" +homepage = "https://github.com/statrs-dev/statrs" +repository = "https://github.com/statrs-dev/statrs" +edition = "2021" + +include = ["CHANGELOG.md", "LICENSE.md", "src/", "tests/"] + +# When changing MSRV: Also update the README +rust-version = "1.66.0" [lib] name = "statrs" path = "src/lib.rs" +[[bench]] +name = "order_statistics" +harness = false +required-features = ["rand"] + [features] -nightly = [] +default = ["nalgebra", "rand"] +nalgebra = ["dep:nalgebra"] +rand = ["dep:rand", "nalgebra?/rand"] [dependencies] -rand = "0.8" -nalgebra = { version = "0.29", features = ["rand"] } approx = "0.5.0" num-traits = "0.2.14" -lazy_static = "1.4.0" + +[dependencies.rand] +version = "0.8" +optional = true + +[dependencies.nalgebra] +version = "0.32" +optional = true +default-features = false +features = ["std"] [dev-dependencies] -criterion = "0.3.3" +criterion = "0.5" +anyhow = "1.0" -[[bench]] -name = "order_statistics" -harness = false +[dev-dependencies.nalgebra] +version = "0.32" +default-features = false +features = ["macros"] + +[lints.rust.unexpected_cfgs] +level = "warn" +# Set by cargo-llvm-cov when running on nightly +check-cfg = ['cfg(coverage_nightly)'] diff --git a/README.md b/README.md index 4521bb9b..973f0046 100644 --- a/README.md +++ b/README.md @@ -1,124 +1,117 @@ -# statrs - -[![Build Status](https://travis-ci.org/boxtown/statrs.svg?branch=master)](https://travis-ci.org/boxtown/statrs) -[![MIT licensed](https://img.shields.io/badge/license-MIT-blue.svg)](./LICENSE.md) -[![Crates.io](https://img.shields.io/crates/v/statrs.svg)](https://crates.io/crates/statrs) - -## Current Version: v0.16.0 - -Should work for both nightly and stable Rust. - -**NOTE:** While I will try to maintain backwards compatibility as much as possible, since this is still a 0.x.x project the API is not considered stable and thus subject to possible breaking changes up until v1.0.0 - -## Description - -Statrs provides a host of statistical utilities for Rust scientific computing. -Included are a number of common distributions that can be sampled (i.e. Normal, Exponential, -Student's T, Gamma, Uniform, etc.) plus common statistical functions like the gamma function, -beta function, and error function. - -This library is a work-in-progress port of the statistical capabilities -in the C# Math.NET library. All unit tests in the library borrowed from Math.NET when possible -and filled-in when not. - -This library is a work-in-progress and not complete. Planned for future releases are continued implementations -of distributions as well as porting over more statistical utilities - -Please check out the documentation [here](https://docs.rs/statrs/*/statrs/) - -## Usage - -Add the most recent release to your `Cargo.toml` - -```Rust -[dependencies] -statrs = "0.16" -``` - -## Examples - -Statrs comes with a number of commonly used distributions including Normal, Gamma, Student's T, Exponential, Weibull, etc. -The common use case is to set up the distributions and sample from them which depends on the `Rand` crate for random number generation - -```Rust -use statrs::distribution::Exp; -use rand::distributions::Distribution; - -let mut r = rand::rngs::OsRng; -let n = Exp::new(0.5).unwrap(); -print!("{}", n.sample(&mut r)); -``` - -Statrs also comes with a number of useful utility traits for more detailed introspection of distributions - -```Rust -use statrs::distribution::{Exp, Continuous, ContinuousCDF}; -use statrs::statistics::Distribution; - -let n = Exp::new(1.0).unwrap(); -assert_eq!(n.mean(), Some(1.0)); -assert_eq!(n.variance(), Some(1.0)); -assert_eq!(n.entropy(), Some(1.0)); -assert_eq!(n.skewness(), Some(2.0)); -assert_eq!(n.cdf(1.0), 0.6321205588285576784045); -assert_eq!(n.pdf(1.0), 0.3678794411714423215955); -``` - -as well as utility functions including `erf`, `gamma`, `ln_gamma`, `beta`, etc. - -```Rust -use statrs::statistics::Distribution; -use statrs::distribution::FisherSnedecor; - -let n = FisherSnedecor::new(1.0, 1.0).unwrap(); -assert!(n.variance().is_none()); -``` - -## Contributing - -Want to contribute? Check out some of the issues marked [help wanted](https://github.com/statrs-dev/statrs/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22) - -### How to contribute - -Clone the repo: - -``` -git clone https://github.com/statrs-dev/statrs -``` - -Create a feature branch: - -``` -git checkout -b master -``` - -After commiting your code: - -``` -git push -u origin -``` - -Then submit a PR, preferably referencing the relevant issue. - -### Style - -This repo makes use of `rustfmt` with the configuration specified in `rustfmt.toml`. -See https://github.com/rust-lang-nursery/rustfmt for instructions on installation -and usage and run the formatter using `rustfmt --write-mode overwrite *.rs` in -the `src` directory before committing. - -### Commit messages - -Please be explicit and and purposeful with commit messages. - -#### Bad - -``` -Modify test code -``` - -#### Good - -``` -test: Update statrs::distribution::Normal test_cdf -``` +# statrs + +![tests][actions-test-badge] +[![MIT licensed][license-badge]](./LICENSE.md) +[![Crate][crates-badge]][crates-url] +[![docs.rs](https://img.shields.io/docsrs/statrs)][docs-url] + +[actions-test-badge]: https://github.com/statrs-dev/statrs/actions/workflows/test.yml/badge.svg +[crates-badge]: https://img.shields.io/crates/v/statrs.svg +[crates-url]: https://crates.io/crates/statrs +[license-badge]: https://img.shields.io/badge/license-MIT-blue.svg +[docsrs-badge]: https://img.shields.io/docsrs/statrs +[docs-url]: https://docs.rs/statrs/*/statrs +[codecov-badge]: https://codecov.io/gh/statrs-dev/statrs/graph/badge.svg?token=XtMSMYXvIf +[codecov-url]: https://codecov.io/gh/statrs-dev/statrs + +Statrs provides a host of statistical utilities for Rust scientific computing. + +Included are a number of common distributions that can be sampled (i.e. Normal, Exponential, Student's T, Gamma, Uniform, etc.) plus common statistical functions like the gamma function, beta function, and error function. + +This library began as port of the statistical capabilities in the C# Math.NET library. +All unit tests in the library borrowed from Math.NET when possible and filled-in when not. +Planned for future releases are continued implementations of distributions as well as porting over more statistical utilities. + +Please check out the documentation [here][docs-url]. + +## Usage + +Add the most recent release to your `Cargo.toml` + +```toml +[dependencies] +statrs = "*" # replace * by the latest version of the crate. +``` + +For examples, view [the docs](https://docs.rs/statrs/*/statrs/). + +### Running tests + +If you'd like to run all suggested tests, you'll need to download some data from +NIST, we have a script for this and formatting the data in the `tests/` folder. + +```sh +cargo test +./tests/gather_nist_data.sh && cargo test -- --include-ignored nist_ +``` + +If you'd like to modify where the data is downloaded, you can use the environment variable, +`STATRS_NIST_DATA_DIR` for running the script and the tests. + +## Minimum supported Rust version (MSRV) + +This crate requires a Rust version of 1.66.0 or higher. Increases in MSRV will be considered a semver non-breaking API change and require a version increase (PATCH until 1.0.0, MINOR after 1.0.0). + +## Contributing + +Thanks for your help to improve the project! +**No contribution is too small and all contributions are valued.** + +Suggestions if you don't know where to start, +- documentation is a great place to start, as you'll be able to identify the value of existing documentation better than its authors. +- tests are valuable in demonstrating correct behavior, you can review test coverage on the [CodeCov Report][codecov-url]*, not live until [#229](https://github.com/statrs-dev/statrs/pull/229) merged. +- check out some of the issues marked [help wanted](https://github.com/statrs-dev/statrs/issues?q=is%3Aissue+is%3Aopen+label%3A%22help+wanted%22). +- look at what's not included from Math.NET's [Distributions](https://github.com/mathnet/mathnet-numerics/tree/master/src/Numerics/Distributions), [Statistics](https://github.com/mathnet/mathnet-numerics/tree/master/src/Numerics/Statistics), or related. + +### How to contribute + +Clone the repo: + +``` +git clone https://github.com/statrs-dev/statrs +``` + +Create a feature branch: + +``` +git checkout -b master +``` + +Write your code and docs, then ensure it is formatted: + +``` +cargo fmt +``` + +Add `--check` to view the diff without making file changes. +Our CI will `fmt`, but less chores in commit history are appreciated. + +After commiting your code: + +``` +git push -u origin +``` + +Then submit a PR, preferably referencing the relevant issue, if it exists. + +### Commit messages + +Please be explicit and and purposeful with commit messages. +[Conventional Commits](https://www.conventionalcommits.org/en/v1.0.0/#summary) encouraged. + +#### Bad + +``` +Modify test code +``` + +#### Good + +``` +test: Update statrs::distribution::Normal test_cdf +``` + +### Communication Expectations + +Please allow at least one week before pinging issues/pr's. + diff --git a/benches/order_statistics.rs b/benches/order_statistics.rs index fa6fdd26..d94902c9 100644 --- a/benches/order_statistics.rs +++ b/benches/order_statistics.rs @@ -1,4 +1,3 @@ -extern crate rand; extern crate statrs; use criterion::{black_box, criterion_group, criterion_main, BatchSize, Criterion}; use rand::prelude::*; diff --git a/data/nist/lew.txt b/data/nist/lew.txt deleted file mode 100644 index 9e38a720..00000000 --- a/data/nist/lew.txt +++ /dev/null @@ -1,200 +0,0 @@ --213 --564 --35 --15 -141 -115 --420 --360 -203 --338 --431 -194 --220 --513 -154 --125 --559 -92 --21 --579 --52 -99 --543 --175 -162 --457 --346 -204 --300 --474 -164 --107 --572 --8 -83 --541 --224 -180 --420 --374 -201 --236 --531 -83 -27 --564 --112 -131 --507 --254 -199 --311 --495 -143 --46 --579 --90 -136 --472 --338 -202 --287 --477 -169 --124 --568 -17 -48 --568 --135 -162 --430 --422 -172 --74 --577 --13 -92 --534 --243 -194 --355 --465 -156 --81 --578 --64 -139 --449 --384 -193 --198 --538 -110 --44 --577 --6 -66 --552 --164 -161 --460 --344 -205 --281 --504 -134 --28 --576 --118 -156 --437 --381 -200 --220 --540 -83 -11 --568 --160 -172 --414 --408 -188 --125 --572 --32 -139 --492 --321 -205 --262 --504 -142 --83 --574 -0 -48 --571 --106 -137 --501 --266 -190 --391 --406 -194 --186 --553 -83 --13 --577 --49 -103 --515 --280 -201 -300 --506 -131 --45 --578 --80 -138 --462 --361 -201 --211 --554 -32 -74 --533 --235 -187 --372 --442 -182 --147 --566 -25 -68 --535 --244 -194 --351 --463 -174 --125 --570 -15 -72 --550 --190 -172 --424 --385 -198 --218 --536 -96 \ No newline at end of file diff --git a/data/nist/lottery.txt b/data/nist/lottery.txt deleted file mode 100644 index a1880747..00000000 --- a/data/nist/lottery.txt +++ /dev/null @@ -1,218 +0,0 @@ -162 -671 -933 -414 -788 -730 -817 -33 -536 -875 -670 -236 -473 -167 -877 -980 -316 -950 -456 -92 -517 -557 -956 -954 -104 -178 -794 -278 -147 -773 -437 -435 -502 -610 -582 -780 -689 -562 -964 -791 -28 -97 -848 -281 -858 -538 -660 -972 -671 -613 -867 -448 -738 -966 -139 -636 -847 -659 -754 -243 -122 -455 -195 -968 -793 -59 -730 -361 -574 -522 -97 -762 -431 -158 -429 -414 -22 -629 -788 -999 -187 -215 -810 -782 -47 -34 -108 -986 -25 -644 -829 -630 -315 -567 -919 -331 -207 -412 -242 -607 -668 -944 -749 -168 -864 -442 -533 -805 -372 -63 -458 -777 -416 -340 -436 -140 -919 -350 -510 -572 -905 -900 -85 -389 -473 -758 -444 -169 -625 -692 -140 -897 -672 -288 -312 -860 -724 -226 -884 -508 -976 -741 -476 -417 -831 -15 -318 -432 -241 -114 -799 -955 -833 -358 -935 -146 -630 -830 -440 -642 -356 -373 -271 -715 -367 -393 -190 -669 -8 -861 -108 -795 -269 -590 -326 -866 -64 -523 -862 -840 -219 -382 -998 -4 -628 -305 -747 -247 -34 -747 -729 -645 -856 -974 -24 -568 -24 -694 -608 -480 -410 -729 -947 -293 -53 -930 -223 -203 -677 -227 -62 -455 -387 -318 -562 -242 -428 -968 \ No newline at end of file diff --git a/data/nist/mavro.txt b/data/nist/mavro.txt deleted file mode 100644 index b904e6aa..00000000 --- a/data/nist/mavro.txt +++ /dev/null @@ -1,50 +0,0 @@ -2.00180 -2.00170 -2.00180 -2.00190 -2.00180 -2.00170 -2.00150 -2.00140 -2.00150 -2.00150 -2.00170 -2.00180 -2.00180 -2.00190 -2.00190 -2.00210 -2.00200 -2.00160 -2.00140 -2.00130 -2.00130 -2.00150 -2.00150 -2.00160 -2.00150 -2.00140 -2.00130 -2.00140 -2.00150 -2.00140 -2.00150 -2.00160 -2.00150 -2.00160 -2.00190 -2.00200 -2.00200 -2.00210 -2.00220 -2.00230 -2.00240 -2.00250 -2.00270 -2.00260 -2.00260 -2.00260 -2.00270 -2.00260 -2.00250 -2.00240 \ No newline at end of file diff --git a/data/nist/michaelso.txt b/data/nist/michaelso.txt deleted file mode 100644 index 2e436816..00000000 --- a/data/nist/michaelso.txt +++ /dev/null @@ -1,100 +0,0 @@ -299.85 -299.74 -299.90 -300.07 -299.93 -299.85 -299.95 -299.98 -299.98 -299.88 -300.00 -299.98 -299.93 -299.65 -299.76 -299.81 -300.00 -300.00 -299.96 -299.96 -299.96 -299.94 -299.96 -299.94 -299.88 -299.80 -299.85 -299.88 -299.90 -299.84 -299.83 -299.79 -299.81 -299.88 -299.88 -299.83 -299.80 -299.79 -299.76 -299.80 -299.88 -299.88 -299.88 -299.86 -299.72 -299.72 -299.62 -299.86 -299.97 -299.95 -299.88 -299.91 -299.85 -299.87 -299.84 -299.84 -299.85 -299.84 -299.84 -299.84 -299.89 -299.81 -299.81 -299.82 -299.80 -299.77 -299.76 -299.74 -299.75 -299.76 -299.91 -299.92 -299.89 -299.86 -299.88 -299.72 -299.84 -299.85 -299.85 -299.78 -299.89 -299.84 -299.78 -299.81 -299.76 -299.81 -299.79 -299.81 -299.82 -299.85 -299.87 -299.87 -299.81 -299.74 -299.81 -299.94 -299.95 -299.80 -299.81 -299.87 \ No newline at end of file diff --git a/data/nist/numacc1.txt b/data/nist/numacc1.txt deleted file mode 100644 index 79dec5da..00000000 --- a/data/nist/numacc1.txt +++ /dev/null @@ -1,3 +0,0 @@ -10000001 -10000003 -10000002 \ No newline at end of file diff --git a/data/nist/numacc2.txt b/data/nist/numacc2.txt deleted file mode 100644 index 8a345dad..00000000 --- a/data/nist/numacc2.txt +++ /dev/null @@ -1,1001 +0,0 @@ -1.2 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 -1.1 -1.3 \ No newline at end of file diff --git a/data/nist/numacc3.txt b/data/nist/numacc3.txt deleted file mode 100644 index c7313205..00000000 --- a/data/nist/numacc3.txt +++ /dev/null @@ -1,1001 +0,0 @@ -1000000.2 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 -1000000.1 -1000000.3 \ No newline at end of file diff --git a/data/nist/numacc4.txt b/data/nist/numacc4.txt deleted file mode 100644 index 63647051..00000000 --- a/data/nist/numacc4.txt +++ /dev/null @@ -1,1001 +0,0 @@ -10000000.2 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 -10000000.1 -10000000.3 \ No newline at end of file diff --git a/rustfmt.toml b/rustfmt.toml deleted file mode 100644 index b42e764f..00000000 --- a/rustfmt.toml +++ /dev/null @@ -1,4 +0,0 @@ -# Run using `rustfmt --write-mode overwrite *.rs` in the -# root of the src directory. You may still get some -# formatting errors (whitespace etc) which should be -# fixed manually before committing. diff --git a/src/distribution/bernoulli.rs b/src/distribution/bernoulli.rs index e31f9c5f..28f9d104 100644 --- a/src/distribution/bernoulli.rs +++ b/src/distribution/bernoulli.rs @@ -1,7 +1,5 @@ -use crate::distribution::{Binomial, Discrete, DiscreteCDF}; +use crate::distribution::{Binomial, BinomialError, Discrete, DiscreteCDF}; use crate::statistics::*; -use crate::Result; -use rand::Rng; /// Implements the /// [Bernoulli](https://en.wikipedia.org/wiki/Bernoulli_distribution) @@ -20,7 +18,7 @@ use rand::Rng; /// assert_eq!(n.pmf(0), 0.5); /// assert_eq!(n.pmf(1), 0.5); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Bernoulli { b: Binomial, } @@ -45,7 +43,7 @@ impl Bernoulli { /// result = Bernoulli::new(-0.5); /// assert!(result.is_err()); /// ``` - pub fn new(p: f64) -> Result { + pub fn new(p: f64) -> Result { Binomial::new(p, 1).map(|b| Bernoulli { b }) } @@ -80,8 +78,15 @@ impl Bernoulli { } } +impl std::fmt::Display for Bernoulli { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Bernoulli({})", self.p()) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Bernoulli { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { rng.gen_bool(self.p()) as u8 as f64 } } @@ -92,7 +97,7 @@ impl DiscreteCDF for Bernoulli { /// /// # Formula /// - /// ```ignore + /// ```text /// if x < 0 { 0 } /// else if x >= 1 { 1 } /// else { 1 - p } @@ -101,12 +106,12 @@ impl DiscreteCDF for Bernoulli { self.b.cdf(x) } - /// Calculates the survival function for the + /// Calculates the survival function for the /// bernoulli distribution at `x`. /// /// # Formula /// - /// ```ignore + /// ```text /// if x < 0 { 1 } /// else if x >= 1 { 0 } /// else { p } @@ -123,7 +128,7 @@ impl Min for Bernoulli { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> u64 { @@ -138,7 +143,7 @@ impl Max for Bernoulli { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 /// ``` fn max(&self) -> u64 { @@ -152,41 +157,44 @@ impl Distribution for Bernoulli { /// /// # Formula /// - /// ```ignore + /// ```text /// p /// ``` fn mean(&self) -> Option { self.b.mean() } + /// Returns the variance of the bernoulli /// distribution /// /// # Formula /// - /// ```ignore + /// ```text /// p * (1 - p) /// ``` fn variance(&self) -> Option { self.b.variance() } + /// Returns the entropy of the bernoulli /// distribution /// /// # Formula /// - /// ```ignore + /// ```text /// q = (1 - p) /// -q * ln(q) - p * ln(p) /// ``` fn entropy(&self) -> Option { self.b.entropy() } + /// Returns the skewness of the bernoulli /// distribution /// /// # Formula /// - /// ```ignore + /// ```text /// q = (1 - p) /// (1 - 2p) / sqrt(p * q) /// ``` @@ -201,7 +209,7 @@ impl Median for Bernoulli { /// /// # Formula /// - /// ```ignore + /// ```text /// if p < 0.5 { 0 } /// else if p > 0.5 { 1 } /// else { 0.5 } @@ -216,7 +224,7 @@ impl Mode> for Bernoulli { /// /// # Formula /// - /// ```ignore + /// ```text /// if p < 0.5 { 0 } /// else { 1 } /// ``` @@ -231,7 +239,7 @@ impl Discrete for Bernoulli { /// /// # Formula /// - /// ```ignore + /// ```text /// if x == 0 { 1 - p } /// else { p } /// ``` @@ -244,7 +252,7 @@ impl Discrete for Bernoulli { /// /// # Formula /// - /// ```ignore + /// ```text /// else if x == 0 { ln(1 - p) } /// else { ln(p) } /// ``` @@ -254,92 +262,54 @@ impl Discrete for Bernoulli { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod testing { - use std::fmt::Debug; - use crate::distribution::DiscreteCDF; - use super::Bernoulli; - - fn try_create(p: f64) -> Bernoulli { - let n = Bernoulli::new(p); - assert!(n.is_ok()); - n.unwrap() - } + use super::*; + use crate::testing_boiler; - fn create_case(p: f64) { - let dist = try_create(p); - assert_eq!(p, dist.p()); - } - - fn bad_create_case(p: f64) { - let n = Bernoulli::new(p); - assert!(n.is_err()); - } - - fn get_value(p: f64, eval: F) -> T - where T: PartialEq + Debug, - F: Fn(Bernoulli) -> T - { - let n = try_create(p); - eval(n) - } - - fn test_case(p: f64, expected: T, eval: F) - where T: PartialEq + Debug, - F: Fn(Bernoulli) -> T - { - let x = get_value(p, eval); - assert_eq!(expected, x); - } - - fn test_almost(p: f64, expected: f64, acc: f64, eval: F) - where F: Fn(Bernoulli) -> f64 - { - let x = get_value(p, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(p: f64; Bernoulli; BinomialError); #[test] fn test_create() { - create_case(0.0); - create_case(0.3); - create_case(1.0); + create_ok(0.0); + create_ok(0.3); + create_ok(1.0); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN); - bad_create_case(-1.0); - bad_create_case(2.0); + create_err(f64::NAN); + create_err(-1.0); + create_err(2.0); } #[test] fn test_cdf_upper_bound() { let cdf = |arg: u64| move |x: Bernoulli| x.cdf(arg); - test_case(0.3, 1., cdf(1)); + test_relative(0.3, 1., cdf(1)); } #[test] fn test_sf_upper_bound() { let sf = |arg: u64| move |x: Bernoulli| x.sf(arg); - test_case(0.3, 0., sf(1)); + test_relative(0.3, 0., sf(1)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: Bernoulli| x.cdf(arg); - test_case(0.0, 1.0, cdf(0)); - test_case(0.0, 1.0, cdf(1)); - test_almost(0.3, 0.7, 1e-15, cdf(0)); - test_almost(0.7, 0.3, 1e-15, cdf(0)); + test_relative(0.0, 1.0, cdf(0)); + test_relative(0.0, 1.0, cdf(1)); + test_absolute(0.3, 0.7, 1e-15, cdf(0)); + test_absolute(0.7, 0.3, 1e-15, cdf(0)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: Bernoulli| x.sf(arg); - test_case(0.0, 0.0, sf(0)); - test_case(0.0, 0.0, sf(1)); - test_almost(0.3, 0.3, 1e-15, sf(0)); - test_almost(0.7, 0.7, 1e-15, sf(0)); + test_relative(0.0, 0.0, sf(0)); + test_relative(0.0, 0.0, sf(1)); + test_absolute(0.3, 0.3, 1e-15, sf(0)); + test_absolute(0.7, 0.7, 1e-15, sf(0)); } } diff --git a/src/distribution/beta.rs b/src/distribution/beta.rs index 6dd5adc3..2741889e 100644 --- a/src/distribution/beta.rs +++ b/src/distribution/beta.rs @@ -1,10 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::{beta, gamma}; -use crate::is_zero; use crate::statistics::*; -use crate::{Result, StatsError}; -use core::f64::INFINITY as INF; -use rand::Rng; /// Implements the [Beta](https://en.wikipedia.org/wiki/Beta_distribution) /// distribution @@ -20,12 +16,39 @@ use rand::Rng; /// assert_eq!(n.mean().unwrap(), 0.5); /// assert!(prec::almost_eq(n.pdf(0.5), 1.5, 1e-14)); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Beta { shape_a: f64, shape_b: f64, } +/// Represents the errors that can occur when creating a [`Beta`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum BetaError { + /// Shape A is NaN, zero or negative. + ShapeAInvalid, + + /// Shape B is NaN, zero or negative. + ShapeBInvalid, + + /// Shape A and Shape B are infinite. + BothShapesInfinite, +} + +impl std::fmt::Display for BetaError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + BetaError::ShapeAInvalid => write!(f, "Shape A is NaN, zero or negative"), + BetaError::ShapeBInvalid => write!(f, "Shape B is NaN, zero or negative"), + BetaError::BothShapesInfinite => write!(f, "Shape A and shape B are infinite"), + } + } +} + +impl std::error::Error for BetaError {} + impl Beta { /// Constructs a new beta distribution with shapeA (α) of `shape_a` /// and shapeB (β) of `shape_b` @@ -46,15 +69,19 @@ impl Beta { /// result = Beta::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(shape_a: f64, shape_b: f64) -> Result { - if shape_a.is_nan() - || shape_b.is_nan() - || shape_a.is_infinite() && shape_b.is_infinite() - || shape_a <= 0.0 - || shape_b <= 0.0 - { - return Err(StatsError::BadParams); - }; + pub fn new(shape_a: f64, shape_b: f64) -> Result { + if shape_a.is_nan() || shape_a <= 0.0 { + return Err(BetaError::ShapeAInvalid); + } + + if shape_b.is_nan() || shape_b <= 0.0 { + return Err(BetaError::ShapeBInvalid); + } + + if shape_a.is_infinite() && shape_b.is_infinite() { + return Err(BetaError::BothShapesInfinite); + } + Ok(Beta { shape_a, shape_b }) } @@ -87,8 +114,15 @@ impl Beta { } } +impl std::fmt::Display for Beta { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Beta(a={}, b={})", self.shape_a, self.shape_b) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Beta { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { // Generated by sampling two gamma distributions and normalizing. let x = super::gamma::sample_unchecked(rng, self.shape_a, 1.0); let y = super::gamma::sample_unchecked(rng, self.shape_b, 1.0); @@ -103,7 +137,7 @@ impl ContinuousCDF for Beta { /// /// # Formula /// - /// ```ignore + /// ```text /// I_x(α, β) /// ``` /// @@ -134,7 +168,7 @@ impl ContinuousCDF for Beta { /// /// # Formula /// - /// ```ignore + /// ```text /// I_(1-x)(β, α) /// ``` /// @@ -156,7 +190,27 @@ impl ContinuousCDF for Beta { } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) { 1. - x } else { - beta::beta_reg(self.shape_b, self.shape_a, 1.0 - x) + beta::beta_reg(self.shape_b, self.shape_a, 1.0 - x) + } + } + + /// Calculates the inverse cumulative distribution function for the beta + /// distribution + /// at `x` + /// + /// # Formula + /// + /// ```text + /// I^{-1}_x(α, β) + /// ``` + /// + /// where `α` is shapeA, `β` is shapeB, and `I_x` is the inverse of the + /// regularized lower incomplete beta function + fn inverse_cdf(&self, x: f64) -> f64 { + if !(0.0..=1.0).contains(&x) { + panic!("x must be in [0, 1]"); + } else { + beta::inv_beta_reg(self.shape_a, self.shape_b, x) } } } @@ -168,7 +222,7 @@ impl Min for Beta { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> f64 { @@ -183,7 +237,7 @@ impl Max for Beta { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 /// ``` fn max(&self) -> f64 { @@ -196,7 +250,7 @@ impl Distribution for Beta { /// /// # Formula /// - /// ```ignore + /// ```text /// α / (α + β) /// ``` /// @@ -209,13 +263,14 @@ impl Distribution for Beta { }; Some(mean) } + /// Returns the variance of the beta distribution /// /// # Remarks /// /// # Formula /// - /// ```ignore + /// ```text /// (α * β) / ((α + β)^2 * (α + β + 1)) /// ``` /// @@ -231,11 +286,12 @@ impl Distribution for Beta { }; Some(var) } + /// Returns the entropy of the beta distribution /// /// # Formula /// - /// ```ignore + /// ```text /// ln(B(α, β)) - (α - 1)ψ(α) - (β - 1)ψ(β) + (α + β - 2)ψ(α + β) /// ``` /// @@ -252,11 +308,12 @@ impl Distribution for Beta { }; Some(entr) } + /// Returns the skewness of the Beta distribution /// /// # Formula /// - /// ```ignore + /// ```text /// 2(β - α) * sqrt(α + β + 1) / ((α + β + 2) * sqrt(αβ)) /// ``` /// @@ -290,7 +347,7 @@ impl Mode> for Beta { /// /// # Formula /// - /// ```ignore + /// ```text /// (α - 1) / (α + β - 2) /// ``` /// @@ -314,7 +371,7 @@ impl Continuous for Beta { /// /// # Formula /// - /// ```ignore + /// ```text /// let B(α, β) = Γ(α)Γ(β)/Γ(α + β) /// /// x^(α - 1) * (1 - x)^(β - 1) / B(α, β) @@ -326,13 +383,13 @@ impl Continuous for Beta { 0.0 } else if self.shape_a.is_infinite() { if ulps_eq!(x, 1.0) { - INF + f64::INFINITY } else { 0.0 } } else if self.shape_b.is_infinite() { - if is_zero(x) { - INF + if x == 0.0 { + f64::INFINITY } else { 0.0 } @@ -352,7 +409,7 @@ impl Continuous for Beta { /// /// # Formula /// - /// ```ignore + /// ```text /// let B(α, β) = Γ(α)Γ(β)/Γ(α + β) /// /// ln(x^(α - 1) * (1 - x)^(β - 1) / B(α, β)) @@ -361,18 +418,18 @@ impl Continuous for Beta { /// where `α` is shapeA, `β` is shapeB, and `Γ` is the gamma function fn ln_pdf(&self, x: f64) -> f64 { if !(0.0..=1.0).contains(&x) { - -INF + f64::NEG_INFINITY } else if self.shape_a.is_infinite() { if ulps_eq!(x, 1.0) { - INF + f64::INFINITY } else { - -INF + f64::NEG_INFINITY } } else if self.shape_b.is_infinite() { - if is_zero(x) { - INF + if x == 0.0 { + f64::INFINITY } else { - -INF + f64::NEG_INFINITY } } else if ulps_eq!(self.shape_a, 1.0) && ulps_eq!(self.shape_b, 1.0) { 0.0 @@ -380,17 +437,17 @@ impl Continuous for Beta { let aa = gamma::ln_gamma(self.shape_a + self.shape_b) - gamma::ln_gamma(self.shape_a) - gamma::ln_gamma(self.shape_b); - let bb = if ulps_eq!(self.shape_a, 1.0) && is_zero(x) { + let bb = if ulps_eq!(self.shape_a, 1.0) && x == 0.0 { 0.0 - } else if is_zero(x) { - -INF + } else if x == 0.0 { + f64::NEG_INFINITY } else { (self.shape_a - 1.0) * x.ln() }; let cc = if ulps_eq!(self.shape_b, 1.0) && ulps_eq!(x, 1.0) { 0.0 } else if ulps_eq!(x, 1.0) { - -INF + f64::NEG_INFINITY } else { (self.shape_b - 1.0) * (1.0 - x).ln() }; @@ -400,21 +457,19 @@ impl Continuous for Beta { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use super::*; - use crate::consts::ACC; use super::super::internal::*; - use crate::statistics::*; use crate::testing_boiler; - testing_boiler!((f64, f64), Beta); + testing_boiler!(a: f64, b: f64; Beta; BetaError); #[test] fn test_create() { - let valid = [(1.0, 1.0), (9.0, 1.0), (5.0, 100.0), (1.0, INF), (INF, 1.0)]; - for &arg in valid.iter() { - try_create(arg); + let valid = [(1.0, 1.0), (9.0, 1.0), (5.0, 100.0), (1.0, f64::INFINITY), (f64::INFINITY, 1.0)]; + for (a, b) in valid { + create_ok(a, b); } } @@ -424,18 +479,18 @@ mod tests { (0.0, 0.0), (0.0, 0.1), (1.0, 0.0), - (0.0, INF), - (INF, 0.0), + (0.0, f64::INFINITY), + (f64::INFINITY, 0.0), (f64::NAN, 1.0), (1.0, f64::NAN), (f64::NAN, f64::NAN), (1.0, -1.0), (-1.0, 1.0), (-1.0, -1.0), - (INF, INF), + (f64::INFINITY, f64::INFINITY), ]; - for &arg in invalid.iter() { - bad_create_case(arg); + for (a, b) in invalid { + create_err(a, b); } } @@ -446,11 +501,11 @@ mod tests { ((1.0, 1.0), 0.5), ((9.0, 1.0), 0.9), ((5.0, 100.0), 0.047619047619047619047616), - ((1.0, INF), 0.0), - ((INF, 1.0), 1.0), + ((1.0, f64::INFINITY), 0.0), + ((f64::INFINITY, 1.0), 1.0), ]; - for &(arg, res) in test.iter() { - test_case(arg, res, f); + for ((a, b), res) in test { + test_relative(a, b, res, f); } } @@ -461,11 +516,11 @@ mod tests { ((1.0, 1.0), 1.0 / 12.0), ((9.0, 1.0), 9.0 / 1100.0), ((5.0, 100.0), 500.0 / 1168650.0), - ((1.0, INF), 0.0), - ((INF, 1.0), 0.0), + ((1.0, f64::INFINITY), 0.0), + ((f64::INFINITY, 1.0), 0.0), ]; - for &(arg, res) in test.iter() { - test_case(arg, res, f); + for ((a, b), res) in test { + test_relative(a, b, res, f); } } @@ -476,53 +531,49 @@ mod tests { ((9.0, 1.0), -1.3083356884473304939016015), ((5.0, 100.0), -2.52016231876027436794592), ]; - for &(arg, res) in test.iter() { - test_case(arg, res, f); + for ((a, b), res) in test { + test_relative(a, b, res, f); } - test_case_special((1.0, 1.0), 0.0, 1e-14, f); + test_absolute(1.0, 1.0, 0.0, 1e-14, f); let entropy = |x: Beta| x.entropy(); - test_none((1.0, INF), entropy); - test_none((INF, 1.0), entropy); + test_none(1.0, f64::INFINITY, entropy); + test_none(f64::INFINITY, 1.0, entropy); } #[test] fn test_skewness() { let skewness = |x: Beta| x.skewness().unwrap(); - test_case((1.0, 1.0), 0.0, skewness); - test_case((9.0, 1.0), -1.4740554623801777107177478829, skewness); - test_case((5.0, 100.0), 0.817594109275534303545831591, skewness); - test_case((1.0, INF), 2.0, skewness); - test_case((INF, 1.0), -2.0, skewness); + test_relative(1.0, 1.0, 0.0, skewness); + test_relative(9.0, 1.0, -1.4740554623801777107177478829, skewness); + test_relative(5.0, 100.0, 0.817594109275534303545831591, skewness); + test_relative(1.0, f64::INFINITY, 2.0, skewness); + test_relative(f64::INFINITY, 1.0, -2.0, skewness); } #[test] fn test_mode() { let mode = |x: Beta| x.mode().unwrap(); - test_case((5.0, 100.0), 0.038834951456310676243255386, mode); - test_case((92.0, INF), 0.0, mode); - test_case((INF, 2.0), 1.0, mode); + test_relative(5.0, 100.0, 0.038834951456310676243255386, mode); + test_relative(92.0, f64::INFINITY, 0.0, mode); + test_relative(f64::INFINITY, 2.0, 1.0, mode); } #[test] - #[should_panic] fn test_mode_shape_a_lte_1() { - let mode = |x: Beta| x.mode().unwrap(); - get_value((1.0, 5.0), mode); + test_none(1.0, 5.0, |dist| dist.mode()); } #[test] - #[should_panic] fn test_mode_shape_b_lte_1() { - let mode = |x: Beta| x.mode().unwrap(); - get_value((5.0, 1.0), mode); + test_none(5.0, 1.0, |dist| dist.mode()); } #[test] fn test_min_max() { let min = |x: Beta| x.min(); let max = |x: Beta| x.max(); - test_case((1.0, 1.0), 0.0, min); - test_case((1.0, 1.0), 1.0, max); + test_relative(1.0, 1.0, 0.0, min); + test_relative(1.0, 1.0, 1.0, max); } #[test] @@ -539,28 +590,28 @@ mod tests { ((5.0, 100.0), 0.5, 4.534102298350337661e-23), ((5.0, 100.0), 1.0, 0.0), ((5.0, 100.0), 1.0, 0.0), - ((1.0, INF), 0.0, INF), - ((1.0, INF), 0.5, 0.0), - ((1.0, INF), 1.0, 0.0), - ((INF, 1.0), 0.0, 0.0), - ((INF, 1.0), 0.5, 0.0), - ((INF, 1.0), 1.0, INF), + ((1.0, f64::INFINITY), 0.0, f64::INFINITY), + ((1.0, f64::INFINITY), 0.5, 0.0), + ((1.0, f64::INFINITY), 1.0, 0.0), + ((f64::INFINITY, 1.0), 0.0, 0.0), + ((f64::INFINITY, 1.0), 0.5, 0.0), + ((f64::INFINITY, 1.0), 1.0, f64::INFINITY), ]; - for &(arg, x, expect) in test.iter() { - test_case(arg, expect, f(x)); + for ((a, b), x, expect) in test { + test_relative(a, b, expect, f(x)); } } #[test] fn test_pdf_input_lt_0() { let pdf = |arg: f64| move |x: Beta| x.pdf(arg); - test_case((1.0, 1.0), 0.0, pdf(-1.0)); + test_relative(1.0, 1.0, 0.0, pdf(-1.0)); } #[test] fn test_pdf_input_gt_0() { let pdf = |arg: f64| move |x: Beta| x.pdf(arg); - test_case((1.0, 1.0), 0.0, pdf(2.0)); + test_relative(1.0, 1.0, 0.0, pdf(2.0)); } #[test] @@ -570,34 +621,34 @@ mod tests { ((1.0, 1.0), 0.0, 0.0), ((1.0, 1.0), 0.5, 0.0), ((1.0, 1.0), 1.0, 0.0), - ((9.0, 1.0), 0.0, -INF), + ((9.0, 1.0), 0.0, f64::NEG_INFINITY), ((9.0, 1.0), 0.5, -3.347952867143343092547366497), ((9.0, 1.0), 1.0, 2.1972245773362193827904904738), - ((5.0, 100.0), 0.0, -INF), + ((5.0, 100.0), 0.0, f64::NEG_INFINITY), ((5.0, 100.0), 0.5, -51.447830024537682154565870), - ((5.0, 100.0), 1.0, -INF), - ((1.0, INF), 0.0, INF), - ((1.0, INF), 0.5, -INF), - ((1.0, INF), 1.0, -INF), - ((INF, 1.0), 0.0, -INF), - ((INF, 1.0), 0.5, -INF), - ((INF, 1.0), 1.0, INF), + ((5.0, 100.0), 1.0, f64::NEG_INFINITY), + ((1.0, f64::INFINITY), 0.0, f64::INFINITY), + ((1.0, f64::INFINITY), 0.5, f64::NEG_INFINITY), + ((1.0, f64::INFINITY), 1.0, f64::NEG_INFINITY), + ((f64::INFINITY, 1.0), 0.0, f64::NEG_INFINITY), + ((f64::INFINITY, 1.0), 0.5, f64::NEG_INFINITY), + ((f64::INFINITY, 1.0), 1.0, f64::INFINITY), ]; - for &(arg, x, expect) in test.iter() { - test_case(arg, expect, f(x)); + for ((a, b), x, expect) in test { + test_relative(a, b, expect, f(x)); } } #[test] fn test_ln_pdf_input_lt_0() { let ln_pdf = |arg: f64| move |x: Beta| x.ln_pdf(arg); - test_case((1.0, 1.0), -INF, ln_pdf(-1.0)); + test_relative(1.0, 1.0, f64::NEG_INFINITY, ln_pdf(-1.0)); } #[test] fn test_ln_pdf_input_gt_1() { let ln_pdf = |arg: f64| move |x: Beta| x.ln_pdf(arg); - test_case((1.0, 1.0), -INF, ln_pdf(2.0)); + test_relative(1.0, 1.0, f64::NEG_INFINITY, ln_pdf(2.0)); } #[test] @@ -613,15 +664,15 @@ mod tests { ((5.0, 100.0), 0.0, 0.0), ((5.0, 100.0), 0.5, 1.0), ((5.0, 100.0), 1.0, 1.0), - ((1.0, INF), 0.0, 1.0), - ((1.0, INF), 0.5, 1.0), - ((1.0, INF), 1.0, 1.0), - ((INF, 1.0), 0.0, 0.0), - ((INF, 1.0), 0.5, 0.0), - ((INF, 1.0), 1.0, 1.0), + ((1.0, f64::INFINITY), 0.0, 1.0), + ((1.0, f64::INFINITY), 0.5, 1.0), + ((1.0, f64::INFINITY), 1.0, 1.0), + ((f64::INFINITY, 1.0), 0.0, 0.0), + ((f64::INFINITY, 1.0), 0.5, 0.0), + ((f64::INFINITY, 1.0), 1.0, 1.0), ]; - for &(arg, x, expect) in test.iter() { - test_case(arg, expect, cdf(x)); + for ((a, b), x, expect) in test { + test_relative(a, b, expect, cdf(x)); } } @@ -638,45 +689,66 @@ mod tests { ((5.0, 100.0), 0.0, 1.0), ((5.0, 100.0), 0.5, 0.0), ((5.0, 100.0), 1.0, 0.0), - ((1.0, INF), 0.0, 0.0), - ((1.0, INF), 0.5, 0.0), - ((1.0, INF), 1.0, 0.0), - ((INF, 1.0), 0.0, 1.0), - ((INF, 1.0), 0.5, 1.0), - ((INF, 1.0), 1.0, 0.0), + ((1.0, f64::INFINITY), 0.0, 0.0), + ((1.0, f64::INFINITY), 0.5, 0.0), + ((1.0, f64::INFINITY), 1.0, 0.0), + ((f64::INFINITY, 1.0), 0.0, 1.0), + ((f64::INFINITY, 1.0), 0.5, 1.0), + ((f64::INFINITY, 1.0), 1.0, 0.0), ]; - for &(arg, x, expect) in test.iter() { - test_case(arg, expect, sf(x)); + for ((a, b), x, expect) in test { + test_relative(a, b, expect, sf(x)); } } + #[test] + fn test_inverse_cdf() { + // let inverse_cdf = |arg: f64| move |x: Beta| x.inverse_cdf(arg); + let func = |arg: f64| move |x: Beta| x.inverse_cdf(x.cdf(arg)); + let test = [ + ((1.0, 1.0), 0.0, 0.0), + ((1.0, 1.0), 0.5, 0.5), + ((1.0, 1.0), 1.0, 1.0), + ((9.0, 1.0), 0.0, 0.0), + ((9.0, 1.0), 0.001953125, 0.001953125), + ((9.0, 1.0), 0.5, 0.5), + ((9.0, 1.0), 1.0, 1.0), + ((5.0, 100.0), 0.0, 0.0), + ((5.0, 100.0), 0.01, 0.01), + ((5.0, 100.0), 1.0, 1.0), + ]; + for ((a, b), x, expect) in test { + test_relative(a, b, expect, func(x)); + }; + } + #[test] fn test_cdf_input_lt_0() { let cdf = |arg: f64| move |x: Beta| x.cdf(arg); - test_case((1.0, 1.0), 0.0, cdf(-1.0)); + test_relative(1.0, 1.0, 0.0, cdf(-1.0)); } #[test] fn test_cdf_input_gt_1() { let cdf = |arg: f64| move |x: Beta| x.cdf(arg); - test_case((1.0, 1.0), 1.0, cdf(2.0)); + test_relative(1.0, 1.0, 1.0, cdf(2.0)); } #[test] fn test_sf_input_lt_0() { let sf = |arg: f64| move |x: Beta| x.sf(arg); - test_case((1.0, 1.0), 1.0, sf(-1.0)); + test_relative(1.0, 1.0, 1.0, sf(-1.0)); } #[test] fn test_sf_input_gt_1() { let sf = |arg: f64| move |x: Beta| x.sf(arg); - test_case((1.0, 1.0), 0.0, sf(2.0)); + test_relative(1.0, 1.0, 0.0, sf(2.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create((1.2, 3.4)), 0.0, 1.0); - test::check_continuous_distribution(&try_create((4.5, 6.7)), 0.0, 1.0); + test::check_continuous_distribution(&create_ok(1.2, 3.4), 0.0, 1.0); + test::check_continuous_distribution(&create_ok(4.5, 6.7), 0.0, 1.0); } } diff --git a/src/distribution/binomial.rs b/src/distribution/binomial.rs index 85ddecaa..c24bf7b5 100644 --- a/src/distribution/binomial.rs +++ b/src/distribution/binomial.rs @@ -1,9 +1,6 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::function::{beta, factorial}; -use crate::is_zero; use crate::statistics::*; -use crate::{Result, StatsError}; -use rand::Rng; use std::f64; /// Implements the @@ -21,12 +18,31 @@ use std::f64; /// assert_eq!(n.pmf(0), 0.03125); /// assert_eq!(n.pmf(3), 0.3125); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Binomial { p: f64, n: u64, } +/// Represents the errors that can occur when creating a [`Binomial`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum BinomialError { + /// The probability is NaN or not in `[0, 1]`. + ProbabilityInvalid, +} + +impl std::fmt::Display for BinomialError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + BinomialError::ProbabilityInvalid => write!(f, "Probability is NaN or not in [0, 1]"), + } + } +} + +impl std::error::Error for BinomialError {} + impl Binomial { /// Constructs a new binomial distribution /// with a given `p` probability of success of `n` @@ -48,9 +64,9 @@ impl Binomial { /// result = Binomial::new(-0.5, 5); /// assert!(result.is_err()); /// ``` - pub fn new(p: f64, n: u64) -> Result { - if p.is_nan() || p < 0.0 || p > 1.0 { - Err(StatsError::BadParams) + pub fn new(p: f64, n: u64) -> Result { + if p.is_nan() || !(0.0..=1.0).contains(&p) { + Err(BinomialError::ProbabilityInvalid) } else { Ok(Binomial { p, n }) } @@ -87,8 +103,15 @@ impl Binomial { } } +impl std::fmt::Display for Binomial { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Bin({},{})", self.p, self.n) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Binomial { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { (0..self.n).fold(0.0, |acc, _| { let n: f64 = rng.gen(); if n < self.p { @@ -106,7 +129,7 @@ impl DiscreteCDF for Binomial { /// /// # Formula /// - /// ```ignore + /// ```text /// I_(1 - p)(n - x, 1 + x) /// ``` /// @@ -125,7 +148,7 @@ impl DiscreteCDF for Binomial { /// /// # Formula /// - /// ```ignore + /// ```text /// I_(p)(x + 1, n - x) /// ``` /// @@ -147,7 +170,7 @@ impl Min for Binomial { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> u64 { @@ -162,7 +185,7 @@ impl Max for Binomial { /// /// # Formula /// - /// ```ignore + /// ```text /// n /// ``` fn max(&self) -> u64 { @@ -175,31 +198,33 @@ impl Distribution for Binomial { /// /// # Formula /// - /// ```ignore + /// ```text /// p * n /// ``` fn mean(&self) -> Option { Some(self.p * self.n as f64) } + /// Returns the variance of the binomial distribution /// /// # Formula /// - /// ```ignore + /// ```text /// n * p * (1 - p) /// ``` fn variance(&self) -> Option { Some(self.p * (1.0 - self.p) * self.n as f64) } + /// Returns the entropy of the binomial distribution /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / 2) * ln (2 * π * e * n * p * (1 - p)) /// ``` fn entropy(&self) -> Option { - let entr = if is_zero(self.p) || ulps_eq!(self.p, 1.0) { + let entr = if self.p == 0.0 || ulps_eq!(self.p, 1.0) { 0.0 } else { (0..self.n + 1).fold(0.0, |acc, x| { @@ -209,11 +234,12 @@ impl Distribution for Binomial { }; Some(entr) } + /// Returns the skewness of the binomial distribution /// /// # Formula /// - /// ```ignore + /// ```text /// (1 - 2p) / sqrt(n * p * (1 - p))) /// ``` fn skewness(&self) -> Option { @@ -226,7 +252,7 @@ impl Median for Binomial { /// /// # Formula /// - /// ```ignore + /// ```text /// floor(n * p) /// ``` fn median(&self) -> f64 { @@ -239,11 +265,11 @@ impl Mode> for Binomial { /// /// # Formula /// - /// ```ignore + /// ```text /// floor((n + 1) * p) /// ``` fn mode(&self) -> Option { - let mode = if is_zero(self.p) { + let mode = if self.p == 0.0 { 0 } else if ulps_eq!(self.p, 1.0) { self.n @@ -260,13 +286,13 @@ impl Discrete for Binomial { /// /// # Formula /// - /// ```ignore + /// ```text /// (n choose k) * p^k * (1 - p)^(n - k) /// ``` fn pmf(&self, x: u64) -> f64 { if x > self.n { 0.0 - } else if is_zero(self.p) { + } else if self.p == 0.0 { if x == 0 { 1.0 } else { @@ -279,7 +305,7 @@ impl Discrete for Binomial { 0.0 } } else { - (factorial::ln_binomial(self.n as u64, x as u64) + (factorial::ln_binomial(self.n, x) + x as f64 * self.p.ln() + (self.n - x) as f64 * (1.0 - self.p).ln()) .exp() @@ -291,13 +317,13 @@ impl Discrete for Binomial { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((n choose k) * p^k * (1 - p)^(n - k)) /// ``` fn ln_pmf(&self, x: u64) -> f64 { if x > self.n { f64::NEG_INFINITY - } else if is_zero(self.p) { + } else if self.p == 0.0 { if x == 0 { 0.0 } else { @@ -310,7 +336,7 @@ impl Discrete for Binomial { f64::NEG_INFINITY } } else { - factorial::ln_binomial(self.n as u64, x as u64) + factorial::ln_binomial(self.n, x) + x as f64 * self.p.ln() + (self.n - x) as f64 * (1.0 - self.p).ln() } @@ -318,256 +344,233 @@ impl Discrete for Binomial { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use std::fmt::Debug; - use crate::statistics::*; - use crate::distribution::{DiscreteCDF, Discrete, Binomial}; + use super::*; use crate::distribution::internal::*; - use crate::consts::ACC; - - fn try_create(p: f64, n: u64) -> Binomial { - let n = Binomial::new(p, n); - assert!(n.is_ok()); - n.unwrap() - } + use crate::testing_boiler; - fn create_case(p: f64, n: u64) { - let dist = try_create(p, n); - assert_eq!(p, dist.p()); - assert_eq!(n, dist.n()); - } - - fn bad_create_case(p: f64, n: u64) { - let n = Binomial::new(p, n); - assert!(n.is_err()); - } - - fn get_value(p: f64, n: u64, eval: F) -> T - where T: PartialEq + Debug, - F: Fn(Binomial) -> T - { - let n = try_create(p, n); - eval(n) - } - - fn test_case(p: f64, n: u64, expected: T, eval: F) - where T: PartialEq + Debug, - F: Fn(Binomial) -> T - { - let x = get_value(p, n, eval); - println!("{} {} {:?}", p, n, expected); - assert_eq!(expected, x); - } - - fn test_almost(p: f64, n: u64, expected: f64, acc: f64, eval: F) - where F: Fn(Binomial) -> f64 - { - let x = get_value(p, n, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(p: f64, n: u64; Binomial; BinomialError); #[test] fn test_create() { - create_case(0.0, 4); - create_case(0.3, 3); - create_case(1.0, 2); + create_ok(0.0, 4); + create_ok(0.3, 3); + create_ok(1.0, 2); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN, 1); - bad_create_case(-1.0, 1); - bad_create_case(2.0, 1); + create_err(f64::NAN, 1); + create_err(-1.0, 1); + create_err(2.0, 1); } #[test] fn test_mean() { let mean = |x: Binomial| x.mean().unwrap(); - test_case(0.0, 4, 0.0, mean); - test_almost(0.3, 3, 0.9, 1e-15, mean); - test_case(1.0, 2, 2.0, mean); + test_exact(0.0, 4, 0.0, mean); + test_absolute(0.3, 3, 0.9, 1e-15, mean); + test_exact(1.0, 2, 2.0, mean); } #[test] fn test_variance() { let variance = |x: Binomial| x.variance().unwrap(); - test_case(0.0, 4, 0.0, variance); - test_case(0.3, 3, 0.63, variance); - test_case(1.0, 2, 0.0, variance); + test_exact(0.0, 4, 0.0, variance); + test_exact(0.3, 3, 0.63, variance); + test_exact(1.0, 2, 0.0, variance); } #[test] fn test_entropy() { let entropy = |x: Binomial| x.entropy().unwrap(); - test_case(0.0, 4, 0.0, entropy); - test_almost(0.3, 3, 1.1404671643037712668976423399228972051669206536461, 1e-15, entropy); - test_case(1.0, 2, 0.0, entropy); + test_exact(0.0, 4, 0.0, entropy); + test_absolute(0.3, 3, 1.1404671643037712668976423399228972051669206536461, 1e-15, entropy); + test_exact(1.0, 2, 0.0, entropy); } #[test] fn test_skewness() { let skewness = |x: Binomial| x.skewness().unwrap(); - test_case(0.0, 4, f64::INFINITY, skewness); - test_case(0.3, 3, 0.503952630678969636286, skewness); - test_case(1.0, 2, f64::NEG_INFINITY, skewness); + test_exact(0.0, 4, f64::INFINITY, skewness); + test_exact(0.3, 3, 0.503952630678969636286, skewness); + test_exact(1.0, 2, f64::NEG_INFINITY, skewness); } #[test] fn test_median() { let median = |x: Binomial| x.median(); - test_case(0.0, 4, 0.0, median); - test_case(0.3, 3, 0.0, median); - test_case(1.0, 2, 2.0, median); + test_exact(0.0, 4, 0.0, median); + test_exact(0.3, 3, 0.0, median); + test_exact(1.0, 2, 2.0, median); } #[test] fn test_mode() { let mode = |x: Binomial| x.mode().unwrap(); - test_case(0.0, 4, 0, mode); - test_case(0.3, 3, 1, mode); - test_case(1.0, 2, 2, mode); + test_exact(0.0, 4, 0, mode); + test_exact(0.3, 3, 1, mode); + test_exact(1.0, 2, 2, mode); } #[test] fn test_min_max() { let min = |x: Binomial| x.min(); let max = |x: Binomial| x.max(); - test_case(0.3, 10, 0, min); - test_case(0.3, 10, 10, max); + test_exact(0.3, 10, 0, min); + test_exact(0.3, 10, 10, max); } #[test] fn test_pmf() { let pmf = |arg: u64| move |x: Binomial| x.pmf(arg); - test_case(0.0, 1, 1.0, pmf(0)); - test_case(0.0, 1, 0.0, pmf(1)); - test_case(0.0, 3, 1.0, pmf(0)); - test_case(0.0, 3, 0.0, pmf(1)); - test_case(0.0, 3, 0.0, pmf(3)); - test_case(0.0, 10, 1.0, pmf(0)); - test_case(0.0, 10, 0.0, pmf(1)); - test_case(0.0, 10, 0.0, pmf(10)); - test_case(0.3, 1, 0.69999999999999995559107901499373838305473327636719, pmf(0)); - test_case(0.3, 1, 0.2999999999999999888977697537484345957636833190918, pmf(1)); - test_case(0.3, 3, 0.34299999999999993471888615204079956461021032657166, pmf(0)); - test_almost(0.3, 3, 0.44099999999999992772448109690231306411849135972008, 1e-15, pmf(1)); - test_almost(0.3, 3, 0.026999999999999997002397833512077451789759292859569, 1e-16, pmf(3)); - test_almost(0.3, 10, 0.02824752489999998207939855277004937778546385011091, 1e-17, pmf(0)); - test_almost(0.3, 10, 0.12106082099999992639752977030555903089040470780077, 1e-15, pmf(1)); - test_almost(0.3, 10, 0.0000059048999999999978147480206303047454017251032868501, 1e-20, pmf(10)); - test_case(1.0, 1, 0.0, pmf(0)); - test_case(1.0, 1, 1.0, pmf(1)); - test_case(1.0, 3, 0.0, pmf(0)); - test_case(1.0, 3, 0.0, pmf(1)); - test_case(1.0, 3, 1.0, pmf(3)); - test_case(1.0, 10, 0.0, pmf(0)); - test_case(1.0, 10, 0.0, pmf(1)); - test_case(1.0, 10, 1.0, pmf(10)); + test_exact(0.0, 1, 1.0, pmf(0)); + test_exact(0.0, 1, 0.0, pmf(1)); + test_exact(0.0, 3, 1.0, pmf(0)); + test_exact(0.0, 3, 0.0, pmf(1)); + test_exact(0.0, 3, 0.0, pmf(3)); + test_exact(0.0, 10, 1.0, pmf(0)); + test_exact(0.0, 10, 0.0, pmf(1)); + test_exact(0.0, 10, 0.0, pmf(10)); + test_exact(0.3, 1, 0.69999999999999995559107901499373838305473327636719, pmf(0)); + test_exact(0.3, 1, 0.2999999999999999888977697537484345957636833190918, pmf(1)); + test_exact(0.3, 3, 0.34299999999999993471888615204079956461021032657166, pmf(0)); + test_absolute(0.3, 3, 0.44099999999999992772448109690231306411849135972008, 1e-15, pmf(1)); + test_absolute(0.3, 3, 0.026999999999999997002397833512077451789759292859569, 1e-16, pmf(3)); + test_absolute(0.3, 10, 0.02824752489999998207939855277004937778546385011091, 1e-17, pmf(0)); + test_absolute(0.3, 10, 0.12106082099999992639752977030555903089040470780077, 1e-15, pmf(1)); + test_absolute(0.3, 10, 0.0000059048999999999978147480206303047454017251032868501, 1e-20, pmf(10)); + test_exact(1.0, 1, 0.0, pmf(0)); + test_exact(1.0, 1, 1.0, pmf(1)); + test_exact(1.0, 3, 0.0, pmf(0)); + test_exact(1.0, 3, 0.0, pmf(1)); + test_exact(1.0, 3, 1.0, pmf(3)); + test_exact(1.0, 10, 0.0, pmf(0)); + test_exact(1.0, 10, 0.0, pmf(1)); + test_exact(1.0, 10, 1.0, pmf(10)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: u64| move |x: Binomial| x.ln_pmf(arg); - test_case(0.0, 1, 0.0, ln_pmf(0)); - test_case(0.0, 1, f64::NEG_INFINITY, ln_pmf(1)); - test_case(0.0, 3, 0.0, ln_pmf(0)); - test_case(0.0, 3, f64::NEG_INFINITY, ln_pmf(1)); - test_case(0.0, 3, f64::NEG_INFINITY, ln_pmf(3)); - test_case(0.0, 10, 0.0, ln_pmf(0)); - test_case(0.0, 10, f64::NEG_INFINITY, ln_pmf(1)); - test_case(0.0, 10, f64::NEG_INFINITY, ln_pmf(10)); - test_case(0.3, 1, -0.3566749439387324423539544041072745145718090708995, ln_pmf(0)); - test_case(0.3, 1, -1.2039728043259360296301803719337238685164245381839, ln_pmf(1)); - test_case(0.3, 3, -1.0700248318161973270618632123218235437154272126985, ln_pmf(0)); - test_almost(0.3, 3, -0.81871040353529122294284394322574719301255212216016, 1e-15, ln_pmf(1)); - test_almost(0.3, 3, -3.6119184129778080888905411158011716055492736145517, 1e-15, ln_pmf(3)); - test_case(0.3, 10, -3.566749439387324423539544041072745145718090708995, ln_pmf(0)); - test_almost(0.3, 10, -2.1114622067804823267977785542148302920616046876506, 1e-14, ln_pmf(1)); - test_case(0.3, 10, -12.039728043259360296301803719337238685164245381839, ln_pmf(10)); - test_case(1.0, 1, f64::NEG_INFINITY, ln_pmf(0)); - test_case(1.0, 1, 0.0, ln_pmf(1)); - test_case(1.0, 3, f64::NEG_INFINITY, ln_pmf(0)); - test_case(1.0, 3, f64::NEG_INFINITY, ln_pmf(1)); - test_case(1.0, 3, 0.0, ln_pmf(3)); - test_case(1.0, 10, f64::NEG_INFINITY, ln_pmf(0)); - test_case(1.0, 10, f64::NEG_INFINITY, ln_pmf(1)); - test_case(1.0, 10, 0.0, ln_pmf(10)); + test_exact(0.0, 1, 0.0, ln_pmf(0)); + test_exact(0.0, 1, f64::NEG_INFINITY, ln_pmf(1)); + test_exact(0.0, 3, 0.0, ln_pmf(0)); + test_exact(0.0, 3, f64::NEG_INFINITY, ln_pmf(1)); + test_exact(0.0, 3, f64::NEG_INFINITY, ln_pmf(3)); + test_exact(0.0, 10, 0.0, ln_pmf(0)); + test_exact(0.0, 10, f64::NEG_INFINITY, ln_pmf(1)); + test_exact(0.0, 10, f64::NEG_INFINITY, ln_pmf(10)); + test_exact(0.3, 1, -0.3566749439387324423539544041072745145718090708995, ln_pmf(0)); + test_exact(0.3, 1, -1.2039728043259360296301803719337238685164245381839, ln_pmf(1)); + test_exact(0.3, 3, -1.0700248318161973270618632123218235437154272126985, ln_pmf(0)); + test_absolute(0.3, 3, -0.81871040353529122294284394322574719301255212216016, 1e-15, ln_pmf(1)); + test_absolute(0.3, 3, -3.6119184129778080888905411158011716055492736145517, 1e-15, ln_pmf(3)); + test_exact(0.3, 10, -3.566749439387324423539544041072745145718090708995, ln_pmf(0)); + test_absolute(0.3, 10, -2.1114622067804823267977785542148302920616046876506, 1e-14, ln_pmf(1)); + test_exact(0.3, 10, -12.039728043259360296301803719337238685164245381839, ln_pmf(10)); + test_exact(1.0, 1, f64::NEG_INFINITY, ln_pmf(0)); + test_exact(1.0, 1, 0.0, ln_pmf(1)); + test_exact(1.0, 3, f64::NEG_INFINITY, ln_pmf(0)); + test_exact(1.0, 3, f64::NEG_INFINITY, ln_pmf(1)); + test_exact(1.0, 3, 0.0, ln_pmf(3)); + test_exact(1.0, 10, f64::NEG_INFINITY, ln_pmf(0)); + test_exact(1.0, 10, f64::NEG_INFINITY, ln_pmf(1)); + test_exact(1.0, 10, 0.0, ln_pmf(10)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: Binomial| x.cdf(arg); - test_case(0.0, 1, 1.0, cdf(0)); - test_case(0.0, 1, 1.0, cdf(1)); - test_case(0.0, 3, 1.0, cdf(0)); - test_case(0.0, 3, 1.0, cdf(1)); - test_case(0.0, 3, 1.0, cdf(3)); - test_case(0.0, 10, 1.0, cdf(0)); - test_case(0.0, 10, 1.0, cdf(1)); - test_case(0.0, 10, 1.0, cdf(10)); - test_almost(0.3, 1, 0.7, 1e-15, cdf(0)); - test_case(0.3, 1, 1.0, cdf(1)); - test_almost(0.3, 3, 0.343, 1e-14, cdf(0)); - test_almost(0.3, 3, 0.784, 1e-15, cdf(1)); - test_case(0.3, 3, 1.0, cdf(3)); - test_almost(0.3, 10, 0.0282475249, 1e-16, cdf(0)); - test_almost(0.3, 10, 0.1493083459, 1e-14, cdf(1)); - test_case(0.3, 10, 1.0, cdf(10)); - test_case(1.0, 1, 0.0, cdf(0)); - test_case(1.0, 1, 1.0, cdf(1)); - test_case(1.0, 3, 0.0, cdf(0)); - test_case(1.0, 3, 0.0, cdf(1)); - test_case(1.0, 3, 1.0, cdf(3)); - test_case(1.0, 10, 0.0, cdf(0)); - test_case(1.0, 10, 0.0, cdf(1)); - test_case(1.0, 10, 1.0, cdf(10)); + test_exact(0.0, 1, 1.0, cdf(0)); + test_exact(0.0, 1, 1.0, cdf(1)); + test_exact(0.0, 3, 1.0, cdf(0)); + test_exact(0.0, 3, 1.0, cdf(1)); + test_exact(0.0, 3, 1.0, cdf(3)); + test_exact(0.0, 10, 1.0, cdf(0)); + test_exact(0.0, 10, 1.0, cdf(1)); + test_exact(0.0, 10, 1.0, cdf(10)); + test_absolute(0.3, 1, 0.7, 1e-15, cdf(0)); + test_exact(0.3, 1, 1.0, cdf(1)); + test_absolute(0.3, 3, 0.343, 1e-14, cdf(0)); + test_absolute(0.3, 3, 0.784, 1e-15, cdf(1)); + test_exact(0.3, 3, 1.0, cdf(3)); + test_absolute(0.3, 10, 0.0282475249, 1e-16, cdf(0)); + test_absolute(0.3, 10, 0.1493083459, 1e-14, cdf(1)); + test_exact(0.3, 10, 1.0, cdf(10)); + test_exact(1.0, 1, 0.0, cdf(0)); + test_exact(1.0, 1, 1.0, cdf(1)); + test_exact(1.0, 3, 0.0, cdf(0)); + test_exact(1.0, 3, 0.0, cdf(1)); + test_exact(1.0, 3, 1.0, cdf(3)); + test_exact(1.0, 10, 0.0, cdf(0)); + test_exact(1.0, 10, 0.0, cdf(1)); + test_exact(1.0, 10, 1.0, cdf(10)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: Binomial| x.sf(arg); - test_case(0.0, 1, 0.0, sf(0)); - test_case(0.0, 1, 0.0, sf(1)); - test_case(0.0, 3, 0.0, sf(0)); - test_case(0.0, 3, 0.0, sf(1)); - test_case(0.0, 3, 0.0, sf(3)); - test_case(0.0, 10, 0.0, sf(0)); - test_case(0.0, 10, 0.0, sf(1)); - test_case(0.0, 10, 0.0, sf(10)); - test_almost(0.3, 1, 0.3, 1e-15, sf(0)); - test_case(0.3, 1, 0.0, sf(1)); - test_almost(0.3, 3, 0.657, 1e-14, sf(0)); - test_almost(0.3, 3, 0.216, 1e-15, sf(1)); - test_case(0.3, 3, 0.0, sf(3)); - test_almost(0.3, 10, 0.9717524751000001, 1e-16, sf(0)); - test_almost(0.3, 10, 0.850691654100002, 1e-14, sf(1)); - test_case(0.3, 10, 0.0, sf(10)); - test_case(1.0, 1, 1.0, sf(0)); - test_case(1.0, 1, 0.0, sf(1)); - test_case(1.0, 3, 1.0, sf(0)); - test_case(1.0, 3, 1.0, sf(1)); - test_case(1.0, 3, 0.0, sf(3)); - test_case(1.0, 10, 1.0, sf(0)); - test_case(1.0, 10, 1.0, sf(1)); - test_case(1.0, 10, 0.0, sf(10)); + test_exact(0.0, 1, 0.0, sf(0)); + test_exact(0.0, 1, 0.0, sf(1)); + test_exact(0.0, 3, 0.0, sf(0)); + test_exact(0.0, 3, 0.0, sf(1)); + test_exact(0.0, 3, 0.0, sf(3)); + test_exact(0.0, 10, 0.0, sf(0)); + test_exact(0.0, 10, 0.0, sf(1)); + test_exact(0.0, 10, 0.0, sf(10)); + test_absolute(0.3, 1, 0.3, 1e-15, sf(0)); + test_exact(0.3, 1, 0.0, sf(1)); + test_absolute(0.3, 3, 0.657, 1e-14, sf(0)); + test_absolute(0.3, 3, 0.216, 1e-15, sf(1)); + test_exact(0.3, 3, 0.0, sf(3)); + test_absolute(0.3, 10, 0.9717524751000001, 1e-16, sf(0)); + test_absolute(0.3, 10, 0.850691654100002, 1e-14, sf(1)); + test_exact(0.3, 10, 0.0, sf(10)); + test_exact(1.0, 1, 1.0, sf(0)); + test_exact(1.0, 1, 0.0, sf(1)); + test_exact(1.0, 3, 1.0, sf(0)); + test_exact(1.0, 3, 1.0, sf(1)); + test_exact(1.0, 3, 0.0, sf(3)); + test_exact(1.0, 10, 1.0, sf(0)); + test_exact(1.0, 10, 1.0, sf(1)); + test_exact(1.0, 10, 0.0, sf(10)); } #[test] fn test_cdf_upper_bound() { let cdf = |arg: u64| move |x: Binomial| x.cdf(arg); - test_case(0.5, 3, 1.0, cdf(5)); + test_exact(0.5, 3, 1.0, cdf(5)); } #[test] fn test_sf_upper_bound() { let sf = |arg: u64| move |x: Binomial| x.sf(arg); - test_case(0.5, 3, 0.0, sf(5)); + test_exact(0.5, 3, 0.0, sf(5)); + } + + #[test] + fn test_inverse_cdf() { + let invcdf = |arg: f64| move |x: Binomial| x.inverse_cdf(arg); + test_exact(0.4, 5, 2, invcdf(0.3456)); + + // cases in issue #185 + test_exact(0.018, 465, 1, invcdf(3.472e-4)); + test_exact(0.5, 6, 4, invcdf(0.75)); + } + + #[test] + fn test_cdf_inverse_cdf() { + let cdf_invcdf = |arg: u64| move |x: Binomial| x.inverse_cdf(x.cdf(arg)); + test_exact(0.3, 10, 3, cdf_invcdf(3)); + test_exact(0.3, 10, 4, cdf_invcdf(4)); + test_exact(0.5, 6, 4, cdf_invcdf(4)); } #[test] fn test_discrete() { - test::check_discrete_distribution(&try_create(0.3, 5), 5); - test::check_discrete_distribution(&try_create(0.7, 10), 10); + test::check_discrete_distribution(&create_ok(0.3, 5), 5); + test::check_discrete_distribution(&create_ok(0.7, 10), 10); } } diff --git a/src/distribution/categorical.rs b/src/distribution/categorical.rs index f489d653..7d3a7c1c 100644 --- a/src/distribution/categorical.rs +++ b/src/distribution/categorical.rs @@ -1,7 +1,5 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; -use rand::Rng; use std::f64; /// Implements the @@ -12,7 +10,6 @@ use std::f64; /// # Examples /// /// ``` -/// /// use statrs::distribution::{Categorical, Discrete}; /// use statrs::statistics::Distribution; /// use statrs::prec; @@ -21,13 +18,43 @@ use std::f64; /// assert!(prec::almost_eq(n.mean().unwrap(), 5.0 / 3.0, 1e-15)); /// assert_eq!(n.pmf(1), 1.0 / 3.0); /// ``` -#[derive(Debug, Clone, PartialEq)] +#[derive(Clone, PartialEq, Debug)] pub struct Categorical { norm_pmf: Vec, cdf: Vec, - sf: Vec + sf: Vec, +} + +/// Represents the errors that can occur when creating a [`Categorical`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum CategoricalError { + /// The probability mass is empty. + ProbMassEmpty, + + /// The probabilities sums up to zero. + ProbMassSumZero, + + /// The probability mass contains at least one element which is NaN or less than zero. + ProbMassHasInvalidElements, +} + +impl std::fmt::Display for CategoricalError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + CategoricalError::ProbMassEmpty => write!(f, "Probability mass is empty"), + CategoricalError::ProbMassSumZero => write!(f, "Probabilities sum up to zero"), + CategoricalError::ProbMassHasInvalidElements => write!( + f, + "Probability mass contains at least one element which is NaN or less than zero" + ), + } + } } +impl std::error::Error for CategoricalError {} + impl Categorical { /// Constructs a new categorical distribution /// with the probabilities masses defined by `prob_mass` @@ -53,23 +80,36 @@ impl Categorical { /// result = Categorical::new(&[0.0, -1.0, 2.0]); /// assert!(result.is_err()); /// ``` - pub fn new(prob_mass: &[f64]) -> Result { - if !super::internal::is_valid_multinomial(prob_mass, true) { - Err(StatsError::BadParams) - } else { - // extract un-normalized cdf - let cdf = prob_mass_to_cdf(prob_mass); - // extract un-normalized sf - let sf = cdf_to_sf(&cdf); - // extract normalized probability mass - let sum = cdf[cdf.len() - 1]; - let mut norm_pmf = vec![0.0; prob_mass.len()]; - norm_pmf - .iter_mut() - .zip(prob_mass.iter()) - .for_each(|(np, pm)| *np = *pm / sum); - Ok(Categorical { norm_pmf, cdf, sf }) + pub fn new(prob_mass: &[f64]) -> Result { + if prob_mass.is_empty() { + return Err(CategoricalError::ProbMassEmpty); + } + + let mut prob_sum = 0.0; + for &p in prob_mass { + if p.is_nan() || p < 0.0 { + return Err(CategoricalError::ProbMassHasInvalidElements); + } + + prob_sum += p; } + + if prob_sum == 0.0 { + return Err(CategoricalError::ProbMassSumZero); + } + + // extract un-normalized cdf + let cdf = prob_mass_to_cdf(prob_mass); + // extract un-normalized sf + let sf = cdf_to_sf(&cdf); + // extract normalized probability mass + let sum = cdf[cdf.len() - 1]; + let mut norm_pmf = vec![0.0; prob_mass.len()]; + norm_pmf + .iter_mut() + .zip(prob_mass.iter()) + .for_each(|(np, pm)| *np = *pm / sum); + Ok(Categorical { norm_pmf, cdf, sf }) } fn cdf_max(&self) -> f64 { @@ -77,8 +117,15 @@ impl Categorical { } } +impl std::fmt::Display for Categorical { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Cat({:#?})", self.norm_pmf) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Categorical { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { sample_unchecked(rng, &self.cdf) } } @@ -89,7 +136,7 @@ impl DiscreteCDF for Categorical { /// /// # Formula /// - /// ```ignore + /// ```text /// sum(p_j) from 0..x /// ``` /// @@ -107,7 +154,7 @@ impl DiscreteCDF for Categorical { /// /// # Formula /// - /// ```ignore + /// ```text /// [ sum(p_j) from x..end ] /// ``` fn sf(&self, x: u64) -> f64 { @@ -128,7 +175,7 @@ impl DiscreteCDF for Categorical { /// /// # Formula /// - /// ```ignore + /// ```text /// i /// ``` /// @@ -151,7 +198,7 @@ impl Min for Categorical { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> u64 { @@ -166,7 +213,7 @@ impl Max for Categorical { /// /// # Formula /// - /// ```ignore + /// ```text /// n /// ``` fn max(&self) -> u64 { @@ -179,7 +226,7 @@ impl Distribution for Categorical { /// /// # Formula /// - /// ```ignore + /// ```text /// Σ(j * p_j) /// ``` /// @@ -194,11 +241,12 @@ impl Distribution for Categorical { .fold(0.0, |acc, (idx, &val)| acc + idx as f64 * val), ) } + /// Returns the variance of the categorical distribution /// /// # Formula /// - /// ```ignore + /// ```text /// Σ(p_j * (j - μ)^2) /// ``` /// @@ -217,11 +265,12 @@ impl Distribution for Categorical { }); Some(var) } + /// Returns the entropy of the categorical distribution /// /// # Formula /// - /// ```ignore + /// ```text /// -Σ(p_j * ln(p_j)) /// ``` /// @@ -243,7 +292,7 @@ impl Median for Categorical { /// /// # Formula /// - /// ```ignore + /// ```text /// CDF^-1(0.5) /// ``` fn median(&self) -> f64 { @@ -257,7 +306,7 @@ impl Discrete for Categorical { /// /// # Formula /// - /// ```ignore + /// ```text /// p_x /// ``` fn pmf(&self, x: u64) -> f64 { @@ -273,7 +322,8 @@ impl Discrete for Categorical { /// Draws a sample from the categorical distribution described by `cdf` /// without doing any bounds checking -pub fn sample_unchecked(rng: &mut R, cdf: &[f64]) -> f64 { +#[cfg(feature = "rand")] +pub fn sample_unchecked(rng: &mut R, cdf: &[f64]) -> f64 { let draw = rng.gen::() * cdf.last().unwrap(); cdf.iter() .enumerate() @@ -294,7 +344,7 @@ pub fn prob_mass_to_cdf(prob_mass: &[f64]) -> Vec { cdf } -/// Computes the sf from the given cumulative densities. +/// Computes the sf from the given cumulative densities. /// Performs no parameter or bounds checking. pub fn cdf_to_sf(cdf: &[f64]) -> Vec { let max = *cdf.last().unwrap(); @@ -342,166 +392,135 @@ fn test_binary_index() { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use std::fmt::Debug; - use crate::statistics::*; - use crate::distribution::{Categorical, Discrete, DiscreteCDF}; + use super::*; use crate::distribution::internal::*; - use crate::consts::ACC; - - fn try_create(prob_mass: &[f64]) -> Categorical { - let n = Categorical::new(prob_mass); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(prob_mass: &[f64]) { - try_create(prob_mass); - } - - fn bad_create_case(prob_mass: &[f64]) { - let n = Categorical::new(prob_mass); - assert!(n.is_err()); - } + use crate::testing_boiler; - fn get_value(prob_mass: &[f64], eval: F) -> T - where T: PartialEq + Debug, - F: Fn(Categorical) -> T - { - let n = try_create(prob_mass); - eval(n) - } - - fn test_case(prob_mass: &[f64], expected: T, eval: F) - where T: PartialEq + Debug, - F: Fn(Categorical) -> T - { - let x = get_value(prob_mass, eval); - assert_eq!(expected, x); - } - - fn test_almost(prob_mass: &[f64], expected: f64, acc: f64, eval: F) - where F: Fn(Categorical) -> f64 - { - let x = get_value(prob_mass, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(prob_mass: &[f64]; Categorical; CategoricalError); #[test] fn test_create() { - create_case(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]); + create_ok(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]); } #[test] fn test_bad_create() { - bad_create_case(&[-1.0, 1.0]); - bad_create_case(&[0.0, 0.0]); + let invalid: &[(&[f64], CategoricalError)] = &[ + (&[], CategoricalError::ProbMassEmpty), + (&[-1.0, 1.0], CategoricalError::ProbMassHasInvalidElements), + (&[0.0, 0.0, 0.0], CategoricalError::ProbMassSumZero), + ]; + + for &(prob_mass, err) in invalid { + test_create_err(prob_mass, err); + } } #[test] fn test_mean() { let mean = |x: Categorical| x.mean().unwrap(); - test_case(&[0.0, 0.25, 0.5, 0.25], 2.0, mean); - test_case(&[0.0, 1.0, 2.0, 1.0], 2.0, mean); - test_case(&[0.0, 0.5, 0.5], 1.5, mean); - test_case(&[0.75, 0.25], 0.25, mean); - test_case(&[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 5.0, mean); + test_exact(&[0.0, 0.25, 0.5, 0.25], 2.0, mean); + test_exact(&[0.0, 1.0, 2.0, 1.0], 2.0, mean); + test_exact(&[0.0, 0.5, 0.5], 1.5, mean); + test_exact(&[0.75, 0.25], 0.25, mean); + test_exact(&[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0], 5.0, mean); } #[test] fn test_variance() { let variance = |x: Categorical| x.variance().unwrap(); - test_case(&[0.0, 0.25, 0.5, 0.25], 0.5, variance); - test_case(&[0.0, 1.0, 2.0, 1.0], 0.5, variance); - test_case(&[0.0, 0.5, 0.5], 0.25, variance); - test_case(&[0.75, 0.25], 0.1875, variance); - test_case(&[1.0, 0.0, 1.0], 1.0, variance); + test_exact(&[0.0, 0.25, 0.5, 0.25], 0.5, variance); + test_exact(&[0.0, 1.0, 2.0, 1.0], 0.5, variance); + test_exact(&[0.0, 0.5, 0.5], 0.25, variance); + test_exact(&[0.75, 0.25], 0.1875, variance); + test_exact(&[1.0, 0.0, 1.0], 1.0, variance); } #[test] fn test_entropy() { let entropy = |x: Categorical| x.entropy().unwrap(); - test_case(&[0.0, 1.0], 0.0, entropy); - test_almost(&[0.0, 1.0, 1.0], 2f64.ln(), 1e-15, entropy); - test_almost(&[1.0, 1.0, 1.0], 3f64.ln(), 1e-15, entropy); - test_almost(&vec![1.0; 100], 100f64.ln(), 1e-14, entropy); - test_almost(&[0.0, 0.25, 0.5, 0.25], 1.0397207708399179, 1e-15, entropy); + test_exact(&[0.0, 1.0], 0.0, entropy); + test_absolute(&[0.0, 1.0, 1.0], 2f64.ln(), 1e-15, entropy); + test_absolute(&[1.0, 1.0, 1.0], 3f64.ln(), 1e-15, entropy); + test_absolute(&vec![1.0; 100], 100f64.ln(), 1e-14, entropy); + test_absolute(&[0.0, 0.25, 0.5, 0.25], 1.0397207708399179, 1e-15, entropy); } #[test] fn test_median() { let median = |x: Categorical| x.median(); - test_case(&[0.0, 3.0, 1.0, 1.0], 1.0, median); - test_case(&[4.0, 2.5, 2.5, 1.0], 1.0, median); + test_exact(&[0.0, 3.0, 1.0, 1.0], 1.0, median); + test_exact(&[4.0, 2.5, 2.5, 1.0], 1.0, median); } #[test] fn test_min_max() { let min = |x: Categorical| x.min(); let max = |x: Categorical| x.max(); - test_case(&[4.0, 2.5, 2.5, 1.0], 0, min); - test_case(&[4.0, 2.5, 2.5, 1.0], 3, max); + test_exact(&[4.0, 2.5, 2.5, 1.0], 0, min); + test_exact(&[4.0, 2.5, 2.5, 1.0], 3, max); } #[test] fn test_pmf() { let pmf = |arg: u64| move |x: Categorical| x.pmf(arg); - test_case(&[0.0, 0.25, 0.5, 0.25], 0.0, pmf(0)); - test_case(&[0.0, 0.25, 0.5, 0.25], 0.25, pmf(1)); - test_case(&[0.0, 0.25, 0.5, 0.25], 0.25, pmf(3)); + test_exact(&[0.0, 0.25, 0.5, 0.25], 0.0, pmf(0)); + test_exact(&[0.0, 0.25, 0.5, 0.25], 0.25, pmf(1)); + test_exact(&[0.0, 0.25, 0.5, 0.25], 0.25, pmf(3)); } #[test] fn test_pmf_x_too_high() { let pmf = |arg: u64| move |x: Categorical| x.pmf(arg); - test_case(&[4.0, 2.5, 2.5, 1.0], 0.0, pmf(4)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 0.0, pmf(4)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: u64| move |x: Categorical| x.ln_pmf(arg); - test_case(&[0.0, 0.25, 0.5, 0.25], 0f64.ln(), ln_pmf(0)); - test_case(&[0.0, 0.25, 0.5, 0.25], 0.25f64.ln(), ln_pmf(1)); - test_case(&[0.0, 0.25, 0.5, 0.25], 0.25f64.ln(), ln_pmf(3)); + test_exact(&[0.0, 0.25, 0.5, 0.25], 0f64.ln(), ln_pmf(0)); + test_exact(&[0.0, 0.25, 0.5, 0.25], 0.25f64.ln(), ln_pmf(1)); + test_exact(&[0.0, 0.25, 0.5, 0.25], 0.25f64.ln(), ln_pmf(3)); } #[test] fn test_ln_pmf_x_too_high() { let ln_pmf = |arg: u64| move |x: Categorical| x.ln_pmf(arg); - test_case(&[4.0, 2.5, 2.5, 1.0], f64::NEG_INFINITY, ln_pmf(4)); + test_exact(&[4.0, 2.5, 2.5, 1.0], f64::NEG_INFINITY, ln_pmf(4)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: Categorical| x.cdf(arg); - test_case(&[0.0, 3.0, 1.0, 1.0], 3.0 / 5.0, cdf(1)); - test_case(&[1.0, 1.0, 1.0, 1.0], 0.25, cdf(0)); - test_case(&[4.0, 2.5, 2.5, 1.0], 0.4, cdf(0)); - test_case(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(3)); - test_case(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(4)); + test_exact(&[0.0, 3.0, 1.0, 1.0], 3.0 / 5.0, cdf(1)); + test_exact(&[1.0, 1.0, 1.0, 1.0], 0.25, cdf(0)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 0.4, cdf(0)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(3)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(4)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: Categorical| x.sf(arg); - test_case(&[0.0, 3.0, 1.0, 1.0], 2.0 / 5.0, sf(1)); - test_case(&[1.0, 1.0, 1.0, 1.0], 0.75, sf(0)); - test_case(&[4.0, 2.5, 2.5, 1.0], 0.6, sf(0)); - test_case(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(3)); - test_case(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(4)); + test_exact(&[0.0, 3.0, 1.0, 1.0], 2.0 / 5.0, sf(1)); + test_exact(&[1.0, 1.0, 1.0, 1.0], 0.75, sf(0)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 0.6, sf(0)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(3)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(4)); } #[test] fn test_cdf_input_high() { let cdf = |arg: u64| move |x: Categorical| x.cdf(arg); - test_case(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(4)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 1.0, cdf(4)); } #[test] fn test_sf_input_high() { let sf = |arg: u64| move |x: Categorical| x.sf(arg); - test_case(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(4)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 0.0, sf(4)); } #[test] @@ -517,31 +536,31 @@ mod tests { #[test] fn test_inverse_cdf() { let inverse_cdf = |arg: f64| move |x: Categorical| x.inverse_cdf(arg); - test_case(&[0.0, 3.0, 1.0, 1.0], 1, inverse_cdf(0.2)); - test_case(&[0.0, 3.0, 1.0, 1.0], 1, inverse_cdf(0.5)); - test_case(&[0.0, 3.0, 1.0, 1.0], 3, inverse_cdf(0.95)); - test_case(&[4.0, 2.5, 2.5, 1.0], 0, inverse_cdf(0.2)); - test_case(&[4.0, 2.5, 2.5, 1.0], 1, inverse_cdf(0.5)); - test_case(&[4.0, 2.5, 2.5, 1.0], 3, inverse_cdf(0.95)); + test_exact(&[0.0, 3.0, 1.0, 1.0], 1, inverse_cdf(0.2)); + test_exact(&[0.0, 3.0, 1.0, 1.0], 1, inverse_cdf(0.5)); + test_exact(&[0.0, 3.0, 1.0, 1.0], 3, inverse_cdf(0.95)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 0, inverse_cdf(0.2)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 1, inverse_cdf(0.5)); + test_exact(&[4.0, 2.5, 2.5, 1.0], 3, inverse_cdf(0.95)); } #[test] #[should_panic] fn test_inverse_cdf_input_low() { - let inverse_cdf = |arg: f64| move |x: Categorical| x.inverse_cdf(arg); - get_value(&[4.0, 2.5, 2.5, 1.0], inverse_cdf(0.0)); + let dist = create_ok(&[4.0, 2.5, 2.5, 1.0]); + dist.inverse_cdf(0.0); } #[test] #[should_panic] fn test_inverse_cdf_input_high() { - let inverse_cdf = |arg: f64| move |x: Categorical| x.inverse_cdf(arg); - get_value(&[4.0, 2.5, 2.5, 1.0], inverse_cdf(1.0)); + let dist = create_ok(&[4.0, 2.5, 2.5, 1.0]); + dist.inverse_cdf(1.0); } #[test] fn test_discrete() { - test::check_discrete_distribution(&try_create(&[1.0, 2.0, 3.0, 4.0]), 4); - test::check_discrete_distribution(&try_create(&[0.0, 1.0, 2.0, 3.0, 4.0]), 5); + test::check_discrete_distribution(&create_ok(&[1.0, 2.0, 3.0, 4.0]), 4); + test::check_discrete_distribution(&create_ok(&[0.0, 1.0, 2.0, 3.0, 4.0]), 5); } } diff --git a/src/distribution/cauchy.rs b/src/distribution/cauchy.rs index e42919ea..c9fc4ae2 100644 --- a/src/distribution/cauchy.rs +++ b/src/distribution/cauchy.rs @@ -1,7 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; -use rand::Rng; use std::f64; /// Implements the [Cauchy](https://en.wikipedia.org/wiki/Cauchy_distribution) @@ -17,12 +15,35 @@ use std::f64; /// assert_eq!(n.mode().unwrap(), 0.0); /// assert_eq!(n.pdf(1.0), 0.1591549430918953357689); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Cauchy { location: f64, scale: f64, } +/// Represents the errors that can occur when creating a [`Cauchy`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum CauchyError { + /// The location is NaN. + LocationInvalid, + + /// The scale is NaN, zero or less than zero. + ScaleInvalid, +} + +impl std::fmt::Display for CauchyError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + CauchyError::LocationInvalid => write!(f, "Location is NaN"), + CauchyError::ScaleInvalid => write!(f, "Scale is NaN, zero or less than zero"), + } + } +} + +impl std::error::Error for CauchyError {} + impl Cauchy { /// Constructs a new cauchy distribution with the given /// location and scale. @@ -42,12 +63,16 @@ impl Cauchy { /// result = Cauchy::new(0.0, -1.0); /// assert!(result.is_err()); /// ``` - pub fn new(location: f64, scale: f64) -> Result { - if location.is_nan() || scale.is_nan() || scale <= 0.0 { - Err(StatsError::BadParams) - } else { - Ok(Cauchy { location, scale }) + pub fn new(location: f64, scale: f64) -> Result { + if location.is_nan() { + return Err(CauchyError::LocationInvalid); + } + + if scale.is_nan() || scale <= 0.0 { + return Err(CauchyError::ScaleInvalid); } + + Ok(Cauchy { location, scale }) } /// Returns the location of the cauchy distribution @@ -79,8 +104,15 @@ impl Cauchy { } } +impl std::fmt::Display for Cauchy { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Cauchy({}, {})", self.location, self.scale) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Cauchy { - fn sample(&self, r: &mut R) -> f64 { + fn sample(&self, r: &mut R) -> f64 { self.location + self.scale * (f64::consts::PI * (r.gen::() - 0.5)).tan() } } @@ -91,7 +123,7 @@ impl ContinuousCDF for Cauchy { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / π) * arctan((x - x_0) / γ) + 0.5 /// ``` /// @@ -105,7 +137,7 @@ impl ContinuousCDF for Cauchy { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / π) * arctan(-(x - x_0) / γ) + 0.5 /// ``` /// @@ -115,6 +147,24 @@ impl ContinuousCDF for Cauchy { fn sf(&self, x: f64) -> f64 { (1.0 / f64::consts::PI) * ((self.location - x) / self.scale).atan() + 0.5 } + + /// Calculates the inverse cumulative distribution function for the + /// cauchy distribution at `x` + /// + /// # Formula + /// + /// ```text + /// x_0 + γ tan((x - 0.5) π) + /// ``` + /// + /// where `x_0` is the location and `γ` is the scale + fn inverse_cdf(&self, x: f64) -> f64 { + if !(0.0..=1.0).contains(&x) { + panic!("x must be in [0, 1]"); + } else { + self.location + self.scale * (f64::consts::PI * (x - 0.5)).tan() + } + } } impl Min for Cauchy { @@ -123,7 +173,7 @@ impl Min for Cauchy { /// /// # Formula /// - /// ```ignore + /// ```text /// NEG_INF /// ``` fn min(&self) -> f64 { @@ -137,8 +187,8 @@ impl Max for Cauchy { /// /// # Formula /// - /// ```ignore - /// INF + /// ```text + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY @@ -150,7 +200,7 @@ impl Distribution for Cauchy { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(γ) + ln(4π) /// ``` /// @@ -165,7 +215,7 @@ impl Median for Cauchy { /// /// # Formula /// - /// ```ignore + /// ```text /// x_0 /// ``` /// @@ -180,7 +230,7 @@ impl Mode> for Cauchy { /// /// # Formula /// - /// ```ignore + /// ```text /// x_0 /// ``` /// @@ -196,7 +246,7 @@ impl Continuous for Cauchy { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 / (πγ * (1 + ((x - x_0) / γ)^2)) /// ``` /// @@ -212,7 +262,7 @@ impl Continuous for Cauchy { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(1 / (πγ * (1 + ((x - x_0) / γ)^2))) /// ``` /// @@ -226,244 +276,243 @@ impl Continuous for Cauchy { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use crate::statistics::*; - use crate::distribution::{ContinuousCDF, Continuous, Cauchy}; + use super::*; use crate::distribution::internal::*; - use crate::consts::ACC; - - fn try_create(location: f64, scale: f64) -> Cauchy { - let n = Cauchy::new(location, scale); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(location: f64, scale: f64) { - let n = try_create(location, scale); - assert_eq!(location, n.location()); - assert_eq!(scale, n.scale()); - } + use crate::testing_boiler; - fn bad_create_case(location: f64, scale: f64) { - let n = Cauchy::new(location, scale); - assert!(n.is_err()); - } - - fn test_case(location: f64, scale: f64, expected: f64, eval: F) - where F: Fn(Cauchy) -> f64 - { - let n = try_create(location, scale); - let x = eval(n); - assert_eq!(expected, x); - } - - fn test_almost(location: f64, scale: f64, expected: f64, acc: f64, eval: F) - where F: Fn(Cauchy) -> f64 - { - let n = try_create(location, scale); - let x = eval(n); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(location: f64, scale: f64; Cauchy; CauchyError); #[test] fn test_create() { - create_case(0.0, 0.1); - create_case(0.0, 1.0); - create_case(0.0, 10.0); - create_case(10.0, 11.0); - create_case(-5.0, 100.0); - create_case(0.0, f64::INFINITY); + create_ok(0.0, 0.1); + create_ok(0.0, 1.0); + create_ok(0.0, 10.0); + create_ok(10.0, 11.0); + create_ok(-5.0, 100.0); + create_ok(0.0, f64::INFINITY); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN, 1.0); - bad_create_case(1.0, f64::NAN); - bad_create_case(f64::NAN, f64::NAN); - bad_create_case(1.0, 0.0); + let invalid = [ + (f64::NAN, 1.0, CauchyError::LocationInvalid), + (1.0, f64::NAN, CauchyError::ScaleInvalid), + (f64::NAN, f64::NAN, CauchyError::LocationInvalid), + (1.0, 0.0, CauchyError::ScaleInvalid), + ]; + + for (location, scale, err) in invalid { + test_create_err(location, scale, err); + } } #[test] fn test_entropy() { let entropy = |x: Cauchy| x.entropy().unwrap(); - test_case(0.0, 2.0, 3.224171427529236102395, entropy); - test_case(0.1, 4.0, 3.917318608089181411812, entropy); - test_case(1.0, 10.0, 4.833609339963336476996, entropy); - test_case(10.0, 11.0, 4.92891951976766133704, entropy); + test_exact(0.0, 2.0, 3.224171427529236102395, entropy); + test_exact(0.1, 4.0, 3.917318608089181411812, entropy); + test_exact(1.0, 10.0, 4.833609339963336476996, entropy); + test_exact(10.0, 11.0, 4.92891951976766133704, entropy); } #[test] fn test_mode() { let mode = |x: Cauchy| x.mode().unwrap(); - test_case(0.0, 2.0, 0.0, mode); - test_case(0.1, 4.0, 0.1, mode); - test_case(1.0, 10.0, 1.0, mode); - test_case(10.0, 11.0, 10.0, mode); - test_case(0.0, f64::INFINITY, 0.0, mode); + test_exact(0.0, 2.0, 0.0, mode); + test_exact(0.1, 4.0, 0.1, mode); + test_exact(1.0, 10.0, 1.0, mode); + test_exact(10.0, 11.0, 10.0, mode); + test_exact(0.0, f64::INFINITY, 0.0, mode); } #[test] fn test_median() { let median = |x: Cauchy| x.median(); - test_case(0.0, 2.0, 0.0, median); - test_case(0.1, 4.0, 0.1, median); - test_case(1.0, 10.0, 1.0, median); - test_case(10.0, 11.0, 10.0, median); - test_case(0.0, f64::INFINITY, 0.0, median); + test_exact(0.0, 2.0, 0.0, median); + test_exact(0.1, 4.0, 0.1, median); + test_exact(1.0, 10.0, 1.0, median); + test_exact(10.0, 11.0, 10.0, median); + test_exact(0.0, f64::INFINITY, 0.0, median); } #[test] fn test_min_max() { let min = |x: Cauchy| x.min(); let max = |x: Cauchy| x.max(); - test_case(0.0, 1.0, f64::NEG_INFINITY, min); - test_case(0.0, 1.0, f64::INFINITY, max); + test_exact(0.0, 1.0, f64::NEG_INFINITY, min); + test_exact(0.0, 1.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Cauchy| x.pdf(arg); - test_case(0.0, 0.1, 0.001272730452554141029739, pdf(-5.0)); - test_case(0.0, 0.1, 0.03151583031522679916216, pdf(-1.0)); - test_almost(0.0, 0.1, 3.183098861837906715378, 1e-14, pdf(0.0)); - test_case(0.0, 0.1, 0.03151583031522679916216, pdf(1.0)); - test_case(0.0, 0.1, 0.001272730452554141029739, pdf(5.0)); - test_almost(0.0, 1.0, 0.01224268793014579505914, 1e-17, pdf(-5.0)); - test_case(0.0, 1.0, 0.1591549430918953357689, pdf(-1.0)); - test_case(0.0, 1.0, 0.3183098861837906715378, pdf(0.0)); - test_case(0.0, 1.0, 0.1591549430918953357689, pdf(1.0)); - test_almost(0.0, 1.0, 0.01224268793014579505914, 1e-17, pdf(5.0)); - test_case(0.0, 10.0, 0.02546479089470325372302, pdf(-5.0)); - test_case(0.0, 10.0, 0.03151583031522679916216, pdf(-1.0)); - test_case(0.0, 10.0, 0.03183098861837906715378, pdf(0.0)); - test_case(0.0, 10.0, 0.03151583031522679916216, pdf(1.0)); - test_case(0.0, 10.0, 0.02546479089470325372302, pdf(5.0)); - test_case(-5.0, 100.0, 0.003183098861837906715378, pdf(-5.0)); - test_almost(-5.0, 100.0, 0.003178014039374906864395, 1e-17, pdf(-1.0)); - test_case(-5.0, 100.0, 0.003175160959439308444267, pdf(0.0)); - test_case(-5.0, 100.0, 0.003171680810918599756255, pdf(1.0)); - test_almost(-5.0, 100.0, 0.003151583031522679916216, 1e-17, pdf(5.0)); - test_case(0.0, f64::INFINITY, 0.0, pdf(-5.0)); - test_case(0.0, f64::INFINITY, 0.0, pdf(-1.0)); - test_case(0.0, f64::INFINITY, 0.0, pdf(0.0)); - test_case(0.0, f64::INFINITY, 0.0, pdf(1.0)); - test_case(0.0, f64::INFINITY, 0.0, pdf(5.0)); - test_case(f64::INFINITY, 1.0, 0.0, pdf(-5.0)); - test_case(f64::INFINITY, 1.0, 0.0, pdf(-1.0)); - test_case(f64::INFINITY, 1.0, 0.0, pdf(0.0)); - test_case(f64::INFINITY, 1.0, 0.0, pdf(1.0)); - test_case(f64::INFINITY, 1.0, 0.0, pdf(5.0)); + test_exact(0.0, 0.1, 0.001272730452554141029739, pdf(-5.0)); + test_exact(0.0, 0.1, 0.03151583031522679916216, pdf(-1.0)); + test_absolute(0.0, 0.1, 3.183098861837906715378, 1e-14, pdf(0.0)); + test_exact(0.0, 0.1, 0.03151583031522679916216, pdf(1.0)); + test_exact(0.0, 0.1, 0.001272730452554141029739, pdf(5.0)); + test_absolute(0.0, 1.0, 0.01224268793014579505914, 1e-17, pdf(-5.0)); + test_exact(0.0, 1.0, 0.1591549430918953357689, pdf(-1.0)); + test_exact(0.0, 1.0, 0.3183098861837906715378, pdf(0.0)); + test_exact(0.0, 1.0, 0.1591549430918953357689, pdf(1.0)); + test_absolute(0.0, 1.0, 0.01224268793014579505914, 1e-17, pdf(5.0)); + test_exact(0.0, 10.0, 0.02546479089470325372302, pdf(-5.0)); + test_exact(0.0, 10.0, 0.03151583031522679916216, pdf(-1.0)); + test_exact(0.0, 10.0, 0.03183098861837906715378, pdf(0.0)); + test_exact(0.0, 10.0, 0.03151583031522679916216, pdf(1.0)); + test_exact(0.0, 10.0, 0.02546479089470325372302, pdf(5.0)); + test_exact(-5.0, 100.0, 0.003183098861837906715378, pdf(-5.0)); + test_absolute(-5.0, 100.0, 0.003178014039374906864395, 1e-17, pdf(-1.0)); + test_exact(-5.0, 100.0, 0.003175160959439308444267, pdf(0.0)); + test_exact(-5.0, 100.0, 0.003171680810918599756255, pdf(1.0)); + test_absolute(-5.0, 100.0, 0.003151583031522679916216, 1e-17, pdf(5.0)); + test_exact(0.0, f64::INFINITY, 0.0, pdf(-5.0)); + test_exact(0.0, f64::INFINITY, 0.0, pdf(-1.0)); + test_exact(0.0, f64::INFINITY, 0.0, pdf(0.0)); + test_exact(0.0, f64::INFINITY, 0.0, pdf(1.0)); + test_exact(0.0, f64::INFINITY, 0.0, pdf(5.0)); + test_exact(f64::INFINITY, 1.0, 0.0, pdf(-5.0)); + test_exact(f64::INFINITY, 1.0, 0.0, pdf(-1.0)); + test_exact(f64::INFINITY, 1.0, 0.0, pdf(0.0)); + test_exact(f64::INFINITY, 1.0, 0.0, pdf(1.0)); + test_exact(f64::INFINITY, 1.0, 0.0, pdf(5.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Cauchy| x.ln_pdf(arg); - test_case(0.0, 0.1, -6.666590723732973542744, ln_pdf(-5.0)); - test_almost(0.0, 0.1, -3.457265309696613941009, 1e-14, ln_pdf(-1.0)); - test_case(0.0, 0.1, 1.157855207144645509875, ln_pdf(0.0)); - test_almost(0.0, 0.1, -3.457265309696613941009, 1e-14, ln_pdf(1.0)); - test_case(0.0, 0.1, -6.666590723732973542744, ln_pdf(5.0)); - test_case(0.0, 1.0, -4.402826423870882219615, ln_pdf(-5.0)); - test_almost(0.0, 1.0, -1.837877066409345483561, 1e-15, ln_pdf(-1.0)); - test_case(0.0, 1.0, -1.144729885849400174143, ln_pdf(0.0)); - test_almost(0.0, 1.0, -1.837877066409345483561, 1e-15, ln_pdf(1.0)); - test_case(0.0, 1.0, -4.402826423870882219615, ln_pdf(5.0)); - test_case(0.0, 10.0, -3.670458530157655613928, ln_pdf(-5.0)); - test_almost(0.0, 10.0, -3.457265309696613941009, 1e-14, ln_pdf(-1.0)); - test_case(0.0, 10.0, -3.447314978843445858161, ln_pdf(0.0)); - test_almost(0.0, 10.0, -3.457265309696613941009, 1e-14, ln_pdf(1.0)); - test_case(0.0, 10.0, -3.670458530157655613928, ln_pdf(5.0)); - test_case(-5.0, 100.0, -5.749900071837491542179, ln_pdf(-5.0)); - test_case(-5.0, 100.0, -5.751498793201188569872, ln_pdf(-1.0)); - test_case(-5.0, 100.0, -5.75239695203607874116, ln_pdf(0.0)); - test_case(-5.0, 100.0, -5.75349360734762171285, ln_pdf(1.0)); - test_case(-5.0, 100.0, -5.759850402690659625027, ln_pdf(5.0)); - test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-5.0)); - test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-1.0)); - test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.0)); - test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(1.0)); - test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(5.0)); - test_case(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(-5.0)); - test_case(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(-1.0)); - test_case(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(0.0)); - test_case(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(1.0)); - test_case(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(5.0)); + test_exact(0.0, 0.1, -6.666590723732973542744, ln_pdf(-5.0)); + test_absolute(0.0, 0.1, -3.457265309696613941009, 1e-14, ln_pdf(-1.0)); + test_exact(0.0, 0.1, 1.157855207144645509875, ln_pdf(0.0)); + test_absolute(0.0, 0.1, -3.457265309696613941009, 1e-14, ln_pdf(1.0)); + test_exact(0.0, 0.1, -6.666590723732973542744, ln_pdf(5.0)); + test_exact(0.0, 1.0, -4.402826423870882219615, ln_pdf(-5.0)); + test_absolute(0.0, 1.0, -1.837877066409345483561, 1e-15, ln_pdf(-1.0)); + test_exact(0.0, 1.0, -1.144729885849400174143, ln_pdf(0.0)); + test_absolute(0.0, 1.0, -1.837877066409345483561, 1e-15, ln_pdf(1.0)); + test_exact(0.0, 1.0, -4.402826423870882219615, ln_pdf(5.0)); + test_exact(0.0, 10.0, -3.670458530157655613928, ln_pdf(-5.0)); + test_absolute(0.0, 10.0, -3.457265309696613941009, 1e-14, ln_pdf(-1.0)); + test_exact(0.0, 10.0, -3.447314978843445858161, ln_pdf(0.0)); + test_absolute(0.0, 10.0, -3.457265309696613941009, 1e-14, ln_pdf(1.0)); + test_exact(0.0, 10.0, -3.670458530157655613928, ln_pdf(5.0)); + test_exact(-5.0, 100.0, -5.749900071837491542179, ln_pdf(-5.0)); + test_exact(-5.0, 100.0, -5.751498793201188569872, ln_pdf(-1.0)); + test_exact(-5.0, 100.0, -5.75239695203607874116, ln_pdf(0.0)); + test_exact(-5.0, 100.0, -5.75349360734762171285, ln_pdf(1.0)); + test_exact(-5.0, 100.0, -5.759850402690659625027, ln_pdf(5.0)); + test_exact(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-5.0)); + test_exact(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-1.0)); + test_exact(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.0)); + test_exact(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(1.0)); + test_exact(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(5.0)); + test_exact(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(-5.0)); + test_exact(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(-1.0)); + test_exact(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(0.0)); + test_exact(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(1.0)); + test_exact(f64::INFINITY, 1.0, f64::NEG_INFINITY, ln_pdf(5.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Cauchy| x.cdf(arg); - test_almost(0.0, 0.1, 0.006365349100972796679298, 1e-16, cdf(-5.0)); - test_almost(0.0, 0.1, 0.03172551743055356951498, 1e-16, cdf(-1.0)); - test_case(0.0, 0.1, 0.5, cdf(0.0)); - test_case(0.0, 0.1, 0.968274482569446430485, cdf(1.0)); - test_case(0.0, 0.1, 0.9936346508990272033207, cdf(5.0)); - test_almost(0.0, 1.0, 0.06283295818900118381375, 1e-16, cdf(-5.0)); - test_case(0.0, 1.0, 0.25, cdf(-1.0)); - test_case(0.0, 1.0, 0.5, cdf(0.0)); - test_case(0.0, 1.0, 0.75, cdf(1.0)); - test_case(0.0, 1.0, 0.9371670418109988161863, cdf(5.0)); - test_case(0.0, 10.0, 0.3524163823495667258246, cdf(-5.0)); - test_case(0.0, 10.0, 0.468274482569446430485, cdf(-1.0)); - test_case(0.0, 10.0, 0.5, cdf(0.0)); - test_case(0.0, 10.0, 0.531725517430553569515, cdf(1.0)); - test_case(0.0, 10.0, 0.6475836176504332741754, cdf(5.0)); - test_case(-5.0, 100.0, 0.5, cdf(-5.0)); - test_case(-5.0, 100.0, 0.5127256113479918307809, cdf(-1.0)); - test_case(-5.0, 100.0, 0.5159022512561763751816, cdf(0.0)); - test_case(-5.0, 100.0, 0.5190757242358362337495, cdf(1.0)); - test_case(-5.0, 100.0, 0.531725517430553569515, cdf(5.0)); - test_case(0.0, f64::INFINITY, 0.5, cdf(-5.0)); - test_case(0.0, f64::INFINITY, 0.5, cdf(-1.0)); - test_case(0.0, f64::INFINITY, 0.5, cdf(0.0)); - test_case(0.0, f64::INFINITY, 0.5, cdf(1.0)); - test_case(0.0, f64::INFINITY, 0.5, cdf(5.0)); - test_case(f64::INFINITY, 1.0, 0.0, cdf(-5.0)); - test_case(f64::INFINITY, 1.0, 0.0, cdf(-1.0)); - test_case(f64::INFINITY, 1.0, 0.0, cdf(0.0)); - test_case(f64::INFINITY, 1.0, 0.0, cdf(1.0)); - test_case(f64::INFINITY, 1.0, 0.0, cdf(5.0)); + test_absolute(0.0, 0.1, 0.006365349100972796679298, 1e-16, cdf(-5.0)); + test_absolute(0.0, 0.1, 0.03172551743055356951498, 1e-16, cdf(-1.0)); + test_exact(0.0, 0.1, 0.5, cdf(0.0)); + test_exact(0.0, 0.1, 0.968274482569446430485, cdf(1.0)); + test_exact(0.0, 0.1, 0.9936346508990272033207, cdf(5.0)); + test_absolute(0.0, 1.0, 0.06283295818900118381375, 1e-16, cdf(-5.0)); + test_exact(0.0, 1.0, 0.25, cdf(-1.0)); + test_exact(0.0, 1.0, 0.5, cdf(0.0)); + test_exact(0.0, 1.0, 0.75, cdf(1.0)); + test_exact(0.0, 1.0, 0.9371670418109988161863, cdf(5.0)); + test_exact(0.0, 10.0, 0.3524163823495667258246, cdf(-5.0)); + test_exact(0.0, 10.0, 0.468274482569446430485, cdf(-1.0)); + test_exact(0.0, 10.0, 0.5, cdf(0.0)); + test_exact(0.0, 10.0, 0.531725517430553569515, cdf(1.0)); + test_exact(0.0, 10.0, 0.6475836176504332741754, cdf(5.0)); + test_exact(-5.0, 100.0, 0.5, cdf(-5.0)); + test_exact(-5.0, 100.0, 0.5127256113479918307809, cdf(-1.0)); + test_exact(-5.0, 100.0, 0.5159022512561763751816, cdf(0.0)); + test_exact(-5.0, 100.0, 0.5190757242358362337495, cdf(1.0)); + test_exact(-5.0, 100.0, 0.531725517430553569515, cdf(5.0)); + test_exact(0.0, f64::INFINITY, 0.5, cdf(-5.0)); + test_exact(0.0, f64::INFINITY, 0.5, cdf(-1.0)); + test_exact(0.0, f64::INFINITY, 0.5, cdf(0.0)); + test_exact(0.0, f64::INFINITY, 0.5, cdf(1.0)); + test_exact(0.0, f64::INFINITY, 0.5, cdf(5.0)); + test_exact(f64::INFINITY, 1.0, 0.0, cdf(-5.0)); + test_exact(f64::INFINITY, 1.0, 0.0, cdf(-1.0)); + test_exact(f64::INFINITY, 1.0, 0.0, cdf(0.0)); + test_exact(f64::INFINITY, 1.0, 0.0, cdf(1.0)); + test_exact(f64::INFINITY, 1.0, 0.0, cdf(5.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Cauchy| x.sf(arg); - test_almost(0.0, 0.1, 0.9936346508990272, 1e-16, sf(-5.0)); - test_almost(0.0, 0.1, 0.9682744825694465, 1e-16, sf(-1.0)); - test_case(0.0, 0.1, 0.5, sf(0.0)); - test_case(0.0, 0.1, 0.03172551743055352, sf(1.0)); - test_case(0.0, 0.1, 0.006365349100972806, sf(5.0)); - test_almost(0.0, 1.0, 0.9371670418109989, 1e-16, sf(-5.0)); - test_case(0.0, 1.0, 0.75, sf(-1.0)); - test_case(0.0, 1.0, 0.5, sf(0.0)); - test_case(0.0, 1.0, 0.25, sf(1.0)); - test_case(0.0, 1.0, 0.06283295818900114, sf(5.0)); - test_case(0.0, 10.0, 0.6475836176504333, sf(-5.0)); - test_case(0.0, 10.0, 0.5317255174305535, sf(-1.0)); - test_case(0.0, 10.0, 0.5, sf(0.0)); - test_case(0.0, 10.0, 0.4682744825694464, sf(1.0)); - test_case(0.0, 10.0, 0.35241638234956674, sf(5.0)); - test_case(-5.0, 100.0, 0.5, sf(-5.0)); - test_case(-5.0, 100.0, 0.4872743886520082, sf(-1.0)); - test_case(-5.0, 100.0, 0.4840977487438236, sf(0.0)); - test_case(-5.0, 100.0, 0.48092427576416374, sf(1.0)); - test_case(-5.0, 100.0, 0.4682744825694464, sf(5.0)); - test_case(0.0, f64::INFINITY, 0.5, sf(-5.0)); - test_case(0.0, f64::INFINITY, 0.5, sf(-1.0)); - test_case(0.0, f64::INFINITY, 0.5, sf(0.0)); - test_case(0.0, f64::INFINITY, 0.5, sf(1.0)); - test_case(0.0, f64::INFINITY, 0.5, sf(5.0)); - test_case(f64::INFINITY, 1.0, 1.0, sf(-5.0)); - test_case(f64::INFINITY, 1.0, 1.0, sf(-1.0)); - test_case(f64::INFINITY, 1.0, 1.0, sf(0.0)); - test_case(f64::INFINITY, 1.0, 1.0, sf(1.0)); - test_case(f64::INFINITY, 1.0, 1.0, sf(5.0)); + test_absolute(0.0, 0.1, 0.9936346508990272, 1e-16, sf(-5.0)); + test_absolute(0.0, 0.1, 0.9682744825694465, 1e-16, sf(-1.0)); + test_exact(0.0, 0.1, 0.5, sf(0.0)); + test_absolute(0.0, 0.1, 0.03172551743055352, 1e-16, sf(1.0)); + test_exact(0.0, 0.1, 0.006365349100972806, sf(5.0)); + test_absolute(0.0, 1.0, 0.9371670418109989, 1e-16, sf(-5.0)); + test_exact(0.0, 1.0, 0.75, sf(-1.0)); + test_exact(0.0, 1.0, 0.5, sf(0.0)); + test_exact(0.0, 1.0, 0.25, sf(1.0)); + test_exact(0.0, 1.0, 0.06283295818900114, sf(5.0)); + test_exact(0.0, 10.0, 0.6475836176504333, sf(-5.0)); + test_exact(0.0, 10.0, 0.5317255174305535, sf(-1.0)); + test_exact(0.0, 10.0, 0.5, sf(0.0)); + test_exact(0.0, 10.0, 0.4682744825694464, sf(1.0)); + test_exact(0.0, 10.0, 0.35241638234956674, sf(5.0)); + test_exact(-5.0, 100.0, 0.5, sf(-5.0)); + test_exact(-5.0, 100.0, 0.4872743886520082, sf(-1.0)); + test_exact(-5.0, 100.0, 0.4840977487438236, sf(0.0)); + test_exact(-5.0, 100.0, 0.48092427576416374, sf(1.0)); + test_exact(-5.0, 100.0, 0.4682744825694464, sf(5.0)); + test_exact(0.0, f64::INFINITY, 0.5, sf(-5.0)); + test_exact(0.0, f64::INFINITY, 0.5, sf(-1.0)); + test_exact(0.0, f64::INFINITY, 0.5, sf(0.0)); + test_exact(0.0, f64::INFINITY, 0.5, sf(1.0)); + test_exact(0.0, f64::INFINITY, 0.5, sf(5.0)); + test_exact(f64::INFINITY, 1.0, 1.0, sf(-5.0)); + test_exact(f64::INFINITY, 1.0, 1.0, sf(-1.0)); + test_exact(f64::INFINITY, 1.0, 1.0, sf(0.0)); + test_exact(f64::INFINITY, 1.0, 1.0, sf(1.0)); + test_exact(f64::INFINITY, 1.0, 1.0, sf(5.0)); + } + + #[test] + fn test_inverse_cdf() { + let func = |arg: f64| move |x: Cauchy| x.inverse_cdf(x.cdf(arg)); + test_absolute(0.0, 0.1, -5.0, 1e-10, func(-5.0)); + test_absolute(0.0, 0.1, -1.0, 1e-14, func(-1.0)); + test_exact(0.0, 0.1, 0.0, func(0.0)); + test_absolute(0.0, 0.1, 1.0, 1e-14, func(1.0)); + test_absolute(0.0, 0.1, 5.0, 1e-10, func(5.0)); + test_absolute(0.0, 1.0, -5.0, 1e-14, func(-5.0)); + test_absolute(0.0, 1.0, -1.0, 1e-15, func(-1.0)); + test_exact(0.0, 1.0, 0.0, func(0.0)); + test_absolute(0.0, 1.0, 1.0, 1e-15, func(1.0)); + test_absolute(0.0, 1.0, 5.0, 1e-14, func(5.0)); + test_absolute(0.0, 10.0, -5.0, 1e-14, func(-5.0)); + test_absolute(0.0, 10.0, -1.0, 1e-14, func(-1.0)); + test_exact(0.0, 10.0, 0.0, func(0.0)); + test_absolute(0.0, 10.0, 1.0, 1e-14, func(1.0)); + test_absolute(0.0, 10.0, 5.0, 1e-14, func(5.0)); + test_exact(-5.0, 100.0, -5.0, func(-5.0)); + test_absolute(-5.0, 100.0, -1.0, 1e-10, func(-1.0)); + test_absolute(-5.0, 100.0, 0.0, 1e-14, func(0.0)); + test_absolute(-5.0, 100.0, 1.0, 1e-14, func(1.0)); + test_absolute(-5.0, 100.0, 5.0, 1e-10, func(5.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(-1.2, 3.4), -1500.0, 1500.0); - test::check_continuous_distribution(&try_create(-4.5, 6.7), -5000.0, 5000.0); + test::check_continuous_distribution(&create_ok(-1.2, 3.4), -1500.0, 1500.0); + test::check_continuous_distribution(&create_ok(-4.5, 6.7), -5000.0, 5000.0); } } diff --git a/src/distribution/chi.rs b/src/distribution/chi.rs index 7fcccf5e..bcb98481 100644 --- a/src/distribution/chi.rs +++ b/src/distribution/chi.rs @@ -1,8 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; use crate::statistics::*; -use crate::{Result, StatsError}; -use rand::Rng; use std::f64; /// Implements the [Chi](https://en.wikipedia.org/wiki/Chi_distribution) @@ -19,11 +17,32 @@ use std::f64; /// assert!(prec::almost_eq(n.mean().unwrap(), 1.25331413731550025121, 1e-14)); /// assert!(prec::almost_eq(n.pdf(1.0), 0.60653065971263342360, 1e-15)); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Chi { freedom: f64, } +/// Represents the errors that can occur when creating a [`Chi`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum ChiError { + /// The degrees of freedom are NaN, zero or less than zero. + FreedomInvalid, +} + +impl std::fmt::Display for ChiError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + ChiError::FreedomInvalid => { + write!(f, "Degrees of freedom are NaN, zero or less than zero") + } + } + } +} + +impl std::error::Error for ChiError {} + impl Chi { /// Constructs a new chi distribution /// with `freedom` degrees of freedom @@ -44,9 +63,9 @@ impl Chi { /// result = Chi::new(0.0); /// assert!(result.is_err()); /// ``` - pub fn new(freedom: f64) -> Result { + pub fn new(freedom: f64) -> Result { if freedom.is_nan() || freedom <= 0.0 { - Err(StatsError::BadParams) + Err(ChiError::FreedomInvalid) } else { Ok(Chi { freedom }) } @@ -68,8 +87,15 @@ impl Chi { } } +impl std::fmt::Display for Chi { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "χ_{}", self.freedom) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Chi { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { (0..self.freedom as i64) .fold(0.0, |acc, _| { acc + super::normal::sample_unchecked(rng, 0.0, 1.0).powf(2.0) @@ -84,7 +110,7 @@ impl ContinuousCDF for Chi { /// /// # Formula /// - /// ```ignore + /// ```text /// P(k / 2, x^2 / 2) /// ``` /// @@ -105,7 +131,7 @@ impl ContinuousCDF for Chi { /// /// # Formula /// - /// ```ignore + /// ```text /// P(k / 2, x^2 / 2) /// ``` /// @@ -128,7 +154,7 @@ impl Min for Chi { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> f64 { @@ -142,8 +168,8 @@ impl Max for Chi { /// /// # Formula /// - /// ```ignore - /// INF + /// ```text + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY @@ -159,7 +185,7 @@ impl Distribution for Chi { /// /// # Formula /// - /// ```ignore + /// ```text /// sqrt2 * Γ((k + 1) / 2) / Γ(k / 2) /// ``` /// @@ -185,6 +211,7 @@ impl Distribution for Chi { Some(mean) } } + /// Returns the variance of the chi distribution /// /// # Remarks @@ -193,7 +220,7 @@ impl Distribution for Chi { /// /// # Formula /// - /// ```ignore + /// ```text /// k - μ^2 /// ``` /// @@ -203,6 +230,7 @@ impl Distribution for Chi { let mean = self.mean()?; Some(self.freedom - mean * mean) } + /// Returns the entropy of the chi distribution /// /// # Remarks @@ -211,7 +239,7 @@ impl Distribution for Chi { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(Γ(k / 2)) + 0.5 * (k - ln2 - (k - 1) * ψ(k / 2)) /// ``` /// @@ -228,6 +256,7 @@ impl Distribution for Chi { / 2.0; Some(entr) } + /// Returns the skewness of the chi distribution /// /// # Remarks @@ -236,7 +265,7 @@ impl Distribution for Chi { /// /// # Formula /// - /// ```ignore + /// ```text /// (μ / σ^3) * (1 - 2σ^2) /// ``` /// where `μ` is the mean and `σ` the standard deviation @@ -257,7 +286,7 @@ impl Mode> for Chi { /// /// # Formula /// - /// ```ignore + /// ```text /// sqrt(k - 1) /// ``` /// @@ -276,7 +305,7 @@ impl Continuous for Chi { /// /// # Formula /// - /// ```ignore + /// ```text /// (2^(1 - (k / 2)) * x^(k - 1) * e^(-x^2 / 2)) / Γ(k / 2) /// ``` /// @@ -299,7 +328,7 @@ impl Continuous for Chi { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((2^(1 - (k / 2)) * x^(k - 1) * e^(-x^2 / 2)) / Γ(k / 2)) /// ``` fn ln_pdf(&self, x: f64) -> f64 { @@ -314,78 +343,38 @@ impl Continuous for Chi { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use std::f64; + use super::*; use crate::distribution::internal::*; - use crate::distribution::{Chi, Continuous, ContinuousCDF}; - use crate::statistics::*; - use crate::consts::ACC; + use crate::testing_boiler; - fn try_create(freedom: f64) -> Chi { - let n = Chi::new(freedom); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(freedom: f64) { - let n = try_create(freedom); - assert_eq!(freedom, n.freedom()); - } - - fn bad_create_case(freedom: f64) { - let n = Chi::new(freedom); - assert!(n.is_err()); - } - - fn get_value(freedom: f64, eval: F) -> f64 - where - F: Fn(Chi) -> f64, - { - let n = try_create(freedom); - eval(n) - } - - fn test_case(freedom: f64, expected: f64, eval: F) - where - F: Fn(Chi) -> f64, - { - let x = get_value(freedom, eval); - assert_eq!(expected, x); - } - - fn test_almost(freedom: f64, expected: f64, acc: f64, eval: F) - where - F: Fn(Chi) -> f64, - { - let x = get_value(freedom, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(freedom: f64; Chi; ChiError); #[test] fn test_create() { - create_case(1.0); - create_case(3.0); - create_case(f64::INFINITY); + create_ok(1.0); + create_ok(3.0); + create_ok(f64::INFINITY); } #[test] fn test_bad_create() { - bad_create_case(0.0); - bad_create_case(-1.0); - bad_create_case(-100.0); - bad_create_case(f64::NEG_INFINITY); - bad_create_case(f64::NAN); + create_err(0.0); + create_err(-1.0); + create_err(-100.0); + create_err(f64::NEG_INFINITY); + create_err(f64::NAN); } #[test] fn test_mean() { let mean = |x: Chi| x.mean().unwrap(); - test_almost(1.0, 0.7978845608028653558799, 1e-15, mean); - test_almost(2.0, 1.25331413731550025121, 1e-14, mean); - test_almost(2.5, 1.43396639245837498609, 1e-14, mean); - test_almost(5.0, 2.12769216214097428235, 1e-14, mean); - test_almost(336.0, 18.31666925443713, 1e-12, mean); + test_absolute(1.0, 0.7978845608028653558799, 1e-15, mean); + test_absolute(2.0, 1.25331413731550025121, 1e-14, mean); + test_absolute(2.5, 1.43396639245837498609, 1e-14, mean); + test_absolute(5.0, 2.12769216214097428235, 1e-14, mean); + test_absolute(336.0, 18.31666925443713, 1e-12, mean); } #[test] @@ -397,223 +386,213 @@ mod tests { } #[test] - #[should_panic] fn test_mean_degen() { - let mean = |x: Chi| x.mean().unwrap(); - get_value(f64::INFINITY, mean); + test_none(f64::INFINITY, |dist| dist.mean()); } #[test] fn test_variance() { let variance = |x: Chi| x.variance().unwrap(); - test_almost(1.0, 0.3633802276324186569245, 1e-15, variance); - test_almost(2.0, 0.42920367320510338077, 1e-14, variance); - test_almost(2.5, 0.44374038529991368581, 1e-13, variance); - test_almost(3.0, 0.4535209105296746277, 1e-14, variance); + test_absolute(1.0, 0.3633802276324186569245, 1e-15, variance); + test_absolute(2.0, 0.42920367320510338077, 1e-14, variance); + test_absolute(2.5, 0.44374038529991368581, 1e-13, variance); + test_absolute(3.0, 0.4535209105296746277, 1e-14, variance); } #[test] - #[should_panic] fn test_variance_degen() { - let variance = |x: Chi| x.variance().unwrap(); - get_value(f64::INFINITY, variance); + test_none(f64::INFINITY, |dist| dist.variance()); } #[test] fn test_entropy() { let entropy = |x: Chi| x.entropy().unwrap(); - test_almost(1.0, 0.7257913526447274323631, 1e-15, entropy); - test_almost(2.0, 0.9420342421707937755946, 1e-15, entropy); - test_almost(2.5, 0.97574472333041323989, 1e-14, entropy); - test_almost(3.0, 0.99615419810620560239, 1e-14, entropy); + test_absolute(1.0, 0.7257913526447274323631, 1e-15, entropy); + test_absolute(2.0, 0.9420342421707937755946, 1e-15, entropy); + test_absolute(2.5, 0.97574472333041323989, 1e-14, entropy); + test_absolute(3.0, 0.99615419810620560239, 1e-14, entropy); } #[test] - #[should_panic] fn test_entropy_degen() { - let entropy = |x: Chi| x.entropy().unwrap(); - get_value(f64::INFINITY, entropy); + test_none(f64::INFINITY, |dist| dist.entropy()); } #[test] fn test_skewness() { let skewness = |x: Chi| x.skewness().unwrap(); - test_almost(1.0, 0.995271746431156042444, 1e-14, skewness); - test_almost(2.0, 0.6311106578189371382, 1e-13, skewness); - test_almost(2.5, 0.5458487096285153216, 1e-12, skewness); - test_almost(3.0, 0.485692828049590809, 1e-12, skewness); + test_absolute(1.0, 0.995271746431156042444, 1e-14, skewness); + test_absolute(2.0, 0.6311106578189371382, 1e-13, skewness); + test_absolute(2.5, 0.5458487096285153216, 1e-12, skewness); + test_absolute(3.0, 0.485692828049590809, 1e-12, skewness); } #[test] - #[should_panic] fn test_skewness_degen() { - let skewness = |x: Chi| x.skewness().unwrap(); - get_value(f64::INFINITY, skewness); + test_none(f64::INFINITY, |dist| dist.skewness()); } #[test] fn test_mode() { let mode = |x: Chi| x.mode().unwrap(); - test_case(1.0, 0.0, mode); - test_case(2.0, 1.0, mode); - test_case(2.5, 1.224744871391589049099, mode); - test_case(3.0, f64::consts::SQRT_2, mode); - test_case(f64::INFINITY, f64::INFINITY, mode); + test_exact(1.0, 0.0, mode); + test_exact(2.0, 1.0, mode); + test_exact(2.5, 1.224744871391589049099, mode); + test_exact(3.0, f64::consts::SQRT_2, mode); + test_exact(f64::INFINITY, f64::INFINITY, mode); } #[test] - #[should_panic] fn test_mode_freedom_lt_1() { - let mode = |x: Chi| x.mode().unwrap(); - get_value(0.5, mode); + test_none(0.5, |dist| dist.mode()); } #[test] fn test_min_max() { let min = |x: Chi| x.min(); let max = |x: Chi| x.max(); - test_case(1.0, 0.0, min); - test_case(2.0, 0.0, min); - test_case(2.5, 0.0, min); - test_case(3.0, 0.0, min); - test_case(f64::INFINITY, 0.0, min); - test_case(1.0, f64::INFINITY, max); - test_case(2.0, f64::INFINITY, max); - test_case(2.5, f64::INFINITY, max); - test_case(3.0, f64::INFINITY, max); - test_case(f64::INFINITY, f64::INFINITY, max); + test_exact(1.0, 0.0, min); + test_exact(2.0, 0.0, min); + test_exact(2.5, 0.0, min); + test_exact(3.0, 0.0, min); + test_exact(f64::INFINITY, 0.0, min); + test_exact(1.0, f64::INFINITY, max); + test_exact(2.0, f64::INFINITY, max); + test_exact(2.5, f64::INFINITY, max); + test_exact(3.0, f64::INFINITY, max); + test_exact(f64::INFINITY, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Chi| x.pdf(arg); - test_case(1.0, 0.0, pdf(0.0)); - test_almost(1.0, 0.79390509495402353102, 1e-15, pdf(0.1)); - test_almost(1.0, 0.48394144903828669960, 1e-15, pdf(1.0)); - test_almost(1.0, 2.1539520085086552718e-7, 1e-22, pdf(5.5)); - test_case(1.0, 0.0, pdf(f64::INFINITY)); - test_case(2.0, 0.0, pdf(0.0)); - test_almost(2.0, 0.099501247919268231335, 1e-16, pdf(0.1)); - test_almost(2.0, 0.60653065971263342360, 1e-15, pdf(1.0)); - test_almost(2.0, 1.4847681768496578863e-6, 1e-21, pdf(5.5)); - test_case(2.0, 0.0, pdf(f64::INFINITY)); - test_case(2.5, 0.0, pdf(0.0)); - test_almost(2.5, 0.029191065334961657461, 1e-16, pdf(0.1)); - test_almost(2.5, 0.56269645152636456261, 1e-15, pdf(1.0)); - test_almost(2.5, 3.2304380188895211768e-6, 1e-20, pdf(5.5)); - test_case(2.5, 0.0, pdf(f64::INFINITY)); - test_case(f64::INFINITY, 0.0, pdf(0.0)); - test_case(f64::INFINITY, 0.0, pdf(0.1)); - test_case(f64::INFINITY, 0.0, pdf(1.0)); - test_case(f64::INFINITY, 0.0, pdf(5.5)); - test_case(f64::INFINITY, 0.0, pdf(f64::INFINITY)); - test_almost(170.0, 0.5644678498668440878, 1e-13, pdf(13.0)); + test_exact(1.0, 0.0, pdf(0.0)); + test_absolute(1.0, 0.79390509495402353102, 1e-15, pdf(0.1)); + test_absolute(1.0, 0.48394144903828669960, 1e-15, pdf(1.0)); + test_absolute(1.0, 2.1539520085086552718e-7, 1e-22, pdf(5.5)); + test_exact(1.0, 0.0, pdf(f64::INFINITY)); + test_exact(2.0, 0.0, pdf(0.0)); + test_absolute(2.0, 0.099501247919268231335, 1e-16, pdf(0.1)); + test_absolute(2.0, 0.60653065971263342360, 1e-15, pdf(1.0)); + test_absolute(2.0, 1.4847681768496578863e-6, 1e-21, pdf(5.5)); + test_exact(2.0, 0.0, pdf(f64::INFINITY)); + test_exact(2.5, 0.0, pdf(0.0)); + test_absolute(2.5, 0.029191065334961657461, 1e-16, pdf(0.1)); + test_absolute(2.5, 0.56269645152636456261, 1e-15, pdf(1.0)); + test_absolute(2.5, 3.2304380188895211768e-6, 1e-20, pdf(5.5)); + test_exact(2.5, 0.0, pdf(f64::INFINITY)); + test_exact(f64::INFINITY, 0.0, pdf(0.0)); + test_exact(f64::INFINITY, 0.0, pdf(0.1)); + test_exact(f64::INFINITY, 0.0, pdf(1.0)); + test_exact(f64::INFINITY, 0.0, pdf(5.5)); + test_exact(f64::INFINITY, 0.0, pdf(f64::INFINITY)); + test_absolute(170.0, 0.5644678498668440878, 1e-13, pdf(13.0)); } #[test] fn test_neg_pdf() { let pdf = |arg: f64| move |x: Chi| x.pdf(arg); - test_case(1.0, 0.0, pdf(-1.0)); + test_exact(1.0, 0.0, pdf(-1.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Chi| x.ln_pdf(arg); - test_case(1.0, f64::NEG_INFINITY, ln_pdf(0.0)); - test_almost(1.0, -0.23079135264472743236, 1e-15, ln_pdf(0.1)); - test_almost(1.0, -0.72579135264472743236, 1e-15, ln_pdf(1.0)); - test_almost(1.0, -15.350791352644727432, 1e-14, ln_pdf(5.5)); - test_case(1.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); - test_case(2.0, f64::NEG_INFINITY, ln_pdf(0.0)); - test_almost(2.0, -2.3075850929940456840, 1e-15, ln_pdf(0.1)); - test_almost(2.0, -0.5, 1e-15, ln_pdf(1.0)); - test_almost(2.0, -13.420251907761574765, 1e-15, ln_pdf(5.5)); - test_case(2.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); - test_case(2.5, f64::NEG_INFINITY, ln_pdf(0.0)); - test_almost(2.5, -3.5338925982092416919, 1e-15, ln_pdf(0.1)); - test_almost(2.5, -0.57501495871817316589, 1e-15, ln_pdf(1.0)); - test_almost(2.5, -12.642892820360535314, 1e-16, ln_pdf(5.5)); - test_case(2.5, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); - test_case(f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.0)); - test_case(f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.1)); - test_case(f64::INFINITY, f64::NEG_INFINITY, ln_pdf(1.0)); - test_case(f64::INFINITY, f64::NEG_INFINITY, ln_pdf(5.5)); - test_case(f64::INFINITY, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); - test_almost(170.0, -0.57187185030600516424237, 1e-13, ln_pdf(13.0)); + test_exact(1.0, f64::NEG_INFINITY, ln_pdf(0.0)); + test_absolute(1.0, -0.23079135264472743236, 1e-15, ln_pdf(0.1)); + test_absolute(1.0, -0.72579135264472743236, 1e-15, ln_pdf(1.0)); + test_absolute(1.0, -15.350791352644727432, 1e-14, ln_pdf(5.5)); + test_exact(1.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); + test_exact(2.0, f64::NEG_INFINITY, ln_pdf(0.0)); + test_absolute(2.0, -2.3075850929940456840, 1e-15, ln_pdf(0.1)); + test_absolute(2.0, -0.5, 1e-15, ln_pdf(1.0)); + test_absolute(2.0, -13.420251907761574765, 1e-15, ln_pdf(5.5)); + test_exact(2.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); + test_exact(2.5, f64::NEG_INFINITY, ln_pdf(0.0)); + test_absolute(2.5, -3.5338925982092416919, 1e-15, ln_pdf(0.1)); + test_absolute(2.5, -0.57501495871817316589, 1e-15, ln_pdf(1.0)); + test_absolute(2.5, -12.642892820360535314, 1e-16, ln_pdf(5.5)); + test_exact(2.5, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); + test_exact(f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.0)); + test_exact(f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.1)); + test_exact(f64::INFINITY, f64::NEG_INFINITY, ln_pdf(1.0)); + test_exact(f64::INFINITY, f64::NEG_INFINITY, ln_pdf(5.5)); + test_exact(f64::INFINITY, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); + test_absolute(170.0, -0.57187185030600516424237, 1e-13, ln_pdf(13.0)); } #[test] fn test_neg_ln_pdf() { let ln_pdf = |arg: f64| move |x: Chi| x.ln_pdf(arg); - test_case(1.0, f64::NEG_INFINITY, ln_pdf(-1.0)); + test_exact(1.0, f64::NEG_INFINITY, ln_pdf(-1.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Chi| x.cdf(arg); - test_case(1.0, 0.0, cdf(0.0)); - test_almost(1.0, 0.079655674554057962931, 1e-16, cdf(0.1)); - test_almost(1.0, 0.68268949213708589717, 1e-15, cdf(1.0)); - test_case(1.0, 0.99999996202087506822, cdf(5.5)); - test_case(1.0, 1.0, cdf(f64::INFINITY)); - test_case(2.0, 0.0, cdf(0.0)); - test_almost(2.0, 0.0049875208073176866474, 1e-17, cdf(0.1)); - test_almost(2.0, 0.39346934028736657640, 1e-15, cdf(1.0)); - test_case(2.0, 0.99999973004214966370, cdf(5.5)); - test_case(2.0, 1.0, cdf(f64::INFINITY)); - test_case(2.5, 0.0, cdf(0.0)); - test_almost(2.5, 0.0011702413714030096290, 1e-18, cdf(0.1)); - test_almost(2.5, 0.28378995266531297417, 1e-16, cdf(1.0)); - test_case(2.5, 0.99999940337322804750, cdf(5.5)); - test_case(2.5, 1.0, cdf(f64::INFINITY)); - test_case(f64::INFINITY, 1.0, cdf(0.0)); - test_case(f64::INFINITY, 1.0, cdf(0.1)); - test_case(f64::INFINITY, 1.0, cdf(1.0)); - test_case(f64::INFINITY, 1.0, cdf(5.5)); - test_case(f64::INFINITY, 1.0, cdf(f64::INFINITY)); + test_exact(1.0, 0.0, cdf(0.0)); + test_absolute(1.0, 0.079655674554057962931, 1e-16, cdf(0.1)); + test_absolute(1.0, 0.68268949213708589717, 1e-15, cdf(1.0)); + test_exact(1.0, 0.99999996202087506822, cdf(5.5)); + test_exact(1.0, 1.0, cdf(f64::INFINITY)); + test_exact(2.0, 0.0, cdf(0.0)); + test_absolute(2.0, 0.0049875208073176866474, 1e-17, cdf(0.1)); + test_absolute(2.0, 0.39346934028736657640, 1e-15, cdf(1.0)); + test_exact(2.0, 0.99999973004214966370, cdf(5.5)); + test_exact(2.0, 1.0, cdf(f64::INFINITY)); + test_exact(2.5, 0.0, cdf(0.0)); + test_absolute(2.5, 0.0011702413714030096290, 1e-18, cdf(0.1)); + test_absolute(2.5, 0.28378995266531297417, 1e-16, cdf(1.0)); + test_exact(2.5, 0.99999940337322804750, cdf(5.5)); + test_exact(2.5, 1.0, cdf(f64::INFINITY)); + test_exact(f64::INFINITY, 1.0, cdf(0.0)); + test_exact(f64::INFINITY, 1.0, cdf(0.1)); + test_exact(f64::INFINITY, 1.0, cdf(1.0)); + test_exact(f64::INFINITY, 1.0, cdf(5.5)); + test_exact(f64::INFINITY, 1.0, cdf(f64::INFINITY)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Chi| x.sf(arg); - test_case(1.0, 1.0, sf(0.0)); - test_almost(1.0, 0.920344325445942, 1e-16, sf(0.1)); - test_almost(1.0, 0.31731050786291404, 1e-15, sf(1.0)); - test_almost(1.0, 3.797912493177544e-8, 1e-15, sf(5.5)); - test_case(1.0, 0.0, sf(f64::INFINITY)); - test_case(2.0, 1.0, sf(0.0)); - test_almost(2.0, 0.9950124791926823, 1e-17, sf(0.1)); - test_almost(2.0, 0.6065306597126333, 1e-15, sf(1.0)); - test_almost(2.0, 2.699578503363014e-7, 1e-15, sf(5.5)); - test_case(2.0, 0.0, sf(f64::INFINITY)); - test_case(2.5, 1.0, sf(0.0)); - test_almost(2.5, 0.998829758628597, 1e-18, sf(0.1)); - test_almost(2.5, 0.716210047334687, 1e-16, sf(1.0)); - test_almost(2.5, 5.966267719870189e-7, 1e-15, sf(5.5)); - test_case(2.5, 0.0, sf(f64::INFINITY)); - test_case(f64::INFINITY, 0.0, sf(0.0)); - test_case(f64::INFINITY, 0.0, sf(0.1)); - test_case(f64::INFINITY, 0.0, sf(1.0)); - test_case(f64::INFINITY, 0.0, sf(5.5)); - test_case(f64::INFINITY, 0.0, sf(f64::INFINITY)); + test_exact(1.0, 1.0, sf(0.0)); + test_absolute(1.0, 0.920344325445942, 1e-16, sf(0.1)); + test_absolute(1.0, 0.31731050786291404, 1e-15, sf(1.0)); + test_absolute(1.0, 3.797912493177544e-8, 1e-15, sf(5.5)); + test_exact(1.0, 0.0, sf(f64::INFINITY)); + test_exact(2.0, 1.0, sf(0.0)); + test_absolute(2.0, 0.9950124791926823, 1e-17, sf(0.1)); + test_absolute(2.0, 0.6065306597126333, 1e-15, sf(1.0)); + test_absolute(2.0, 2.699578503363014e-7, 1e-15, sf(5.5)); + test_exact(2.0, 0.0, sf(f64::INFINITY)); + test_exact(2.5, 1.0, sf(0.0)); + test_absolute(2.5, 0.998829758628597, 1e-18, sf(0.1)); + test_absolute(2.5, 0.716210047334687, 1e-16, sf(1.0)); + test_absolute(2.5, 5.966267719870189e-7, 1e-15, sf(5.5)); + test_exact(2.5, 0.0, sf(f64::INFINITY)); + test_exact(f64::INFINITY, 0.0, sf(0.0)); + test_exact(f64::INFINITY, 0.0, sf(0.1)); + test_exact(f64::INFINITY, 0.0, sf(1.0)); + test_exact(f64::INFINITY, 0.0, sf(5.5)); + test_exact(f64::INFINITY, 0.0, sf(f64::INFINITY)); } #[test] fn test_neg_cdf() { let cdf = |arg: f64| move |x: Chi| x.cdf(arg); - test_case(1.0, 0.0, cdf(-1.0)); + test_exact(1.0, 0.0, cdf(-1.0)); } #[test] fn test_neg_sf() { let sf = |arg: f64| move |x: Chi| x.sf(arg); - test_case(1.0, 1.0, sf(-1.0)); + test_exact(1.0, 1.0, sf(-1.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(1.0), 0.0, 10.0); - test::check_continuous_distribution(&try_create(2.0), 0.0, 10.0); - test::check_continuous_distribution(&try_create(5.0), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(1.0), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(2.0), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(5.0), 0.0, 10.0); } } diff --git a/src/distribution/chi_squared.rs b/src/distribution/chi_squared.rs index 05ad63ba..f61d6e19 100644 --- a/src/distribution/chi_squared.rs +++ b/src/distribution/chi_squared.rs @@ -1,7 +1,5 @@ -use crate::distribution::{Continuous, ContinuousCDF, Gamma}; +use crate::distribution::{Continuous, ContinuousCDF, Gamma, GammaError}; use crate::statistics::*; -use crate::Result; -use rand::Rng; use std::f64; /// Implements the @@ -21,7 +19,7 @@ use std::f64; /// assert_eq!(n.mean().unwrap(), 3.0); /// assert!(prec::almost_eq(n.pdf(4.0), 0.107981933026376103901, 1e-15)); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct ChiSquared { freedom: f64, g: Gamma, @@ -48,7 +46,7 @@ impl ChiSquared { /// result = ChiSquared::new(0.0); /// assert!(result.is_err()); /// ``` - pub fn new(freedom: f64) -> Result { + pub fn new(freedom: f64) -> Result { Gamma::new(freedom / 2.0, 0.5).map(|g| ChiSquared { freedom, g }) } @@ -96,8 +94,15 @@ impl ChiSquared { } } +impl std::fmt::Display for ChiSquared { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "χ^2_{}", self.freedom) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for ChiSquared { - fn sample(&self, r: &mut R) -> f64 { + fn sample(&self, r: &mut R) -> f64 { ::rand::distributions::Distribution::sample(&self.g, r) } } @@ -108,7 +113,7 @@ impl ContinuousCDF for ChiSquared { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / Γ(k / 2)) * γ(k / 2, x / 2) /// ``` /// @@ -123,7 +128,7 @@ impl ContinuousCDF for ChiSquared { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / Γ(k / 2)) * γ(k / 2, x / 2) /// ``` /// @@ -132,6 +137,21 @@ impl ContinuousCDF for ChiSquared { fn sf(&self, x: f64) -> f64 { self.g.sf(x) } + + /// Calculates the inverse cumulative distribution function for the + /// chi-squared distribution at `x` + /// + /// # Formula + /// + /// ```text + /// γ^{-1}(k / 2, x * Γ(k / 2) / 2) + /// ``` + /// + /// where `k` is the degrees of freedom, `Γ` is the gamma function, + /// and `γ` is the lower incomplete gamma function + fn inverse_cdf(&self, p: f64) -> f64 { + self.g.inverse_cdf(p) + } } impl Min for ChiSquared { @@ -141,7 +161,7 @@ impl Min for ChiSquared { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> f64 { @@ -156,8 +176,8 @@ impl Max for ChiSquared { /// /// # Formula /// - /// ```ignore - /// INF + /// ```text + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY @@ -169,7 +189,7 @@ impl Distribution for ChiSquared { /// /// # Formula /// - /// ```ignore + /// ```text /// k /// ``` /// @@ -177,11 +197,12 @@ impl Distribution for ChiSquared { fn mean(&self) -> Option { self.g.mean() } + /// Returns the variance of the chi-squared distribution /// /// # Formula /// - /// ```ignore + /// ```text /// 2k /// ``` /// @@ -189,11 +210,12 @@ impl Distribution for ChiSquared { fn variance(&self) -> Option { self.g.variance() } + /// Returns the entropy of the chi-squared distribution /// /// # Formula /// - /// ```ignore + /// ```text /// (k / 2) + ln(2 * Γ(k / 2)) + (1 - (k / 2)) * ψ(k / 2) /// ``` /// @@ -202,11 +224,12 @@ impl Distribution for ChiSquared { fn entropy(&self) -> Option { self.g.entropy() } + /// Returns the skewness of the chi-squared distribution /// /// # Formula /// - /// ```ignore + /// ```text /// sqrt(8 / k) /// ``` /// @@ -221,7 +244,7 @@ impl Median for ChiSquared { /// /// # Formula /// - /// ```ignore + /// ```text /// k * (1 - (2 / 9k))^3 /// ``` fn median(&self) -> f64 { @@ -241,7 +264,7 @@ impl Mode> for ChiSquared { /// /// # Formula /// - /// ```ignore + /// ```text /// k - 2 /// ``` /// @@ -257,7 +280,7 @@ impl Continuous for ChiSquared { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 / (2^(k / 2) * Γ(k / 2)) * x^((k / 2) - 1) * e^(-x / 2) /// ``` /// @@ -271,7 +294,7 @@ impl Continuous for ChiSquared { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(1 / (2^(k / 2) * Γ(k / 2)) * x^((k / 2) - 1) * e^(-x / 2)) /// ``` fn ln_pdf(&self, x: f64) -> f64 { @@ -280,50 +303,29 @@ impl Continuous for ChiSquared { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use crate::statistics::Median; - use crate::distribution::ChiSquared; + use super::*; use crate::distribution::internal::*; - use crate::consts::ACC; + use crate::testing_boiler; - fn try_create(freedom: f64) -> ChiSquared { - let n = ChiSquared::new(freedom); - assert!(n.is_ok()); - n.unwrap() - } - - fn test_case(freedom: f64, expected: f64, eval: F) - where F: Fn(ChiSquared) -> f64 - { - let n = try_create(freedom); - let x = eval(n); - assert_eq!(expected, x); - } - - fn test_almost(freedom: f64, expected: f64, acc: f64, eval: F) - where F: Fn(ChiSquared) -> f64 - { - let n = try_create(freedom); - let x = eval(n); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(freedom: f64; ChiSquared; GammaError); #[test] fn test_median() { let median = |x: ChiSquared| x.median(); - test_almost(0.5, 0.0857338820301783264746, 1e-16, median); - test_case(1.0, 1.0 - 2.0 / 3.0, median); - test_case(2.0, 2.0 - 2.0 / 3.0, median); - test_case(2.5, 2.5 - 2.0 / 3.0, median); - test_case(3.0, 3.0 - 2.0 / 3.0, median); + test_absolute(0.5, 0.0857338820301783264746, 1e-16, median); + test_exact(1.0, 1.0 - 2.0 / 3.0, median); + test_exact(2.0, 2.0 - 2.0 / 3.0, median); + test_exact(2.5, 2.5 - 2.0 / 3.0, median); + test_exact(3.0, 3.0 - 2.0 / 3.0, median); } #[test] fn test_continuous() { // TODO: figure out why this test fails: - //test::check_continuous_distribution(&try_create(1.0), 0.0, 10.0); - test::check_continuous_distribution(&try_create(2.0), 0.0, 10.0); - test::check_continuous_distribution(&try_create(5.0), 0.0, 50.0); + //test::check_continuous_distribution(&create_ok(1.0), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(2.0), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(5.0), 0.0, 50.0); } } diff --git a/src/distribution/dirac.rs b/src/distribution/dirac.rs index b58b676b..ec833d93 100644 --- a/src/distribution/dirac.rs +++ b/src/distribution/dirac.rs @@ -1,7 +1,5 @@ -use crate::distribution::{Continuous, ContinuousCDF}; +use crate::distribution::ContinuousCDF; use crate::statistics::*; -use crate::{Result, StatsError}; -use rand::Rng; /// Implements the [Dirac Delta](https://en.wikipedia.org/wiki/Dirac_delta_function#As_a_distribution) /// distribution @@ -18,8 +16,27 @@ use rand::Rng; #[derive(Debug, Copy, Clone, PartialEq)] pub struct Dirac(f64); +/// Represents the errors that can occur when creating a [`Dirac`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum DiracError { + /// The value v is NaN. + ValueInvalid, +} + +impl std::fmt::Display for DiracError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + DiracError::ValueInvalid => write!(f, "Value v is NaN"), + } + } +} + +impl std::error::Error for DiracError {} + impl Dirac { - /// Constructs a new dirac distribution function at value `v`. + /// Constructs a new dirac distribution function at value `v`. /// /// # Errors /// @@ -36,17 +53,24 @@ impl Dirac { /// result = Dirac::new(f64::NAN); /// assert!(result.is_err()); /// ``` - pub fn new(v: f64) -> Result { + pub fn new(v: f64) -> Result { if v.is_nan() { - Err(StatsError::BadParams) + Err(DiracError::ValueInvalid) } else { Ok(Dirac(v)) } } } +impl std::fmt::Display for Dirac { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "δ_{}", self.0) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Dirac { - fn sample(&self, _: &mut R) -> f64 { + fn sample(&self, _: &mut R) -> f64 { self.0 } } @@ -56,7 +80,6 @@ impl ContinuousCDF for Dirac { /// dirac distribution at `x` /// /// Where the value is 1 if x > `v`, 0 otherwise. - /// fn cdf(&self, x: f64) -> f64 { if x < self.0 { 0.0 @@ -69,7 +92,6 @@ impl ContinuousCDF for Dirac { /// dirac distribution at `x` /// /// Where the value is 0 if x > `v`, 1 otherwise. - /// fn sf(&self, x: f64) -> f64 { if x < self.0 { 1.0 @@ -85,7 +107,7 @@ impl Min for Dirac { /// /// # Formula /// - /// ```ignore + /// ```text /// v /// ``` fn min(&self) -> f64 { @@ -99,7 +121,7 @@ impl Max for Dirac { /// /// # Formula /// - /// ```ignore + /// ```text /// v /// ``` fn max(&self) -> f64 { @@ -117,11 +139,12 @@ impl Distribution for Dirac { fn mean(&self) -> Option { Some(self.0) } + /// Returns the variance of the dirac distribution /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` /// @@ -129,11 +152,12 @@ impl Distribution for Dirac { fn variance(&self) -> Option { Some(0.0) } + /// Returns the entropy of the dirac distribution /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` /// @@ -141,11 +165,12 @@ impl Distribution for Dirac { fn entropy(&self) -> Option { Some(0.0) } + /// Returns the skewness of the dirac distribution /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn skewness(&self) -> Option { @@ -158,7 +183,7 @@ impl Median for Dirac { /// /// # Formula /// - /// ```ignore + /// ```text /// v /// ``` /// @@ -173,7 +198,7 @@ impl Mode> for Dirac { /// /// # Formula /// - /// ```ignore + /// ```text /// v /// ``` /// @@ -184,117 +209,95 @@ impl Mode> for Dirac { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use crate::statistics::*; - use crate::distribution::{ContinuousCDF, Continuous, Dirac}; - use crate::consts::ACC; - - fn try_create(v: f64) -> Dirac { - let d = Dirac::new(v); - assert!(d.is_ok()); - d.unwrap() - } + use super::*; + use crate::testing_boiler; - fn create_case(v: f64) { - let d = try_create(v); - assert_eq!(v, d.mean().unwrap()); - } - - fn bad_create_case(v: f64) { - let d = Dirac::new(v); - assert!(d.is_err()); - } - - fn test_case(v: f64, expected: f64, eval: F) - where F: Fn(Dirac) -> f64 - { - let x = eval(try_create(v)); - assert_eq!(expected, x); - } + testing_boiler!(v: f64; Dirac; DiracError); #[test] fn test_create() { - create_case(10.0); - create_case(-5.0); - create_case(10.0); - create_case(100.0); - create_case(f64::INFINITY); + create_ok(10.0); + create_ok(-5.0); + create_ok(10.0); + create_ok(100.0); + create_ok(f64::INFINITY); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN); + create_err(f64::NAN); } #[test] fn test_variance() { let variance = |x: Dirac| x.variance().unwrap(); - test_case(0.0, 0.0, variance); - test_case(-5.0, 0.0, variance); - test_case(f64::INFINITY, 0.0, variance); + test_exact(0.0, 0.0, variance); + test_exact(-5.0, 0.0, variance); + test_exact(f64::INFINITY, 0.0, variance); } #[test] fn test_entropy() { let entropy = |x: Dirac| x.entropy().unwrap(); - test_case(0.0, 0.0, entropy); - test_case(f64::INFINITY, 0.0, entropy); + test_exact(0.0, 0.0, entropy); + test_exact(f64::INFINITY, 0.0, entropy); } #[test] fn test_skewness() { let skewness = |x: Dirac| x.skewness().unwrap(); - test_case(0.0, 0.0, skewness); - test_case(4.0, 0.0, skewness); - test_case(0.3, 0.0, skewness); - test_case(f64::INFINITY, 0.0, skewness); + test_exact(0.0, 0.0, skewness); + test_exact(4.0, 0.0, skewness); + test_exact(0.3, 0.0, skewness); + test_exact(f64::INFINITY, 0.0, skewness); } #[test] fn test_mode() { let mode = |x: Dirac| x.mode().unwrap(); - test_case(0.0, 0.0, mode); - test_case(3.0, 3.0, mode); - test_case(f64::INFINITY, f64::INFINITY, mode); + test_exact(0.0, 0.0, mode); + test_exact(3.0, 3.0, mode); + test_exact(f64::INFINITY, f64::INFINITY, mode); } #[test] fn test_median() { let median = |x: Dirac| x.median(); - test_case(0.0, 0.0, median); - test_case(3.0, 3.0, median); - test_case(f64::INFINITY, f64::INFINITY, median); + test_exact(0.0, 0.0, median); + test_exact(3.0, 3.0, median); + test_exact(f64::INFINITY, f64::INFINITY, median); } #[test] fn test_min_max() { let min = |x: Dirac| x.min(); let max = |x: Dirac| x.max(); - test_case(0.0, 0.0, min); - test_case(3.0, 3.0, min); - test_case(f64::INFINITY, f64::INFINITY, min); + test_exact(0.0, 0.0, min); + test_exact(3.0, 3.0, min); + test_exact(f64::INFINITY, f64::INFINITY, min); - test_case(0.0, 0.0, max); - test_case(3.0, 3.0, max); - test_case(f64::NEG_INFINITY, f64::NEG_INFINITY, max); + test_exact(0.0, 0.0, max); + test_exact(3.0, 3.0, max); + test_exact(f64::NEG_INFINITY, f64::NEG_INFINITY, max); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Dirac| x.cdf(arg); - test_case(0.0, 1.0, cdf(0.0)); - test_case(3.0, 1.0, cdf(3.0)); - test_case(f64::INFINITY, 0.0, cdf(1.0)); - test_case(f64::INFINITY, 1.0, cdf(f64::INFINITY)); + test_exact(0.0, 1.0, cdf(0.0)); + test_exact(3.0, 1.0, cdf(3.0)); + test_exact(f64::INFINITY, 0.0, cdf(1.0)); + test_exact(f64::INFINITY, 1.0, cdf(f64::INFINITY)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Dirac| x.sf(arg); - test_case(0.0, 0.0, sf(0.0)); - test_case(3.0, 0.0, sf(3.0)); - test_case(f64::INFINITY, 1.0, sf(1.0)); - test_case(f64::INFINITY, 0.0, sf(f64::INFINITY)); + test_exact(0.0, 0.0, sf(0.0)); + test_exact(3.0, 0.0, sf(3.0)); + test_exact(f64::INFINITY, 1.0, sf(1.0)); + test_exact(f64::INFINITY, 0.0, sf(f64::INFINITY)); } } diff --git a/src/distribution/dirichlet.rs b/src/distribution/dirichlet.rs index 104a5981..7c7c9913 100644 --- a/src/distribution/dirichlet.rs +++ b/src/distribution/dirichlet.rs @@ -1,13 +1,8 @@ use crate::distribution::Continuous; use crate::function::gamma; +use crate::prec; use crate::statistics::*; -use crate::{prec, Result, StatsError}; -use nalgebra::DMatrix; -use nalgebra::DVector; -use nalgebra::{ - base::allocator::Allocator, base::dimension::DimName, DefaultAllocator, Dim, DimMin, U1, -}; -use rand::Rng; +use nalgebra::{Dim, Dyn, OMatrix, OVector}; use std::f64; /// Implements the @@ -26,11 +21,42 @@ use std::f64; /// assert_eq!(n.mean().unwrap(), DVector::from_vec(vec![1.0 / 6.0, 1.0 / 3.0, 0.5])); /// assert_eq!(n.pdf(&DVector::from_vec(vec![0.33333, 0.33333, 0.33333])), 2.222155556222205); /// ``` -#[derive(Debug, Clone, PartialEq)] -pub struct Dirichlet { - alpha: DVector, +#[derive(Clone, PartialEq, Debug)] +pub struct Dirichlet +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + alpha: OVector, } -impl Dirichlet { + +/// Represents the errors that can occur when creating a [`Dirichlet`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum DirichletError { + /// Alpha contains less than two elements. + AlphaTooShort, + + /// Alpha contains an element that is NaN, infinite, zero or less than zero. + AlphaHasInvalidElements, +} + +impl std::fmt::Display for DirichletError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + DirichletError::AlphaTooShort => write!(f, "Alpha contains less than two elements"), + DirichletError::AlphaHasInvalidElements => write!( + f, + "Alpha contains an element that is NaN, infinite, zero or less than zero" + ), + } + } +} + +impl std::error::Error for DirichletError {} + +impl Dirichlet { /// Constructs a new dirichlet distribution with the given /// concentration parameters (alpha) /// @@ -54,15 +80,8 @@ impl Dirichlet { /// result = Dirichlet::new(alpha_err); /// assert!(result.is_err()); /// ``` - pub fn new(alpha: Vec) -> Result { - if !is_valid_alpha(&alpha) { - Err(StatsError::BadParams) - } else { - // let vec = alpha.to_vec(); - Ok(Dirichlet { - alpha: DVector::from_vec(alpha.to_vec()), - }) - } + pub fn new(alpha: Vec) -> Result { + Self::new_from_nalgebra(alpha.into()) } /// Constructs a new dirichlet distribution with the given @@ -84,9 +103,34 @@ impl Dirichlet { /// result = Dirichlet::new_with_param(0.0, 1); /// assert!(result.is_err()); /// ``` - pub fn new_with_param(alpha: f64, n: usize) -> Result { + pub fn new_with_param(alpha: f64, n: usize) -> Result { Self::new(vec![alpha; n]) } +} + +impl Dirichlet +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + /// Constructs a new distribution with the given vector for `alpha` + /// Does not clone the vector it takes ownership of + /// + /// # Error + /// + /// Returns an error if vector has length less than 2 or if any element + /// of alpha is NOT finite positive + pub fn new_from_nalgebra(alpha: OVector) -> Result { + if alpha.len() < 2 { + return Err(DirichletError::AlphaTooShort); + } + + if alpha.iter().any(|&a_i| !a_i.is_finite() || a_i <= 0.0) { + return Err(DirichletError::AlphaHasInvalidElements); + } + + Ok(Self { alpha }) + } /// Returns the concentration parameters of /// the dirichlet distribution as a slice @@ -100,24 +144,25 @@ impl Dirichlet { /// let n = Dirichlet::new(vec![1.0, 2.0, 3.0]).unwrap(); /// assert_eq!(n.alpha(), &DVector::from_vec(vec![1.0, 2.0, 3.0])); /// ``` - pub fn alpha(&self) -> &DVector { + pub fn alpha(&self) -> &nalgebra::OVector { &self.alpha } fn alpha_sum(&self) -> f64 { - self.alpha.fold(0.0, |acc, x| acc + x) + self.alpha.sum() } + /// Returns the entropy of the dirichlet distribution /// /// # Formula /// - /// ```ignore + /// ```text /// ln(B(α)) - (K - α_0)ψ(α_0) - Σ((α_i - 1)ψ(α_i)) /// ``` /// /// where /// - /// ```ignore + /// ```text /// B(α) = Π(Γ(α_i)) / Γ(Σ(α_i)) /// ``` /// @@ -136,55 +181,77 @@ impl Dirichlet { } } -impl ::rand::distributions::Distribution> for Dirichlet { - fn sample(&self, rng: &mut R) -> DVector { +impl std::fmt::Display for Dirichlet +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Dir({}, {})", self.alpha.len(), &self.alpha) + } +} + +#[cfg(feature = "rand")] +impl ::rand::distributions::Distribution> for Dirichlet +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + fn sample(&self, rng: &mut R) -> OVector { let mut sum = 0.0; - let mut samples: Vec<_> = self - .alpha - .iter() - .map(|&a| { + OVector::from_iterator_generic( + self.alpha.shape_generic().0, + nalgebra::Const::<1>, + self.alpha.iter().map(|&a| { let sample = super::gamma::sample_unchecked(rng, a, 1.0); sum += sample; sample - }) - .collect(); - for _ in samples.iter_mut().map(|x| *x /= sum) {} - DVector::from_vec(samples) + }), + ) } } -impl MeanN> for Dirichlet { +impl MeanN> for Dirichlet +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ /// Returns the means of the dirichlet distribution /// /// # Formula /// - /// ```ignore + /// ```text /// α_i / α_0 /// ``` /// /// for the `i`th element where `α_i` is the `i`th concentration parameter /// and `α_0` is the sum of all concentration parameters - fn mean(&self) -> Option> { + fn mean(&self) -> Option> { let sum = self.alpha_sum(); Some(self.alpha.map(|x| x / sum)) } } -impl VarianceN> for Dirichlet { +impl VarianceN> for Dirichlet +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ /// Returns the variances of the dirichlet distribution /// /// # Formula /// - /// ```ignore + /// ```text /// (α_i * (α_0 - α_i)) / (α_0^2 * (α_0 + 1)) /// ``` /// /// for the `i`th element where `α_i` is the `i`th concentration parameter /// and `α_0` is the sum of all concentration parameters - fn variance(&self) -> Option> { + fn variance(&self) -> Option> { let sum = self.alpha_sum(); let normalizing = sum * sum * (sum + 1.0); - let mut cov = DMatrix::from_diagonal(&self.alpha.map(|x| x * (sum - x) / normalizing)); + let mut cov = OMatrix::from_diagonal(&self.alpha.map(|x| x * (sum - x) / normalizing)); let mut offdiag = |x: usize, y: usize| { let elt = -self.alpha[x] * self.alpha[y] / normalizing; cov[(x, y)] = elt; @@ -199,7 +266,13 @@ impl VarianceN> for Dirichlet { } } -impl<'a> Continuous<&'a DVector, f64> for Dirichlet { +impl<'a, D> Continuous<&'a OVector, f64> for Dirichlet +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator, D>, +{ /// Calculates the probabiliy density function for the dirichlet /// distribution /// with given `x`'s corresponding to the concentration parameters for this @@ -215,13 +288,13 @@ impl<'a> Continuous<&'a DVector, f64> for Dirichlet { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / B(α)) * Π(x_i^(α_i - 1)) /// ``` /// /// where /// - /// ```ignore + /// ```text /// B(α) = Π(Γ(α_i)) / Γ(Σ(α_i)) /// ``` /// @@ -230,7 +303,7 @@ impl<'a> Continuous<&'a DVector, f64> for Dirichlet { /// the `i`th concentration parameter, `Γ` is the gamma function, /// `Π` is the product from `1` to `K`, `Σ` is the sum from `1` to `K`, /// and `K` is the number of concentration parameters - fn pdf(&self, x: &DVector) -> f64 { + fn pdf(&self, x: &OVector) -> f64 { self.ln_pdf(x).exp() } @@ -249,13 +322,13 @@ impl<'a> Continuous<&'a DVector, f64> for Dirichlet { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((1 / B(α)) * Π(x_i^(α_i - 1))) /// ``` /// /// where /// - /// ```ignore + /// ```text /// B(α) = Π(Γ(α_i)) / Γ(Σ(α_i)) /// ``` /// @@ -264,7 +337,7 @@ impl<'a> Continuous<&'a DVector, f64> for Dirichlet { /// the `i`th concentration parameter, `Γ` is the gamma function, /// `Π` is the product from `1` to `K`, `Σ` is the sum from `1` to `K`, /// and `K` is the number of concentration parameters - fn ln_pdf(&self, x: &DVector) -> f64 { + fn ln_pdf(&self, x: &OVector) -> f64 { // TODO: would it be clearer here to just do a for loop instead // of using iterators? if self.alpha.len() != x.len() { @@ -293,84 +366,112 @@ impl<'a> Continuous<&'a DVector, f64> for Dirichlet { } } -// determines if `a` is a valid alpha array -// for the Dirichlet distribution -fn is_valid_alpha(a: &[f64]) -> bool { - a.len() >= 2 && super::internal::is_valid_multinomial(a, false) -} - #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use super::*; - use nalgebra::{DVector}; - use crate::function::gamma; - use crate::statistics::*; - use crate::distribution::{Continuous, Dirichlet}; - use crate::consts::ACC; - #[test] - fn test_is_valid_alpha() { - let invalid = [1.0]; - assert!(!is_valid_alpha(&invalid)); - } + use std::fmt::{Debug, Display}; - fn try_create(alpha: &[f64]) -> Dirichlet + use nalgebra::{dmatrix, dvector, vector, DimMin, OVector}; + + fn try_create(alpha: OVector) -> Dirichlet + where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { - let n = Dirichlet::new(alpha.to_vec()); - assert!(n.is_ok()); - n.unwrap() + let mvn = Dirichlet::new_from_nalgebra(alpha); + assert!(mvn.is_ok()); + mvn.unwrap() } - fn create_case(alpha: &[f64]) + fn bad_create_case(alpha: OVector) + where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { - let n = try_create(alpha); - let a2 = n.alpha(); - for i in 0..alpha.len() { - assert_eq!(alpha[i], a2[i]); - } + let dd = Dirichlet::new_from_nalgebra(alpha); + assert!(dd.is_err()); } - fn bad_create_case(alpha: &[f64]) + fn test_almost(alpha: OVector, expected: T, acc: f64, eval: F) + where + T: Debug + Display + approx::RelativeEq, + F: FnOnce(Dirichlet) -> T, + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, { - let n = Dirichlet::new(alpha.to_vec()); - assert!(n.is_err()); + let dd = try_create(alpha); + let x = eval(dd); + assert_relative_eq!(expected, x, epsilon = acc); } #[test] fn test_create() { - create_case(&[1.0, 2.0, 3.0, 4.0, 5.0]); - create_case(&[0.001, f64::INFINITY, 3756.0]); + try_create(vector![1.0, 2.0]); + try_create(vector![1.0, 2.0, 3.0, 4.0, 5.0]); + assert!(Dirichlet::new(vec![1.0, 2.0, 3.0, 4.0, 5.0]).is_ok()); + // try_create(vector![0.001, f64::INFINITY, 3756.0]); // moved to bad case as this is degenerate } #[test] fn test_bad_create() { - bad_create_case(&[1.0]); - bad_create_case(&[1.0, 2.0, 0.0, 4.0, 5.0]); - bad_create_case(&[1.0, f64::NAN, 3.0, 4.0, 5.0]); - bad_create_case(&[0.0, 0.0, 0.0]); + bad_create_case(vector![1.0, f64::NAN]); + bad_create_case(vector![1.0, 0.0]); + bad_create_case(vector![1.0, f64::INFINITY]); + bad_create_case(vector![-1.0, 2.0]); + bad_create_case(vector![1.0]); + bad_create_case(vector![1.0, 2.0, 0.0, 4.0, 5.0]); + bad_create_case(vector![1.0, f64::NAN, 3.0, 4.0, 5.0]); + bad_create_case(vector![0.0, 0.0, 0.0]); + bad_create_case(vector![0.001, f64::INFINITY, 3756.0]); // moved to bad case as this is degenerate } - // #[test] - // fn test_mean() { - // let n = Dirichlet::new_with_param(0.3, 5).unwrap(); - // let res = n.mean(); - // for x in res { - // assert_eq!(x, 0.3 / 1.5); - // } - // } + #[test] + fn test_mean() { + let mean = |dd: Dirichlet<_>| dd.mean().unwrap(); + + test_almost(vec![0.5; 5].into(), vec![1.0 / 5.0; 5].into(), 1e-15, mean); + + test_almost( + dvector![0.1, 0.2, 0.3, 0.4], + dvector![0.1, 0.2, 0.3, 0.4], + 1e-15, + mean, + ); + + test_almost( + dvector![1.0, 2.0, 3.0, 4.0], + dvector![0.1, 0.2, 0.3, 0.4], + 1e-15, + mean, + ); + } - // #[test] - // fn test_variance() { - // let alpha = [1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]; - // let sum = alpha.iter().fold(0.0, |acc, x| acc + x); - // let n = Dirichlet::new(&alpha).unwrap(); - // let res = n.variance(); - // for i in 1..11 { - // let f = i as f64; - // assert_almost_eq!(res[i-1], f * (sum - f) / (sum * sum * (sum + 1.0)), 1e-15); - // } - // } + #[test] + fn test_variance() { + let variance = |dd: Dirichlet<_>| dd.variance().unwrap(); + + test_almost( + dvector![1.0, 2.0], + dmatrix![0.055555555555555, -0.055555555555555; + -0.055555555555555, 0.055555555555555; + ], + 1e-15, + variance, + ); + + test_almost( + dvector![0.1, 0.2, 0.3, 0.4], + dmatrix![0.045, -0.010, -0.015, -0.020; + -0.010, 0.080, -0.030, -0.040; + -0.015, -0.030, 0.105, -0.060; + -0.020, -0.040, -0.060, 0.120; + ], + 1e-15, + variance, + ); + } // #[test] // fn test_std_dev() { @@ -386,70 +487,100 @@ mod tests { #[test] fn test_entropy() { - let mut n = try_create(&[0.1, 0.3, 0.5, 0.8]); - assert_eq!(n.entropy().unwrap(), -17.46469081094079); - - n = try_create(&[0.1, 0.2, 0.3, 0.4]); - assert_eq!(n.entropy().unwrap(), -21.53881433791513); - } - - macro_rules! dvec { - ($($x:expr),*) => (DVector::from_vec(vec![$($x),*])); + let entropy = |x: Dirichlet<_>| x.entropy().unwrap(); + test_almost( + vector![0.1, 0.3, 0.5, 0.8], + -17.46469081094079, + 1e-30, + entropy, + ); + test_almost( + vector![0.1, 0.2, 0.3, 0.4], + -21.53881433791513, + 1e-30, + entropy, + ); } #[test] fn test_pdf() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - assert_almost_eq!(n.pdf(&dvec![0.01, 0.03, 0.5, 0.46]), 18.77225681167061, 1e-12); - assert_almost_eq!(n.pdf(&dvec![0.1,0.2,0.3,0.4]), 0.8314656481199253, 1e-14); + let pdf = |arg| move |x: Dirichlet<_>| x.pdf(&arg); + test_almost( + vector![0.1, 0.3, 0.5, 0.8], + 18.77225681167061, + 1e-12, + pdf([0.01, 0.03, 0.5, 0.46].into()), + ); + test_almost( + vector![0.1, 0.3, 0.5, 0.8], + 0.8314656481199253, + 1e-14, + pdf([0.1, 0.2, 0.3, 0.4].into()), + ); } #[test] fn test_ln_pdf() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - assert_almost_eq!(n.ln_pdf(&dvec![0.01, 0.03, 0.5, 0.46]), 18.77225681167061f64.ln(), 1e-12); - assert_almost_eq!(n.ln_pdf(&dvec![0.1,0.2,0.3,0.4]), 0.8314656481199253f64.ln(), 1e-14); + let ln_pdf = |arg| move |x: Dirichlet<_>| x.ln_pdf(&arg); + test_almost( + vector![0.1, 0.3, 0.5, 0.8], + 18.77225681167061_f64.ln(), + 1e-12, + ln_pdf([0.01, 0.03, 0.5, 0.46].into()), + ); + test_almost( + vector![0.1, 0.3, 0.5, 0.8], + 0.8314656481199253_f64.ln(), + 1e-14, + ln_pdf([0.1, 0.2, 0.3, 0.4].into()), + ); } #[test] #[should_panic] fn test_pdf_bad_input_length() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - n.pdf(&dvec![0.5]); + let n = try_create(dvector![0.1, 0.3, 0.5, 0.8]); + n.pdf(&dvector![0.5]); } #[test] #[should_panic] fn test_pdf_bad_input_range() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - n.pdf(&dvec![1.5, 0.0, 0.0, 0.0]); + let n = try_create(vector![0.1, 0.3, 0.5, 0.8]); + n.pdf(&vector![1.5, 0.0, 0.0, 0.0]); } #[test] #[should_panic] fn test_pdf_bad_input_sum() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - n.pdf(&dvec![0.5, 0.25, 0.8, 0.9]); + let n = try_create(vector![0.1, 0.3, 0.5, 0.8]); + n.pdf(&vector![0.5, 0.25, 0.8, 0.9]); } #[test] #[should_panic] fn test_ln_pdf_bad_input_length() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - n.ln_pdf(&dvec![0.5]); + let n = try_create(dvector![0.1, 0.3, 0.5, 0.8]); + n.ln_pdf(&dvector![0.5]); } #[test] #[should_panic] fn test_ln_pdf_bad_input_range() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - n.ln_pdf(&dvec![1.5, 0.0, 0.0, 0.0]); + let n = try_create(vector![0.1, 0.3, 0.5, 0.8]); + n.ln_pdf(&vector![1.5, 0.0, 0.0, 0.0]); } #[test] #[should_panic] fn test_ln_pdf_bad_input_sum() { - let n = try_create(&[0.1, 0.3, 0.5, 0.8]); - n.ln_pdf(&dvec![0.5, 0.25, 0.8, 0.9]); + let n = try_create(vector![0.1, 0.3, 0.5, 0.8]); + n.ln_pdf(&vector![0.5, 0.25, 0.8, 0.9]); + } + + #[test] + fn test_error_is_sync_send() { + fn assert_sync_send() {} + assert_sync_send::(); } } diff --git a/src/distribution/discrete_uniform.rs b/src/distribution/discrete_uniform.rs index c151318f..85b26090 100644 --- a/src/distribution/discrete_uniform.rs +++ b/src/distribution/discrete_uniform.rs @@ -1,7 +1,5 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; -use rand::Rng; /// Implements the [Discrete /// Uniform](https://en.wikipedia.org/wiki/Discrete_uniform_distribution) @@ -17,12 +15,31 @@ use rand::Rng; /// assert_eq!(n.mean().unwrap(), 2.5); /// assert_eq!(n.pmf(3), 1.0 / 6.0); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct DiscreteUniform { min: i64, max: i64, } +/// Represents the errors that can occur when creating a [`DiscreteUniform`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum DiscreteUniformError { + /// The maximum is less than the minimum. + MinMaxInvalid, +} + +impl std::fmt::Display for DiscreteUniformError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + DiscreteUniformError::MinMaxInvalid => write!(f, "Maximum is less than minimum"), + } + } +} + +impl std::error::Error for DiscreteUniformError {} + impl DiscreteUniform { /// Constructs a new discrete uniform distribution with a minimum value /// of `min` and a maximum value of `max`. @@ -42,17 +59,24 @@ impl DiscreteUniform { /// result = DiscreteUniform::new(5, 0); /// assert!(result.is_err()); /// ``` - pub fn new(min: i64, max: i64) -> Result { + pub fn new(min: i64, max: i64) -> Result { if max < min { - Err(StatsError::BadParams) + Err(DiscreteUniformError::MinMaxInvalid) } else { Ok(DiscreteUniform { min, max }) } } } +impl std::fmt::Display for DiscreteUniform { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Uni([{}, {}])", self.min, self.max) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for DiscreteUniform { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { rng.gen_range(self.min..=self.max) as f64 } } @@ -63,7 +87,7 @@ impl DiscreteCDF for DiscreteUniform { /// /// # Formula /// - /// ```ignore + /// ```text /// (floor(x) - min + 1) / (max - min + 1) /// ``` fn cdf(&self, x: i64) -> f64 { @@ -84,7 +108,7 @@ impl DiscreteCDF for DiscreteUniform { } fn sf(&self, x: i64) -> f64 { - //1. - self.cdf(x) + // 1. - self.cdf(x) if x < self.min { 1.0 } else if x >= self.max { @@ -131,39 +155,42 @@ impl Distribution for DiscreteUniform { /// /// # Formula /// - /// ```ignore + /// ```text /// (min + max) / 2 /// ``` fn mean(&self) -> Option { Some((self.min + self.max) as f64 / 2.0) } + /// Returns the variance of the discrete uniform distribution /// /// # Formula /// - /// ```ignore + /// ```text /// ((max - min + 1)^2 - 1) / 12 /// ``` fn variance(&self) -> Option { let diff = (self.max - self.min) as f64; Some(((diff + 1.0) * (diff + 1.0) - 1.0) / 12.0) } + /// Returns the entropy of the discrete uniform distribution /// /// # Formula /// - /// ```ignore + /// ```text /// ln(max - min + 1) /// ``` fn entropy(&self) -> Option { let diff = (self.max - self.min) as f64; Some((diff + 1.0).ln()) } + /// Returns the skewness of the discrete uniform distribution /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn skewness(&self) -> Option { @@ -176,7 +203,7 @@ impl Median for DiscreteUniform { /// /// # Formula /// - /// ```ignore + /// ```text /// (max + min) / 2 /// ``` fn median(&self) -> f64 { @@ -194,7 +221,7 @@ impl Mode> for DiscreteUniform { /// /// # Formula /// - /// ```ignore + /// ```text /// N/A // (max + min) / 2 for the middle element /// ``` fn mode(&self) -> Option { @@ -212,7 +239,7 @@ impl Discrete for DiscreteUniform { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 / (max - min + 1) /// ``` fn pmf(&self, x: i64) -> f64 { @@ -232,7 +259,7 @@ impl Discrete for DiscreteUniform { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(1 / (max - min + 1)) /// ``` fn ln_pmf(&self, x: i64) -> f64 { @@ -245,167 +272,134 @@ impl Discrete for DiscreteUniform { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use std::fmt::Debug; - use crate::statistics::*; - use crate::distribution::{DiscreteCDF, Discrete, DiscreteUniform}; - use crate::consts::ACC; - - fn try_create(min: i64, max: i64) -> DiscreteUniform { - let n = DiscreteUniform::new(min, max); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(min: i64, max: i64) { - let n = try_create(min, max); - assert_eq!(min, n.min()); - assert_eq!(max, n.max()); - } + use super::*; + use crate::testing_boiler; - fn bad_create_case(min: i64, max: i64) { - let n = DiscreteUniform::new(min, max); - assert!(n.is_err()); - } - - fn get_value(min: i64, max: i64, eval: F) -> T - where T: PartialEq + Debug, - F: Fn(DiscreteUniform) -> T - { - let n = try_create(min, max); - eval(n) - } - - fn test_case(min: i64, max: i64, expected: T, eval: F) - where T: PartialEq + Debug, - F: Fn(DiscreteUniform) -> T - { - let x = get_value(min, max, eval); - assert_eq!(expected, x); - } + testing_boiler!(min: i64, max: i64; DiscreteUniform; DiscreteUniformError); #[test] fn test_create() { - create_case(-10, 10); - create_case(0, 4); - create_case(10, 20); - create_case(20, 20); + create_ok(-10, 10); + create_ok(0, 4); + create_ok(10, 20); + create_ok(20, 20); } #[test] fn test_bad_create() { - bad_create_case(-1, -2); - bad_create_case(6, 5); + create_err(-1, -2); + create_err(6, 5); } #[test] fn test_mean() { let mean = |x: DiscreteUniform| x.mean().unwrap(); - test_case(-10, 10, 0.0, mean); - test_case(0, 4, 2.0, mean); - test_case(10, 20, 15.0, mean); - test_case(20, 20, 20.0, mean); + test_exact(-10, 10, 0.0, mean); + test_exact(0, 4, 2.0, mean); + test_exact(10, 20, 15.0, mean); + test_exact(20, 20, 20.0, mean); } #[test] fn test_variance() { let variance = |x: DiscreteUniform| x.variance().unwrap(); - test_case(-10, 10, 36.66666666666666666667, variance); - test_case(0, 4, 2.0, variance); - test_case(10, 20, 10.0, variance); - test_case(20, 20, 0.0, variance); + test_exact(-10, 10, 36.66666666666666666667, variance); + test_exact(0, 4, 2.0, variance); + test_exact(10, 20, 10.0, variance); + test_exact(20, 20, 0.0, variance); } #[test] fn test_entropy() { let entropy = |x: DiscreteUniform| x.entropy().unwrap(); - test_case(-10, 10, 3.0445224377234229965005979803657054342845752874046093, entropy); - test_case(0, 4, 1.6094379124341003746007593332261876395256013542685181, entropy); - test_case(10, 20, 2.3978952727983705440619435779651292998217068539374197, entropy); - test_case(20, 20, 0.0, entropy); + test_exact(-10, 10, 3.0445224377234229965005979803657054342845752874046093, entropy); + test_exact(0, 4, 1.6094379124341003746007593332261876395256013542685181, entropy); + test_exact(10, 20, 2.3978952727983705440619435779651292998217068539374197, entropy); + test_exact(20, 20, 0.0, entropy); } #[test] fn test_skewness() { let skewness = |x: DiscreteUniform| x.skewness().unwrap(); - test_case(-10, 10, 0.0, skewness); - test_case(0, 4, 0.0, skewness); - test_case(10, 20, 0.0, skewness); - test_case(20, 20, 0.0, skewness); + test_exact(-10, 10, 0.0, skewness); + test_exact(0, 4, 0.0, skewness); + test_exact(10, 20, 0.0, skewness); + test_exact(20, 20, 0.0, skewness); } #[test] fn test_median() { let median = |x: DiscreteUniform| x.median(); - test_case(-10, 10, 0.0, median); - test_case(0, 4, 2.0, median); - test_case(10, 20, 15.0, median); - test_case(20, 20, 20.0, median); + test_exact(-10, 10, 0.0, median); + test_exact(0, 4, 2.0, median); + test_exact(10, 20, 15.0, median); + test_exact(20, 20, 20.0, median); } #[test] fn test_mode() { let mode = |x: DiscreteUniform| x.mode().unwrap(); - test_case(-10, 10, 0, mode); - test_case(0, 4, 2, mode); - test_case(10, 20, 15, mode); - test_case(20, 20, 20, mode); + test_exact(-10, 10, 0, mode); + test_exact(0, 4, 2, mode); + test_exact(10, 20, 15, mode); + test_exact(20, 20, 20, mode); } #[test] fn test_pmf() { let pmf = |arg: i64| move |x: DiscreteUniform| x.pmf(arg); - test_case(-10, 10, 0.04761904761904761904762, pmf(-5)); - test_case(-10, 10, 0.04761904761904761904762, pmf(1)); - test_case(-10, 10, 0.04761904761904761904762, pmf(10)); - test_case(-10, -10, 0.0, pmf(0)); - test_case(-10, -10, 1.0, pmf(-10)); + test_exact(-10, 10, 0.04761904761904761904762, pmf(-5)); + test_exact(-10, 10, 0.04761904761904761904762, pmf(1)); + test_exact(-10, 10, 0.04761904761904761904762, pmf(10)); + test_exact(-10, -10, 0.0, pmf(0)); + test_exact(-10, -10, 1.0, pmf(-10)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: i64| move |x: DiscreteUniform| x.ln_pmf(arg); - test_case(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(-5)); - test_case(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(1)); - test_case(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(10)); - test_case(-10, -10, f64::NEG_INFINITY, ln_pmf(0)); - test_case(-10, -10, 0.0, ln_pmf(-10)); + test_exact(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(-5)); + test_exact(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(1)); + test_exact(-10, 10, -3.0445224377234229965005979803657054342845752874046093, ln_pmf(10)); + test_exact(-10, -10, f64::NEG_INFINITY, ln_pmf(0)); + test_exact(-10, -10, 0.0, ln_pmf(-10)); } #[test] fn test_cdf() { let cdf = |arg: i64| move |x: DiscreteUniform| x.cdf(arg); - test_case(-10, 10, 0.2857142857142857142857, cdf(-5)); - test_case(-10, 10, 0.5714285714285714285714, cdf(1)); - test_case(-10, 10, 1.0, cdf(10)); - test_case(-10, -10, 1.0, cdf(-10)); + test_exact(-10, 10, 0.2857142857142857142857, cdf(-5)); + test_exact(-10, 10, 0.5714285714285714285714, cdf(1)); + test_exact(-10, 10, 1.0, cdf(10)); + test_exact(-10, -10, 1.0, cdf(-10)); } #[test] fn test_sf() { let sf = |arg: i64| move |x: DiscreteUniform| x.sf(arg); - test_case(-10, 10, 0.7142857142857142857143, sf(-5)); - test_case(-10, 10, 0.42857142857142855, sf(1)); - test_case(-10, 10, 0.0, sf(10)); - test_case(-10, -10, 0.0, sf(-10)); + test_exact(-10, 10, 0.7142857142857142857143, sf(-5)); + test_exact(-10, 10, 0.42857142857142855, sf(1)); + test_exact(-10, 10, 0.0, sf(10)); + test_exact(-10, -10, 0.0, sf(-10)); } #[test] fn test_cdf_lower_bound() { let cdf = |arg: i64| move |x: DiscreteUniform| x.cdf(arg); - test_case(0, 3, 0.0, cdf(-1)); + test_exact(0, 3, 0.0, cdf(-1)); } #[test] fn test_sf_lower_bound() { let sf = |arg: i64| move |x: DiscreteUniform| x.sf(arg); - test_case(0, 3, 1.0, sf(-1)); + test_exact(0, 3, 1.0, sf(-1)); } #[test] fn test_cdf_upper_bound() { let cdf = |arg: i64| move |x: DiscreteUniform| x.cdf(arg); - test_case(0, 3, 1.0, cdf(5)); + test_exact(0, 3, 1.0, cdf(5)); } } diff --git a/src/distribution/empirical.rs b/src/distribution/empirical.rs index 5804f7c1..965d8c7f 100644 --- a/src/distribution/empirical.rs +++ b/src/distribution/empirical.rs @@ -1,25 +1,22 @@ -use crate::distribution::{Continuous, ContinuousCDF, Uniform}; +use crate::distribution::ContinuousCDF; use crate::statistics::*; -use crate::{Result, StatsError}; -use ::num_traits::float::Float; use core::cmp::Ordering; -use rand::Rng; use std::collections::BTreeMap; -#[derive(Clone, Debug, PartialEq)] -struct NonNAN(T); +#[derive(Clone, PartialEq, Debug)] +struct NonNan(T); -impl Eq for NonNAN {} +impl Eq for NonNan {} -impl PartialOrd for NonNAN { +impl PartialOrd for NonNan { fn partial_cmp(&self, other: &Self) -> Option { - self.0.partial_cmp(&other.0) + Some(self.cmp(other)) } } -impl Ord for NonNAN { +impl Ord for NonNan { fn cmp(&self, other: &Self) -> Ordering { - self.partial_cmp(other).unwrap() + self.0.partial_cmp(&other.0).unwrap() } } @@ -37,18 +34,20 @@ impl Ord for NonNAN { /// let empirical = Empirical::from_vec(samples); /// assert_eq!(empirical.mean().unwrap(), 5.0); /// ``` -#[derive(Debug, Clone, PartialEq)] +#[derive(Clone, PartialEq, Debug)] pub struct Empirical { sum: f64, mean_and_var: Option<(f64, f64)>, // keys are data points, values are number of data points with equal value - data: BTreeMap, u64>, + data: BTreeMap, u64>, } impl Empirical { /// Constructs a new discrete uniform distribution with a minimum value /// of `min` and a maximum value of `max`. /// + /// Note that this will always succeed and never return the [`Err`][Result::Err] variant. + /// /// # Examples /// /// ``` @@ -56,15 +55,16 @@ impl Empirical { /// /// let mut result = Empirical::new(); /// assert!(result.is_ok()); - /// /// ``` - pub fn new() -> Result { + #[allow(clippy::result_unit_err)] + pub fn new() -> Result { Ok(Empirical { sum: 0., mean_and_var: None, data: BTreeMap::new(), }) } + pub fn from_vec(src: Vec) -> Empirical { let mut empirical = Empirical::new().unwrap(); for elt in src.into_iter() { @@ -72,6 +72,7 @@ impl Empirical { } empirical } + pub fn add(&mut self, data_point: f64) { if !data_point.is_nan() { self.sum += 1.; @@ -86,13 +87,14 @@ impl Empirical { self.mean_and_var = Some((data_point, 0.)); } } - *self.data.entry(NonNAN(data_point)).or_insert(0) += 1; + *self.data.entry(NonNan(data_point)).or_insert(0) += 1; } } + pub fn remove(&mut self, data_point: f64) { if !data_point.is_nan() { if let (Some(val), Some((mean, var))) = - (self.data.remove(&NonNAN(data_point)), self.mean_and_var) + (self.data.remove(&NonNan(data_point)), self.mean_and_var) { if val == 1 && self.data.is_empty() { self.mean_and_var = None; @@ -105,12 +107,13 @@ impl Empirical { var - (self.sum - 1.) * (data_point - mean) * (data_point - mean) / self.sum; self.sum -= 1.; if val != 1 { - self.data.insert(NonNAN(data_point), val - 1); + self.data.insert(NonNan(data_point), val - 1); }; self.mean_and_var = Some((mean, var)); } } } + // Due to issues with rounding and floating-point accuracy the default // implementation may be ill-behaved. // Specialized inverse cdfs should be used whenever possible. @@ -148,8 +151,35 @@ impl Empirical { } } +impl std::fmt::Display for Empirical { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some((&NonNan(x), _)) = self.data.first_key_value() { + write!(f, "Empirical([{:.3e}", x)?; + } else { + return write!(f, "Empirical(∅)"); + } + + let mut enumerated_values = self + .data + .iter() + .flat_map(|(&NonNan(x), &count)| std::iter::repeat(x).take(count as usize)) + .skip(1); + + for val in enumerated_values.by_ref().take(4) { + write!(f, ", {:.3e}", val)?; + } + if enumerated_values.next().is_some() { + write!(f, ", ...")?; + } + write!(f, "])") + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Empirical { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { + use crate::distribution::Uniform; + let uniform = Uniform::new(0.0, 1.0).unwrap(); self.__inverse_cdf(uniform.sample(rng)) } @@ -158,14 +188,14 @@ impl ::rand::distributions::Distribution for Empirical { /// Panics if number of samples is zero impl Max for Empirical { fn max(&self) -> f64 { - self.data.iter().rev().map(|(key, _)| key.0).next().unwrap() + self.data.keys().rev().map(|key| key.0).next().unwrap() } } /// Panics if number of samples is zero impl Min for Empirical { fn min(&self) -> f64 { - self.data.iter().map(|(key, _)| key.0).next().unwrap() + self.data.keys().map(|key| key.0).next().unwrap() } } @@ -173,6 +203,7 @@ impl Distribution for Empirical { fn mean(&self) -> Option { self.mean_and_var.map(|(mean, _)| mean) } + fn variance(&self) -> Option { self.mean_and_var.map(|(_, var)| var / (self.sum - 1.)) } @@ -200,9 +231,13 @@ impl ContinuousCDF for Empirical { } sum as f64 / self.sum } + + fn inverse_cdf(&self, p: f64) -> f64 { + self.__inverse_cdf(p) + } } -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use super::*; #[test] @@ -256,8 +291,35 @@ mod tests { let unchanged = empirical.clone(); empirical.add(2.0); empirical.remove(2.0); - //because of rounding errors, this doesn't hold in general - //due to the mean and variance being calculated in a streaming way + // because of rounding errors, this doesn't hold in general + // due to the mean and variance being calculated in a streaming way assert_eq!(unchanged, empirical); } + + #[test] + fn test_display() { + let mut e = Empirical::new().unwrap(); + assert_eq!(e.to_string(), "Empirical(∅)"); + e.add(1.0); + assert_eq!(e.to_string(), "Empirical([1.000e0])"); + e.add(1.0); + assert_eq!(e.to_string(), "Empirical([1.000e0, 1.000e0])"); + e.add(2.0); + assert_eq!(e.to_string(), "Empirical([1.000e0, 1.000e0, 2.000e0])"); + e.add(2.0); + assert_eq!( + e.to_string(), + "Empirical([1.000e0, 1.000e0, 2.000e0, 2.000e0])" + ); + e.add(5.0); + assert_eq!( + e.to_string(), + "Empirical([1.000e0, 1.000e0, 2.000e0, 2.000e0, 5.000e0])" + ); + e.add(5.0); + assert_eq!( + e.to_string(), + "Empirical([1.000e0, 1.000e0, 2.000e0, 2.000e0, 5.000e0, ...])" + ); + } } diff --git a/src/distribution/erlang.rs b/src/distribution/erlang.rs index 619ba698..2ad017f3 100644 --- a/src/distribution/erlang.rs +++ b/src/distribution/erlang.rs @@ -1,7 +1,5 @@ -use crate::distribution::{Continuous, ContinuousCDF, Gamma}; +use crate::distribution::{Continuous, ContinuousCDF, Gamma, GammaError}; use crate::statistics::*; -use crate::Result; -use rand::Rng; /// Implements the [Erlang](https://en.wikipedia.org/wiki/Erlang_distribution) /// distribution @@ -20,7 +18,7 @@ use rand::Rng; /// assert_eq!(n.mean().unwrap(), 3.0); /// assert!(prec::almost_eq(n.pdf(2.0), 0.270670566473225383788, 1e-15)); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Erlang { g: Gamma, } @@ -45,7 +43,7 @@ impl Erlang { /// result = Erlang::new(0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(shape: u64, rate: f64) -> Result { + pub fn new(shape: u64, rate: f64) -> Result { Gamma::new(shape as f64, rate).map(|g| Erlang { g }) } @@ -78,8 +76,15 @@ impl Erlang { } } +impl std::fmt::Display for Erlang { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "E({}, {})", self.rate(), self.shape()) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Erlang { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { ::rand::distributions::Distribution::sample(&self.g, rng) } } @@ -91,7 +96,7 @@ impl ContinuousCDF for Erlang { /// /// # Formula /// - /// ```ignore + /// ```text /// γ(k, λx) (k - 1)! /// ``` /// @@ -107,7 +112,7 @@ impl ContinuousCDF for Erlang { /// /// # Formula /// - /// ```ignore + /// ```text /// γ(k, λx) (k - 1)! /// ``` /// @@ -116,6 +121,21 @@ impl ContinuousCDF for Erlang { fn sf(&self, x: f64) -> f64 { self.g.sf(x) } + + /// Calculates the inverse cumulative distribution function for the erlang + /// distribution at `x` + /// + /// # Formula + /// + /// ```text + /// γ^{-1}(k, (k - 1)! x) / λ + /// ``` + /// + /// where `k` is the shape, `λ` is the rate, and `γ` is the upper + /// incomplete gamma function + fn inverse_cdf(&self, p: f64) -> f64 { + self.g.inverse_cdf(p) + } } impl Min for Erlang { @@ -125,7 +145,7 @@ impl Min for Erlang { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> f64 { @@ -140,8 +160,8 @@ impl Max for Erlang { /// /// # Formula /// - /// ```ignore - /// INF + /// ```text + /// f64::INFINITY /// ``` fn max(&self) -> f64 { self.g.max() @@ -158,7 +178,7 @@ impl Distribution for Erlang { /// /// # Formula /// - /// ```ignore + /// ```text /// k / λ /// ``` /// @@ -166,11 +186,12 @@ impl Distribution for Erlang { fn mean(&self) -> Option { self.g.mean() } + /// Returns the variance of the erlang distribution /// /// # Formula /// - /// ```ignore + /// ```text /// k / λ^2 /// ``` /// @@ -178,11 +199,12 @@ impl Distribution for Erlang { fn variance(&self) -> Option { self.g.variance() } + /// Returns the entropy of the erlang distribution /// /// # Formula /// - /// ```ignore + /// ```text /// k - ln(λ) + ln(Γ(k)) + (1 - k) * ψ(k) /// ``` /// @@ -191,11 +213,12 @@ impl Distribution for Erlang { fn entropy(&self) -> Option { self.g.entropy() } + /// Returns the skewness of the erlang distribution /// /// # Formula /// - /// ```ignore + /// ```text /// 2 / sqrt(k) /// ``` /// @@ -215,7 +238,7 @@ impl Mode> for Erlang { /// /// # Formula /// - /// ```ignore + /// ```text /// (k - 1) / λ /// ``` /// @@ -236,7 +259,7 @@ impl Continuous for Erlang { /// /// # Formula /// - /// ```ignore + /// ```text /// (λ^k / Γ(k)) * x^(k - 1) * e^(-λ * x) /// ``` /// @@ -256,7 +279,7 @@ impl Continuous for Erlang { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((λ^k / Γ(k)) * x^(k - 1) * e ^(-λ * x)) /// ``` /// @@ -267,50 +290,41 @@ impl Continuous for Erlang { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use crate::distribution::Erlang; + use super::*; use crate::distribution::internal::*; - use crate::consts::ACC; - - fn try_create(shape: u64, rate: f64) -> Erlang { - let n = Erlang::new(shape, rate); - assert!(n.is_ok()); - n.unwrap() - } + use crate::testing_boiler; - fn create_case(shape: u64, rate: f64) { - let n = try_create(shape, rate); - assert_eq!(shape, n.shape()); - assert_eq!(rate, n.rate()); - } - - fn bad_create_case(shape: u64, rate: f64) { - let n = Erlang::new(shape, rate); - assert!(n.is_err()); - } + testing_boiler!(shape: u64, rate: f64; Erlang; GammaError); #[test] fn test_create() { - create_case(1, 0.1); - create_case(1, 1.0); - create_case(10, 10.0); - create_case(10, 1.0); - create_case(10, f64::INFINITY); + create_ok(1, 0.1); + create_ok(1, 1.0); + create_ok(10, 10.0); + create_ok(10, 1.0); + create_ok(10, f64::INFINITY); } #[test] fn test_bad_create() { - bad_create_case(0, 1.0); - bad_create_case(1, 0.0); - bad_create_case(1, f64::NAN); - bad_create_case(1, -1.0); + let invalid = [ + (0, 1.0, GammaError::ShapeInvalid), + (1, 0.0, GammaError::RateInvalid), + (1, f64::NAN, GammaError::RateInvalid), + (1, -1.0, GammaError::RateInvalid), + ]; + + for (s, r, err) in invalid { + test_create_err(s, r, err); + } } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(1, 2.5), 0.0, 20.0); - test::check_continuous_distribution(&try_create(2, 1.5), 0.0, 20.0); - test::check_continuous_distribution(&try_create(3, 0.5), 0.0, 20.0); + test::check_continuous_distribution(&create_ok(1, 2.5), 0.0, 20.0); + test::check_continuous_distribution(&create_ok(2, 1.5), 0.0, 20.0); + test::check_continuous_distribution(&create_ok(3, 0.5), 0.0, 20.0); } } diff --git a/src/distribution/exponential.rs b/src/distribution/exponential.rs index 890592d8..9c6c21fc 100644 --- a/src/distribution/exponential.rs +++ b/src/distribution/exponential.rs @@ -1,7 +1,5 @@ -use crate::distribution::{ziggurat, Continuous, ContinuousCDF}; +use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; -use rand::Rng; use std::f64; /// Implements the @@ -20,18 +18,37 @@ use std::f64; /// assert_eq!(n.mean().unwrap(), 1.0); /// assert_eq!(n.pdf(1.0), 0.3678794411714423215955); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Exp { rate: f64, } +/// Represents the errors that can occur when creating a [`Exp`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum ExpError { + /// The rate is NaN, zero or less than zero. + RateInvalid, +} + +impl std::fmt::Display for ExpError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + ExpError::RateInvalid => write!(f, "Rate is NaN, zero or less than zero"), + } + } +} + +impl std::error::Error for ExpError {} + impl Exp { /// Constructs a new exponential distribution with a /// rate (λ) of `rate`. /// /// # Errors /// - /// Returns an error if rate is `NaN` or `rate <= 0.0` + /// Returns an error if rate is `NaN` or `rate <= 0.0`. /// /// # Examples /// @@ -44,9 +61,9 @@ impl Exp { /// result = Exp::new(-1.0); /// assert!(result.is_err()); /// ``` - pub fn new(rate: f64) -> Result { + pub fn new(rate: f64) -> Result { if rate.is_nan() || rate <= 0.0 { - Err(StatsError::BadParams) + Err(ExpError::RateInvalid) } else { Ok(Exp { rate }) } @@ -67,8 +84,17 @@ impl Exp { } } +impl std::fmt::Display for Exp { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Exp({})", self.rate) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Exp { - fn sample(&self, r: &mut R) -> f64 { + fn sample(&self, r: &mut R) -> f64 { + use crate::distribution::ziggurat; + ziggurat::sample_exp_1(r) / self.rate } } @@ -79,7 +105,7 @@ impl ContinuousCDF for Exp { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 - e^(-λ * x) /// ``` /// @@ -97,7 +123,7 @@ impl ContinuousCDF for Exp { /// /// # Formula /// - /// ```ignore + /// ```text /// e^(-λ * x) /// ``` /// @@ -109,6 +135,19 @@ impl ContinuousCDF for Exp { (-self.rate * x).exp() } } + + /// Calculates the inverse cumulative distribution function. + /// + /// # Formula + /// + /// ```text + /// -ln(1 - p) / λ + /// ``` + /// + /// where `p` is the probability and `λ` is the rate + fn inverse_cdf(&self, p: f64) -> f64 { + -(-p).ln_1p() / self.rate + } } impl Min for Exp { @@ -117,7 +156,7 @@ impl Min for Exp { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> f64 { @@ -131,8 +170,8 @@ impl Max for Exp { /// /// # Formula /// - /// ```ignore - /// INF + /// ```text + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY @@ -144,7 +183,7 @@ impl Distribution for Exp { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 / λ /// ``` /// @@ -152,11 +191,12 @@ impl Distribution for Exp { fn mean(&self) -> Option { Some(1.0 / self.rate) } + /// Returns the variance of the exponential distribution /// /// # Formula /// - /// ```ignore + /// ```text /// 1 / λ^2 /// ``` /// @@ -164,11 +204,12 @@ impl Distribution for Exp { fn variance(&self) -> Option { Some(1.0 / (self.rate * self.rate)) } + /// Returns the entropy of the exponential distribution /// /// # Formula /// - /// ```ignore + /// ```text /// 1 - ln(λ) /// ``` /// @@ -176,11 +217,12 @@ impl Distribution for Exp { fn entropy(&self) -> Option { Some(1.0 - self.rate.ln()) } + /// Returns the skewness of the exponential distribution /// /// # Formula /// - /// ```ignore + /// ```text /// 2 /// ``` fn skewness(&self) -> Option { @@ -193,7 +235,7 @@ impl Median for Exp { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / λ) * ln2 /// ``` /// @@ -208,7 +250,7 @@ impl Mode> for Exp { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn mode(&self) -> Option { @@ -222,7 +264,7 @@ impl Continuous for Exp { /// /// # Formula /// - /// ```ignore + /// ```text /// λ * e^(-λ * x) /// ``` /// @@ -240,7 +282,7 @@ impl Continuous for Exp { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(λ * e^(-λ * x)) /// ``` /// @@ -255,237 +297,214 @@ impl Continuous for Exp { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use std::f64; - use crate::statistics::*; - use crate::distribution::{ContinuousCDF, Continuous, Exp}; + use super::*; use crate::distribution::internal::*; - use crate::consts::ACC; + use crate::testing_boiler; - fn try_create(rate: f64) -> Exp { - let n = Exp::new(rate); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(rate: f64) { - let n = try_create(rate); - assert_eq!(rate, n.rate()); - } - - fn bad_create_case(rate: f64) { - let n = Exp::new(rate); - assert!(n.is_err()); - } - - fn get_value(rate: f64, eval: F) -> f64 - where F: Fn(Exp) -> f64 - { - let n = try_create(rate); - eval(n) - } - - fn test_case(rate: f64, expected: f64, eval: F) - where F: Fn(Exp) -> f64 - { - let x = get_value(rate, eval); - assert_eq!(expected, x); - } - - fn test_almost(rate: f64, expected: f64, acc: f64, eval: F) - where F: Fn(Exp) -> f64 - { - let x = get_value(rate, eval); - assert_almost_eq!(expected, x, acc); - } - - fn test_is_nan(rate: f64, eval: F) - where F : Fn(Exp) -> f64 - { - let x = get_value(rate, eval); - assert!(x.is_nan()); - } + testing_boiler!(rate: f64; Exp; ExpError); #[test] fn test_create() { - create_case(0.1); - create_case(1.0); - create_case(10.0); + create_ok(0.1); + create_ok(1.0); + create_ok(10.0); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN); - bad_create_case(0.0); - bad_create_case(-1.0); - bad_create_case(-10.0); + create_err(f64::NAN); + create_err(0.0); + create_err(-1.0); + create_err(-10.0); } #[test] fn test_mean() { let mean = |x: Exp| x.mean().unwrap(); - test_case(0.1, 10.0, mean); - test_case(1.0, 1.0, mean); - test_case(10.0, 0.1, mean); + test_exact(0.1, 10.0, mean); + test_exact(1.0, 1.0, mean); + test_exact(10.0, 0.1, mean); } #[test] fn test_variance() { let variance = |x: Exp| x.variance().unwrap(); - test_almost(0.1, 100.0, 1e-13, variance); - test_case(1.0, 1.0, variance); - test_case(10.0, 0.01, variance); + test_absolute(0.1, 100.0, 1e-13, variance); + test_exact(1.0, 1.0, variance); + test_exact(10.0, 0.01, variance); } #[test] fn test_entropy() { let entropy = |x: Exp| x.entropy().unwrap(); - test_almost(0.1, 3.302585092994045684018, 1e-15, entropy); - test_case(1.0, 1.0, entropy); - test_almost(10.0, -1.302585092994045684018, 1e-15, entropy); + test_absolute(0.1, 3.302585092994045684018, 1e-15, entropy); + test_exact(1.0, 1.0, entropy); + test_absolute(10.0, -1.302585092994045684018, 1e-15, entropy); } #[test] fn test_skewness() { let skewness = |x: Exp| x.skewness().unwrap(); - test_case(0.1, 2.0, skewness); - test_case(1.0, 2.0, skewness); - test_case(10.0, 2.0, skewness); + test_exact(0.1, 2.0, skewness); + test_exact(1.0, 2.0, skewness); + test_exact(10.0, 2.0, skewness); } #[test] fn test_median() { let median = |x: Exp| x.median(); - test_almost(0.1, 6.931471805599453094172, 1e-15, median); - test_case(1.0, f64::consts::LN_2, median); - test_case(10.0, 0.06931471805599453094172, median); + test_absolute(0.1, 6.931471805599453094172, 1e-15, median); + test_exact(1.0, f64::consts::LN_2, median); + test_exact(10.0, 0.06931471805599453094172, median); } #[test] fn test_mode() { let mode = |x: Exp| x.mode().unwrap(); - test_case(0.1, 0.0, mode); - test_case(1.0, 0.0, mode); - test_case(10.0, 0.0, mode); + test_exact(0.1, 0.0, mode); + test_exact(1.0, 0.0, mode); + test_exact(10.0, 0.0, mode); } #[test] fn test_min_max() { let min = |x: Exp| x.min(); let max = |x: Exp| x.max(); - test_case(0.1, 0.0, min); - test_case(1.0, 0.0, min); - test_case(10.0, 0.0, min); - test_case(0.1, f64::INFINITY, max); - test_case(1.0, f64::INFINITY, max); - test_case(10.0, f64::INFINITY, max); + test_exact(0.1, 0.0, min); + test_exact(1.0, 0.0, min); + test_exact(10.0, 0.0, min); + test_exact(0.1, f64::INFINITY, max); + test_exact(1.0, f64::INFINITY, max); + test_exact(10.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Exp| x.pdf(arg); - test_case(0.1, 0.1, pdf(0.0)); - test_case(1.0, 1.0, pdf(0.0)); - test_case(10.0, 10.0, pdf(0.0)); + test_exact(0.1, 0.1, pdf(0.0)); + test_exact(1.0, 1.0, pdf(0.0)); + test_exact(10.0, 10.0, pdf(0.0)); test_is_nan(f64::INFINITY, pdf(0.0)); - test_case(0.1, 0.09900498337491680535739, pdf(0.1)); - test_almost(1.0, 0.9048374180359595731642, 1e-15, pdf(0.1)); - test_case(10.0, 3.678794411714423215955, pdf(0.1)); + test_exact(0.1, 0.09900498337491680535739, pdf(0.1)); + test_absolute(1.0, 0.9048374180359595731642, 1e-15, pdf(0.1)); + test_exact(10.0, 3.678794411714423215955, pdf(0.1)); test_is_nan(f64::INFINITY, pdf(0.1)); - test_case(0.1, 0.09048374180359595731642, pdf(1.0)); - test_case(1.0, 0.3678794411714423215955, pdf(1.0)); - test_almost(10.0, 4.539992976248485153559e-4, 1e-19, pdf(1.0)); + test_exact(0.1, 0.09048374180359595731642, pdf(1.0)); + test_exact(1.0, 0.3678794411714423215955, pdf(1.0)); + test_absolute(10.0, 4.539992976248485153559e-4, 1e-19, pdf(1.0)); test_is_nan(f64::INFINITY, pdf(1.0)); - test_case(0.1, 0.0, pdf(f64::INFINITY)); - test_case(1.0, 0.0, pdf(f64::INFINITY)); - test_case(10.0, 0.0, pdf(f64::INFINITY)); + test_exact(0.1, 0.0, pdf(f64::INFINITY)); + test_exact(1.0, 0.0, pdf(f64::INFINITY)); + test_exact(10.0, 0.0, pdf(f64::INFINITY)); test_is_nan(f64::INFINITY, pdf(f64::INFINITY)); } #[test] fn test_neg_pdf() { let pdf = |arg: f64| move |x: Exp| x.pdf(arg); - test_case(0.1, 0.0, pdf(-1.0)); + test_exact(0.1, 0.0, pdf(-1.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Exp| x.ln_pdf(arg); - test_almost(0.1, -2.302585092994045684018, 1e-15, ln_pdf(0.0)); - test_case(1.0, 0.0, ln_pdf(0.0)); - test_case(10.0, 2.302585092994045684018, ln_pdf(0.0)); + test_absolute(0.1, -2.302585092994045684018, 1e-15, ln_pdf(0.0)); + test_exact(1.0, 0.0, ln_pdf(0.0)); + test_exact(10.0, 2.302585092994045684018, ln_pdf(0.0)); test_is_nan(f64::INFINITY, ln_pdf(0.0)); - test_almost(0.1, -2.312585092994045684018, 1e-15, ln_pdf(0.1)); - test_case(1.0, -0.1, ln_pdf(0.1)); - test_almost(10.0, 1.302585092994045684018, 1e-15, ln_pdf(0.1)); + test_absolute(0.1, -2.312585092994045684018, 1e-15, ln_pdf(0.1)); + test_exact(1.0, -0.1, ln_pdf(0.1)); + test_absolute(10.0, 1.302585092994045684018, 1e-15, ln_pdf(0.1)); test_is_nan(f64::INFINITY, ln_pdf(0.1)); - test_case(0.1, -2.402585092994045684018, ln_pdf(1.0)); - test_case(1.0, -1.0, ln_pdf(1.0)); - test_case(10.0, -7.697414907005954315982, ln_pdf(1.0)); + test_exact(0.1, -2.402585092994045684018, ln_pdf(1.0)); + test_exact(1.0, -1.0, ln_pdf(1.0)); + test_exact(10.0, -7.697414907005954315982, ln_pdf(1.0)); test_is_nan(f64::INFINITY, ln_pdf(1.0)); - test_case(0.1, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); - test_case(1.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); - test_case(10.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); + test_exact(0.1, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); + test_exact(1.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); + test_exact(10.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); test_is_nan(f64::INFINITY, ln_pdf(f64::INFINITY)); } #[test] fn test_neg_ln_pdf() { let ln_pdf = |arg: f64| move |x: Exp| x.ln_pdf(arg); - test_case(0.1, f64::NEG_INFINITY, ln_pdf(-1.0)); + test_exact(0.1, f64::NEG_INFINITY, ln_pdf(-1.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Exp| x.cdf(arg); - test_case(0.1, 0.0, cdf(0.0)); - test_case(1.0, 0.0, cdf(0.0)); - test_case(10.0, 0.0, cdf(0.0)); + test_exact(0.1, 0.0, cdf(0.0)); + test_exact(1.0, 0.0, cdf(0.0)); + test_exact(10.0, 0.0, cdf(0.0)); test_is_nan(f64::INFINITY, cdf(0.0)); - test_almost(0.1, 0.009950166250831946426094, 1e-16, cdf(0.1)); - test_almost(1.0, 0.0951625819640404268358, 1e-16, cdf(0.1)); - test_case(10.0, 0.6321205588285576784045, cdf(0.1)); - test_case(f64::INFINITY, 1.0, cdf(0.1)); - test_almost(0.1, 0.0951625819640404268358, 1e-16, cdf(1.0)); - test_case(1.0, 0.6321205588285576784045, cdf(1.0)); - test_case(10.0, 0.9999546000702375151485, cdf(1.0)); - test_case(f64::INFINITY, 1.0, cdf(1.0)); - test_case(0.1, 1.0, cdf(f64::INFINITY)); - test_case(1.0, 1.0, cdf(f64::INFINITY)); - test_case(10.0, 1.0, cdf(f64::INFINITY)); - test_case(f64::INFINITY, 1.0, cdf(f64::INFINITY)); + test_absolute(0.1, 0.009950166250831946426094, 1e-16, cdf(0.1)); + test_absolute(1.0, 0.0951625819640404268358, 1e-16, cdf(0.1)); + test_exact(10.0, 0.6321205588285576784045, cdf(0.1)); + test_exact(f64::INFINITY, 1.0, cdf(0.1)); + test_absolute(0.1, 0.0951625819640404268358, 1e-16, cdf(1.0)); + test_exact(1.0, 0.6321205588285576784045, cdf(1.0)); + test_exact(10.0, 0.9999546000702375151485, cdf(1.0)); + test_exact(f64::INFINITY, 1.0, cdf(1.0)); + test_exact(0.1, 1.0, cdf(f64::INFINITY)); + test_exact(1.0, 1.0, cdf(f64::INFINITY)); + test_exact(10.0, 1.0, cdf(f64::INFINITY)); + test_exact(f64::INFINITY, 1.0, cdf(f64::INFINITY)); + } + + #[test] + fn test_inverse_cdf() { + let distribution = Exp::new(0.42).unwrap(); + assert_eq!(distribution.median(), distribution.inverse_cdf(0.5)); + + let distribution = Exp::new(0.042).unwrap(); + assert_eq!(distribution.median(), distribution.inverse_cdf(0.5)); + + let distribution = Exp::new(0.0042).unwrap(); + assert_eq!(distribution.median(), distribution.inverse_cdf(0.5)); + + let distribution = Exp::new(0.33).unwrap(); + assert_eq!(distribution.median(), distribution.inverse_cdf(0.5)); + + let distribution = Exp::new(0.033).unwrap(); + assert_eq!(distribution.median(), distribution.inverse_cdf(0.5)); + + let distribution = Exp::new(0.0033).unwrap(); + assert_eq!(distribution.median(), distribution.inverse_cdf(0.5)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Exp| x.sf(arg); - test_case(0.1, 1.0, sf(0.0)); - test_case(1.0, 1.0, sf(0.0)); - test_case(10.0, 1.0, sf(0.0)); + test_exact(0.1, 1.0, sf(0.0)); + test_exact(1.0, 1.0, sf(0.0)); + test_exact(10.0, 1.0, sf(0.0)); test_is_nan(f64::INFINITY, sf(0.0)); - test_almost(0.1, 0.9900498337491681, 1e-16, sf(0.1)); - test_almost(1.0, 0.9048374180359595, 1e-16, sf(0.1)); - test_almost(10.0, 0.36787944117144233, 1e-15, sf(0.1)); - test_case(f64::INFINITY, 0.0, sf(0.1)); + test_absolute(0.1, 0.9900498337491681, 1e-16, sf(0.1)); + test_absolute(1.0, 0.9048374180359595, 1e-16, sf(0.1)); + test_absolute(10.0, 0.36787944117144233, 1e-15, sf(0.1)); + test_exact(f64::INFINITY, 0.0, sf(0.1)); } #[test] fn test_neg_cdf() { let cdf = |arg: f64| move |x: Exp| x.cdf(arg); - test_case(0.1, 0.0, cdf(-1.0)); + test_exact(0.1, 0.0, cdf(-1.0)); } #[test] fn test_neg_sf() { let sf = |arg: f64| move |x: Exp| x.sf(arg); - test_case(0.1, 1.0, sf(-1.0)); + test_exact(0.1, 1.0, sf(-1.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(0.5), 0.0, 10.0); - test::check_continuous_distribution(&try_create(1.5), 0.0, 20.0); - test::check_continuous_distribution(&try_create(2.5), 0.0, 50.0); + test::check_continuous_distribution(&create_ok(0.5), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(1.5), 0.0, 20.0); + test::check_continuous_distribution(&create_ok(2.5), 0.0, 50.0); } } diff --git a/src/distribution/fisher_snedecor.rs b/src/distribution/fisher_snedecor.rs index 4b8782be..3208f98f 100644 --- a/src/distribution/fisher_snedecor.rs +++ b/src/distribution/fisher_snedecor.rs @@ -1,8 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::beta; use crate::statistics::*; -use crate::{Result, StatsError}; -use rand::Rng; use std::f64; /// Implements the @@ -26,6 +24,33 @@ pub struct FisherSnedecor { freedom_2: f64, } +/// Represents the errors that can occur when creating a [`FisherSnedecor`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum FisherSnedecorError { + /// `freedom_1` is NaN, infinite, zero or less than zero. + Freedom1Invalid, + + /// `freedom_2` is NaN, infinite, zero or less than zero. + Freedom2Invalid, +} + +impl std::fmt::Display for FisherSnedecorError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + FisherSnedecorError::Freedom1Invalid => { + write!(f, "freedom_1 is NaN, infinite, zero or less than zero.") + } + FisherSnedecorError::Freedom2Invalid => { + write!(f, "freedom_2 is NaN, infinite, zero or less than zero.") + } + } + } +} + +impl std::error::Error for FisherSnedecorError {} + impl FisherSnedecor { /// Constructs a new fisher-snedecor distribution with /// degrees of freedom `freedom_1` and `freedom_2` @@ -46,16 +71,19 @@ impl FisherSnedecor { /// result = FisherSnedecor::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(freedom_1: f64, freedom_2: f64) -> Result { - if !freedom_1.is_finite() || freedom_1 <= 0.0 || !freedom_2.is_finite() || freedom_2 <= 0.0 - { - Err(StatsError::BadParams) - } else { - Ok(FisherSnedecor { - freedom_1, - freedom_2, - }) + pub fn new(freedom_1: f64, freedom_2: f64) -> Result { + if !freedom_1.is_finite() || freedom_1 <= 0.0 { + return Err(FisherSnedecorError::Freedom1Invalid); } + + if !freedom_2.is_finite() || freedom_2 <= 0.0 { + return Err(FisherSnedecorError::Freedom2Invalid); + } + + Ok(FisherSnedecor { + freedom_1, + freedom_2, + }) } /// Returns the first degree of freedom for the @@ -89,8 +117,15 @@ impl FisherSnedecor { } } +impl std::fmt::Display for FisherSnedecor { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "F({},{})", self.freedom_1, self.freedom_2) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for FisherSnedecor { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { (super::gamma::sample_unchecked(rng, self.freedom_1 / 2.0, 0.5) * self.freedom_2) / (super::gamma::sample_unchecked(rng, self.freedom_2 / 2.0, 0.5) * self.freedom_1) } @@ -103,7 +138,7 @@ impl ContinuousCDF for FisherSnedecor { /// /// # Formula /// - /// ```ignore + /// ```text /// I_((d1 * x) / (d1 * x + d2))(d1 / 2, d2 / 2) /// ``` /// @@ -129,7 +164,7 @@ impl ContinuousCDF for FisherSnedecor { /// /// # Formula /// - /// ```ignore + /// ```text /// I_(1 - ((d1 * x) / (d1 * x + d2))(d2 / 2, d1 / 2) /// ``` /// @@ -144,11 +179,33 @@ impl ContinuousCDF for FisherSnedecor { } else { beta::beta_reg( self.freedom_2 / 2.0, - self.freedom_1 / 2.0, - 1. - ((self.freedom_1 * x) / (self.freedom_1 * x + self.freedom_2)) + self.freedom_1 / 2.0, + 1. - ((self.freedom_1 * x) / (self.freedom_1 * x + self.freedom_2)), ) } } + + /// Calculates the inverse cumulative distribution function for the + /// fisher-snedecor distribution at `x` + /// + /// # Formula + /// + /// ```text + /// z = I^{-1}_(x)(d1 / 2, d2 / 2) + /// d2 / (d1 (1 / z - 1)) + /// ``` + /// + /// where `d1` is the first degree of freedom, `d2` is + /// the second degree of freedom, and `I` is the regularized incomplete + /// beta function + fn inverse_cdf(&self, x: f64) -> f64 { + if !(0.0..=1.0).contains(&x) { + panic!("x must be in [0, 1]"); + } else { + let z = beta::inv_beta_reg(self.freedom_1 / 2.0, self.freedom_2 / 2.0, x); + self.freedom_2 / (self.freedom_1 * (1.0 / z - 1.0)) + } + } } impl Min for FisherSnedecor { @@ -158,7 +215,7 @@ impl Min for FisherSnedecor { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> f64 { @@ -173,8 +230,8 @@ impl Max for FisherSnedecor { /// /// # Formula /// - /// ```ignore - /// INF + /// ```text + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY @@ -194,7 +251,7 @@ impl Distribution for FisherSnedecor { /// /// # Formula /// - /// ```ignore + /// ```text /// d2 / (d2 - 2) /// ``` /// @@ -206,6 +263,7 @@ impl Distribution for FisherSnedecor { Some(self.freedom_2 / (self.freedom_2 - 2.0)) } } + /// Returns the variance of the fisher-snedecor distribution /// /// # Panics @@ -218,7 +276,7 @@ impl Distribution for FisherSnedecor { /// /// # Formula /// - /// ```ignore + /// ```text /// (2 * d2^2 * (d1 + d2 - 2)) / (d1 * (d2 - 2)^2 * (d2 - 4)) /// ``` /// @@ -237,6 +295,7 @@ impl Distribution for FisherSnedecor { Some(val) } } + /// Returns the skewness of the fisher-snedecor distribution /// /// # Panics @@ -249,7 +308,7 @@ impl Distribution for FisherSnedecor { /// /// # Formula /// - /// ```ignore + /// ```text /// ((2d1 + d2 - 2) * sqrt(8 * (d2 - 4))) / ((d2 - 6) * sqrt(d1 * (d1 + d2 /// - 2))) /// ``` @@ -282,7 +341,7 @@ impl Mode> for FisherSnedecor { /// /// # Formula /// - /// ```ignore + /// ```text /// ((d1 - 2) / d1) * (d2 / (d2 + 2)) /// ``` /// @@ -311,7 +370,7 @@ impl Continuous for FisherSnedecor { /// /// # Formula /// - /// ```ignore + /// ```text /// sqrt(((d1 * x) ^ d1 * d2 ^ d2) / (d1 * x + d2) ^ (d1 + d2)) / (x * β(d1 /// / 2, d2 / 2)) /// ``` @@ -340,7 +399,7 @@ impl Continuous for FisherSnedecor { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(sqrt(((d1 * x) ^ d1 * d2 ^ d2) / (d1 * x + d2) ^ (d1 + d2)) / (x * /// β(d1 / 2, d2 / 2))) /// ``` @@ -353,240 +412,213 @@ impl Continuous for FisherSnedecor { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use crate::statistics::*; - use crate::distribution::{ContinuousCDF, Continuous, FisherSnedecor}; + use super::*; use crate::distribution::internal::*; - use crate::consts::ACC; - - fn try_create(freedom_1: f64, freedom_2: f64) -> FisherSnedecor { - let n = FisherSnedecor::new(freedom_1, freedom_2); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(freedom_1: f64, freedom_2: f64) { - let n = try_create(freedom_1, freedom_2); - assert_eq!(freedom_1, n.freedom_1()); - assert_eq!(freedom_2, n.freedom_2()); - } - - fn bad_create_case(freedom_1: f64, freedom_2: f64) { - let n = FisherSnedecor::new(freedom_1, freedom_2); - assert!(n.is_err()); - } + use crate::testing_boiler; - fn get_value(freedom_1: f64, freedom_2: f64, eval: F) -> f64 - where F: Fn(FisherSnedecor) -> f64 - { - let n = try_create(freedom_1, freedom_2); - eval(n) - } - - fn test_case(freedom_1: f64, freedom_2: f64, expected: f64, eval: F) - where F: Fn(FisherSnedecor) -> f64 - { - let x = get_value(freedom_1, freedom_2, eval); - assert_eq!(expected, x); - } - - fn test_almost(freedom_1: f64, freedom_2: f64, expected: f64, acc: f64, eval: F) - where F: Fn(FisherSnedecor) -> f64 - { - let x = get_value(freedom_1, freedom_2, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(freedom_1: f64, freedom_2: f64; FisherSnedecor; FisherSnedecorError); #[test] fn test_create() { - create_case(0.1, 0.1); - create_case(1.0, 0.1); - create_case(10.0, 0.1); - create_case(0.1, 1.0); - create_case(1.0, 1.0); - create_case(10.0, 1.0); + create_ok(0.1, 0.1); + create_ok(1.0, 0.1); + create_ok(10.0, 0.1); + create_ok(0.1, 1.0); + create_ok(1.0, 1.0); + create_ok(10.0, 1.0); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN, f64::NAN); - bad_create_case(0.0, f64::NAN); - bad_create_case(-1.0, f64::NAN); - bad_create_case(-10.0, f64::NAN); - bad_create_case(f64::NAN, 0.0); - bad_create_case(0.0, 0.0); - bad_create_case(-1.0, 0.0); - bad_create_case(-10.0, 0.0); - bad_create_case(f64::NAN, -1.0); - bad_create_case(0.0, -1.0); - bad_create_case(-1.0, -1.0); - bad_create_case(-10.0, -1.0); - bad_create_case(f64::NAN, -10.0); - bad_create_case(0.0, -10.0); - bad_create_case(-1.0, -10.0); - bad_create_case(-10.0, -10.0); - bad_create_case(f64::INFINITY, 0.1); - bad_create_case(0.1, f64::INFINITY); - bad_create_case(f64::INFINITY, f64::INFINITY); + test_create_err(f64::INFINITY, 0.1, FisherSnedecorError::Freedom1Invalid); + test_create_err(0.1, f64::INFINITY, FisherSnedecorError::Freedom2Invalid); + + create_err(f64::NAN, f64::NAN); + create_err(0.0, f64::NAN); + create_err(-1.0, f64::NAN); + create_err(-10.0, f64::NAN); + create_err(f64::NAN, 0.0); + create_err(0.0, 0.0); + create_err(-1.0, 0.0); + create_err(-10.0, 0.0); + create_err(f64::NAN, -1.0); + create_err(0.0, -1.0); + create_err(-1.0, -1.0); + create_err(-10.0, -1.0); + create_err(f64::NAN, -10.0); + create_err(0.0, -10.0); + create_err(-1.0, -10.0); + create_err(-10.0, -10.0); + create_err(f64::INFINITY, f64::INFINITY); } #[test] fn test_mean() { let mean = |x: FisherSnedecor| x.mean().unwrap(); - test_case(0.1, 10.0, 1.25, mean); - test_case(1.0, 10.0, 1.25, mean); - test_case(10.0, 10.0, 1.25, mean); + test_exact(0.1, 10.0, 1.25, mean); + test_exact(1.0, 10.0, 1.25, mean); + test_exact(10.0, 10.0, 1.25, mean); } #[test] - #[should_panic] fn test_mean_with_low_d2() { - let mean = |x: FisherSnedecor| x.mean().unwrap(); - get_value(0.1, 0.1, mean); + test_none(0.1, 0.1, |dist| dist.mean()); } #[test] fn test_variance() { let variance = |x: FisherSnedecor| x.variance().unwrap(); - test_almost(0.1, 10.0, 42.1875, 1e-14, variance); - test_case(1.0, 10.0, 4.6875, variance); - test_case(10.0, 10.0, 0.9375, variance); + test_absolute(0.1, 10.0, 42.1875, 1e-14, variance); + test_exact(1.0, 10.0, 4.6875, variance); + test_exact(10.0, 10.0, 0.9375, variance); } #[test] - #[should_panic] fn test_variance_with_low_d2() { - let variance = |x: FisherSnedecor| x.variance().unwrap(); - get_value(0.1, 0.1, variance); + test_none(0.1, 0.1, |dist| dist.variance()); } #[test] fn test_skewness() { let skewness = |x: FisherSnedecor| x.skewness().unwrap(); - test_almost(0.1, 10.0, 15.78090735784977089658, 1e-14, skewness); - test_case(1.0, 10.0, 5.773502691896257645091, skewness); - test_case(10.0, 10.0, 3.614784456460255759501, skewness); + test_absolute(0.1, 10.0, 15.78090735784977089658, 1e-14, skewness); + test_exact(1.0, 10.0, 5.773502691896257645091, skewness); + test_exact(10.0, 10.0, 3.614784456460255759501, skewness); } #[test] - #[should_panic] fn test_skewness_with_low_d2() { - let skewness = |x: FisherSnedecor| x.skewness().unwrap(); - get_value(0.1, 0.1, skewness); + test_none(0.1, 0.1, |dist| dist.skewness()); } #[test] fn test_mode() { let mode = |x: FisherSnedecor| x.mode().unwrap(); - test_case(10.0, 0.1, 0.0380952380952380952381, mode); - test_case(10.0, 1.0, 4.0 / 15.0, mode); - test_case(10.0, 10.0, 2.0 / 3.0, mode); + test_exact(10.0, 0.1, 0.0380952380952380952381, mode); + test_exact(10.0, 1.0, 4.0 / 15.0, mode); + test_exact(10.0, 10.0, 2.0 / 3.0, mode); } #[test] - #[should_panic] fn test_mode_with_low_d1() { - let mode = |x: FisherSnedecor| x.mode().unwrap(); - get_value(0.1, 0.1, mode); + test_none(0.1, 0.1, |dist| dist.mode()); } #[test] fn test_min_max() { let min = |x: FisherSnedecor| x.min(); let max = |x: FisherSnedecor| x.max(); - test_case(1.0, 1.0, 0.0, min); - test_case(1.0, 1.0, f64::INFINITY, max); + test_exact(1.0, 1.0, 0.0, min); + test_exact(1.0, 1.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: FisherSnedecor| x.pdf(arg); - test_almost(0.1, 0.1, 0.0234154207226588982471, 1e-16, pdf(1.0)); - test_almost(1.0, 0.1, 0.0396064560910663979961, 1e-16, pdf(1.0)); - test_almost(10.0, 0.1, 0.0418440630400545297349, 1e-14, pdf(1.0)); - test_almost(0.1, 1.0, 0.0396064560910663979961, 1e-16, pdf(1.0)); - test_almost(1.0, 1.0, 0.1591549430918953357689, 1e-16, pdf(1.0)); - test_almost(10.0, 1.0, 0.230361989229138647108, 1e-16, pdf(1.0)); - test_almost(0.1, 0.1, 0.00221546909694001013517, 1e-18, pdf(10.0)); - test_almost(1.0, 0.1, 0.00369960370387922619592, 1e-17, pdf(10.0)); - test_almost(10.0, 0.1, 0.00390179721174142927402, 1e-15, pdf(10.0)); - test_almost(0.1, 1.0, 0.00319864073359931548273, 1e-17, pdf(10.0)); - test_almost(1.0, 1.0, 0.009150765837179460915678, 1e-17, pdf(10.0)); - test_almost(10.0, 1.0, 0.0116493859171442148446, 1e-17, pdf(10.0)); - test_almost(0.1, 10.0, 0.00305087016058573989694, 1e-15, pdf(10.0)); - test_almost(1.0, 10.0, 0.00271897749113479577864, 1e-17, pdf(10.0)); - test_almost(10.0, 10.0, 2.4289227234060500084E-4, 1e-18, pdf(10.0)); + test_absolute(0.1, 0.1, 0.0234154207226588982471, 1e-16, pdf(1.0)); + test_absolute(1.0, 0.1, 0.0396064560910663979961, 1e-16, pdf(1.0)); + test_absolute(10.0, 0.1, 0.0418440630400545297349, 1e-14, pdf(1.0)); + test_absolute(0.1, 1.0, 0.0396064560910663979961, 1e-16, pdf(1.0)); + test_absolute(1.0, 1.0, 0.1591549430918953357689, 1e-16, pdf(1.0)); + test_absolute(10.0, 1.0, 0.230361989229138647108, 1e-16, pdf(1.0)); + test_absolute(0.1, 0.1, 0.00221546909694001013517, 1e-18, pdf(10.0)); + test_absolute(1.0, 0.1, 0.00369960370387922619592, 1e-17, pdf(10.0)); + test_absolute(10.0, 0.1, 0.00390179721174142927402, 1e-15, pdf(10.0)); + test_absolute(0.1, 1.0, 0.00319864073359931548273, 1e-17, pdf(10.0)); + test_absolute(1.0, 1.0, 0.009150765837179460915678, 1e-17, pdf(10.0)); + test_absolute(10.0, 1.0, 0.0116493859171442148446, 1e-17, pdf(10.0)); + test_absolute(0.1, 10.0, 0.00305087016058573989694, 1e-15, pdf(10.0)); + test_absolute(1.0, 10.0, 0.00271897749113479577864, 1e-17, pdf(10.0)); + test_absolute(10.0, 10.0, 2.4289227234060500084E-4, 1e-18, pdf(10.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: FisherSnedecor| x.ln_pdf(arg); - test_almost(0.1, 0.1, 0.0234154207226588982471f64.ln(), 1e-15, ln_pdf(1.0)); - test_almost(1.0, 0.1, 0.0396064560910663979961f64.ln(), 1e-15, ln_pdf(1.0)); - test_almost(10.0, 0.1, 0.0418440630400545297349f64.ln(), 1e-13, ln_pdf(1.0)); - test_almost(0.1, 1.0, 0.0396064560910663979961f64.ln(), 1e-15, ln_pdf(1.0)); - test_almost(1.0, 1.0, 0.1591549430918953357689f64.ln(), 1e-15, ln_pdf(1.0)); - test_almost(10.0, 1.0, 0.230361989229138647108f64.ln(), 1e-15, ln_pdf(1.0)); - test_case(0.1, 0.1, 0.00221546909694001013517f64.ln(), ln_pdf(10.0)); - test_almost(1.0, 0.1, 0.00369960370387922619592f64.ln(), 1e-15, ln_pdf(10.0)); - test_almost(10.0, 0.1, 0.00390179721174142927402f64.ln(), 1e-13, ln_pdf(10.0)); - test_almost(0.1, 1.0, 0.00319864073359931548273f64.ln(), 1e-15, ln_pdf(10.0)); - test_almost(1.0, 1.0, 0.009150765837179460915678f64.ln(), 1e-15, ln_pdf(10.0)); - test_case(10.0, 1.0, 0.0116493859171442148446f64.ln(), ln_pdf(10.0)); - test_almost(0.1, 10.0, 0.00305087016058573989694f64.ln(), 1e-13, ln_pdf(10.0)); - test_case(1.0, 10.0, 0.00271897749113479577864f64.ln(), ln_pdf(10.0)); - test_almost(10.0, 10.0, 2.4289227234060500084E-4f64.ln(), 1e-14, ln_pdf(10.0)); + test_absolute(0.1, 0.1, 0.0234154207226588982471f64.ln(), 1e-15, ln_pdf(1.0)); + test_absolute(1.0, 0.1, 0.0396064560910663979961f64.ln(), 1e-15, ln_pdf(1.0)); + test_absolute(10.0, 0.1, 0.0418440630400545297349f64.ln(), 1e-13, ln_pdf(1.0)); + test_absolute(0.1, 1.0, 0.0396064560910663979961f64.ln(), 1e-15, ln_pdf(1.0)); + test_absolute(1.0, 1.0, 0.1591549430918953357689f64.ln(), 1e-15, ln_pdf(1.0)); + test_absolute(10.0, 1.0, 0.230361989229138647108f64.ln(), 1e-15, ln_pdf(1.0)); + test_exact(0.1, 0.1, 0.00221546909694001013517f64.ln(), ln_pdf(10.0)); + test_absolute(1.0, 0.1, 0.00369960370387922619592f64.ln(), 1e-15, ln_pdf(10.0)); + test_absolute(10.0, 0.1, 0.00390179721174142927402f64.ln(), 1e-13, ln_pdf(10.0)); + test_absolute(0.1, 1.0, 0.00319864073359931548273f64.ln(), 1e-15, ln_pdf(10.0)); + test_absolute(1.0, 1.0, 0.009150765837179460915678f64.ln(), 1e-15, ln_pdf(10.0)); + test_exact(10.0, 1.0, 0.0116493859171442148446f64.ln(), ln_pdf(10.0)); + test_absolute(0.1, 10.0, 0.00305087016058573989694f64.ln(), 1e-13, ln_pdf(10.0)); + test_exact(1.0, 10.0, 0.00271897749113479577864f64.ln(), ln_pdf(10.0)); + test_absolute(10.0, 10.0, 2.4289227234060500084E-4f64.ln(), 1e-14, ln_pdf(10.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: FisherSnedecor| x.cdf(arg); - test_almost(0.1, 0.1, 0.44712986033425140335, 1e-15, cdf(0.1)); - test_almost(1.0, 0.1, 0.08156522095104674015, 1e-15, cdf(0.1)); - test_almost(10.0, 0.1, 0.033184005716276536322, 1e-13, cdf(0.1)); - test_almost(0.1, 1.0, 0.74378710917986379989, 1e-15, cdf(0.1)); - test_almost(1.0, 1.0, 0.1949822290421366451595, 1e-16, cdf(0.1)); - test_almost(10.0, 1.0, 0.0101195597354337146205, 1e-17, cdf(0.1)); - test_almost(0.1, 0.1, 0.5, 1e-15, cdf(1.0)); - test_almost(1.0, 0.1, 0.16734351500944271141, 1e-14, cdf(1.0)); - test_almost(10.0, 0.1, 0.12207560664741704938, 1e-13, cdf(1.0)); - test_almost(0.1, 1.0, 0.83265648499055728859, 1e-15, cdf(1.0)); - test_almost(1.0, 1.0, 0.5, 1e-15, cdf(1.0)); - test_almost(10.0, 1.0, 0.340893132302059872675, 1e-15, cdf(1.0)); + test_absolute(0.1, 0.1, 0.44712986033425140335, 1e-15, cdf(0.1)); + test_absolute(1.0, 0.1, 0.08156522095104674015, 1e-15, cdf(0.1)); + test_absolute(10.0, 0.1, 0.033184005716276536322, 1e-13, cdf(0.1)); + test_absolute(0.1, 1.0, 0.74378710917986379989, 1e-15, cdf(0.1)); + test_absolute(1.0, 1.0, 0.1949822290421366451595, 1e-16, cdf(0.1)); + test_absolute(10.0, 1.0, 0.0101195597354337146205, 1e-17, cdf(0.1)); + test_absolute(0.1, 0.1, 0.5, 1e-15, cdf(1.0)); + test_absolute(1.0, 0.1, 0.16734351500944271141, 1e-14, cdf(1.0)); + test_absolute(10.0, 0.1, 0.12207560664741704938, 1e-13, cdf(1.0)); + test_absolute(0.1, 1.0, 0.83265648499055728859, 1e-15, cdf(1.0)); + test_absolute(1.0, 1.0, 0.5, 1e-15, cdf(1.0)); + test_absolute(10.0, 1.0, 0.340893132302059872675, 1e-15, cdf(1.0)); } #[test] fn test_cdf_lower_bound() { let cdf = |arg: f64| move |x: FisherSnedecor| x.cdf(arg); - test_case(0.1, 0.1, 0.0, cdf(-1.0)); + test_exact(0.1, 0.1, 0.0, cdf(-1.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: FisherSnedecor| x.sf(arg); - test_almost(0.1, 0.1, 0.5528701396657489, 1e-12, sf(0.1)); - test_almost(1.0, 0.1, 0.9184347790489533, 1e-12, sf(0.1)); - test_almost(10.0, 0.1, 0.9668159942836896, 1e-12, sf(0.1)); - test_almost(0.1, 1.0, 0.25621289082013654, 1e-12, sf(0.1)); - test_almost(1.0, 1.0, 0.8050177709578634, 1e-12, sf(0.1)); - test_almost(10.0, 1.0, 0.9898804402645662, 1e-12, sf(0.1)); - test_almost(0.1, 0.1, 0.5, 1e-15, sf(1.0)); - test_almost(1.0, 0.1, 0.8326564849905562, 1e-12, sf(1.0)); - test_almost(10.0, 0.1, 0.8779243933525519, 1e-12, sf(1.0)); - test_almost(0.1, 1.0, 0.16734351500944344, 1e-12, sf(1.0)); - test_almost(1.0, 1.0, 0.5, 1e-12, sf(1.0)); - test_almost(10.0, 1.0, 0.65910686769794, 1e-12, sf(1.0)); + test_absolute(0.1, 0.1, 0.5528701396657489, 1e-12, sf(0.1)); + test_absolute(1.0, 0.1, 0.9184347790489533, 1e-12, sf(0.1)); + test_absolute(10.0, 0.1, 0.9668159942836896, 1e-12, sf(0.1)); + test_absolute(0.1, 1.0, 0.25621289082013654, 1e-12, sf(0.1)); + test_absolute(1.0, 1.0, 0.8050177709578634, 1e-12, sf(0.1)); + test_absolute(10.0, 1.0, 0.9898804402645662, 1e-12, sf(0.1)); + test_absolute(0.1, 0.1, 0.5, 1e-15, sf(1.0)); + test_absolute(1.0, 0.1, 0.8326564849905562, 1e-12, sf(1.0)); + test_absolute(10.0, 0.1, 0.8779243933525519, 1e-12, sf(1.0)); + test_absolute(0.1, 1.0, 0.16734351500944344, 1e-12, sf(1.0)); + test_absolute(1.0, 1.0, 0.5, 1e-12, sf(1.0)); + test_absolute(10.0, 1.0, 0.65910686769794, 1e-12, sf(1.0)); + } + + #[test] + fn test_inverse_cdf() { + let func = |arg: f64| move |x: FisherSnedecor| x.inverse_cdf(x.cdf(arg)); + test_absolute(0.1, 0.1, 0.1, 1e-12, func(0.1)); + test_absolute(1.0, 0.1, 0.1, 1e-12, func(0.1)); + test_absolute(10.0, 0.1, 0.1, 1e-12, func(0.1)); + test_absolute(0.1, 1.0, 0.1, 1e-12, func(0.1)); + test_absolute(1.0, 1.0, 0.1, 1e-12, func(0.1)); + test_absolute(10.0, 1.0, 0.1, 1e-12, func(0.1)); + test_absolute(0.1, 0.1, 1.0, 1e-13, func(1.0)); + test_absolute(1.0, 0.1, 1.0, 1e-12, func(1.0)); + test_absolute(10.0, 0.1, 1.0, 1e-12, func(1.0)); + test_absolute(0.1, 1.0, 1.0, 1e-12, func(1.0)); + test_absolute(1.0, 1.0, 1.0, 1e-12, func(1.0)); + test_absolute(10.0, 1.0, 1.0, 1e-12, func(1.0)); } #[test] fn test_sf_lower_bound() { let sf = |arg: f64| move |x: FisherSnedecor| x.sf(arg); - test_case(0.1, 0.1, 1.0, sf(-1.0)); + test_exact(0.1, 0.1, 1.0, sf(-1.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(10.0, 10.0), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(10.0, 10.0), 0.0, 10.0); } } diff --git a/src/distribution/gamma.rs b/src/distribution/gamma.rs index 7a36a30f..89341439 100644 --- a/src/distribution/gamma.rs +++ b/src/distribution/gamma.rs @@ -1,9 +1,7 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; +use crate::prec; use crate::statistics::*; -use crate::{Result, StatsError}; -use core::f64::INFINITY as INF; -use rand::Rng; /// Implements the [Gamma](https://en.wikipedia.org/wiki/Gamma_distribution) /// distribution @@ -19,12 +17,39 @@ use rand::Rng; /// assert_eq!(n.mean().unwrap(), 3.0); /// assert!(prec::almost_eq(n.pdf(2.0), 0.270670566473225383788, 1e-15)); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Gamma { shape: f64, rate: f64, } +/// Represents the errors that can occur when creating a [`Gamma`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum GammaError { + /// The shape is NaN, zero or less than zero. + ShapeInvalid, + + /// The rate is NaN, zero or less than zero. + RateInvalid, + + /// The shape and rate are both infinite. + ShapeAndRateInfinite, +} + +impl std::fmt::Display for GammaError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + GammaError::ShapeInvalid => write!(f, "Shape is NaN zero, or less than zero."), + GammaError::RateInvalid => write!(f, "Rate is NaN zero, or less than zero."), + GammaError::ShapeAndRateInfinite => write!(f, "Shape and rate are infinite"), + } + } +} + +impl std::error::Error for GammaError {} + impl Gamma { /// Constructs a new gamma distribution with a shape (α) /// of `shape` and a rate (β) of `rate` @@ -45,15 +70,19 @@ impl Gamma { /// result = Gamma::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(shape: f64, rate: f64) -> Result { - if shape.is_nan() - || rate.is_nan() - || shape.is_infinite() && rate.is_infinite() - || shape <= 0.0 - || rate <= 0.0 - { - return Err(StatsError::BadParams); + pub fn new(shape: f64, rate: f64) -> Result { + if shape.is_nan() || shape <= 0.0 { + return Err(GammaError::ShapeInvalid); + } + + if rate.is_nan() || rate <= 0.0 { + return Err(GammaError::RateInvalid); } + + if shape.is_infinite() && rate.is_infinite() { + return Err(GammaError::ShapeAndRateInfinite); + } + Ok(Gamma { shape, rate }) } @@ -86,8 +115,15 @@ impl Gamma { } } +impl std::fmt::Display for Gamma { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Γ({}, {})", self.shape, self.rate) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Gamma { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { sample_unchecked(rng, self.shape, self.rate) } } @@ -99,7 +135,7 @@ impl ContinuousCDF for Gamma { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / Γ(α)) * γ(α, β * x) /// ``` /// @@ -124,7 +160,7 @@ impl ContinuousCDF for Gamma { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / Γ(α)) * γ(α, β * x) /// ``` /// @@ -133,20 +169,60 @@ impl ContinuousCDF for Gamma { fn sf(&self, x: f64) -> f64 { if x <= 0.0 { 1.0 - } - else if ulps_eq!(x, self.shape) && self.rate.is_infinite() { + } else if ulps_eq!(x, self.shape) && self.rate.is_infinite() { 0.0 - } - else if self.rate.is_infinite() { + } else if self.rate.is_infinite() { 1.0 - } - else if x.is_infinite() { + } else if x.is_infinite() { 0.0 - } - else { + } else { gamma::gamma_ur(self.shape, x * self.rate) } } + + fn inverse_cdf(&self, p: f64) -> f64 { + if !(0.0..=1.0).contains(&p) { + panic!("default inverse_cdf implementation should be provided probability on [0,1]") + } + if p == 0.0 { + return self.min(); + }; + if p == 1.0 { + return self.max(); + }; + + // Bisection search for MAX_ITERS.0 iterations + let mut high = 2.0; + let mut low = 1.0; + while self.cdf(low) > p { + low /= 2.0; + } + while self.cdf(high) < p { + high *= 2.0; + } + let mut x_0 = (high + low) / 2.0; + + for _ in 0..8 { + if self.cdf(x_0) >= p { + high = x_0; + } else { + low = x_0; + } + if prec::convergence(&mut x_0, (high + low) / 2.0) { + break; + } + } + + // Newton Raphson, for at least one step + for _ in 0..4 { + let x_next = x_0 - (self.cdf(x_0) - p) / self.pdf(x_0); + if prec::convergence(&mut x_0, x_next) { + break; + } + } + + x_0 + } } impl Min for Gamma { @@ -156,7 +232,7 @@ impl Min for Gamma { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> f64 { @@ -171,11 +247,11 @@ impl Max for Gamma { /// /// # Formula /// - /// ```ignore - /// INF + /// ```text + /// f64::INFINITY /// ``` fn max(&self) -> f64 { - INF + f64::INFINITY } } @@ -184,7 +260,7 @@ impl Distribution for Gamma { /// /// # Formula /// - /// ```ignore + /// ```text /// α / β /// ``` /// @@ -192,11 +268,12 @@ impl Distribution for Gamma { fn mean(&self) -> Option { Some(self.shape / self.rate) } + /// Returns the variance of the gamma distribution /// /// # Formula /// - /// ```ignore + /// ```text /// α / β^2 /// ``` /// @@ -204,11 +281,12 @@ impl Distribution for Gamma { fn variance(&self) -> Option { Some(self.shape / (self.rate * self.rate)) } + /// Returns the entropy of the gamma distribution /// /// # Formula /// - /// ```ignore + /// ```text /// α - ln(β) + ln(Γ(α)) + (1 - α) * ψ(α) /// ``` /// @@ -220,11 +298,12 @@ impl Distribution for Gamma { + (1.0 - self.shape) * gamma::digamma(self.shape); Some(entr) } + /// Returns the skewness of the gamma distribution /// /// # Formula /// - /// ```ignore + /// ```text /// 2 / sqrt(α) /// ``` /// @@ -239,13 +318,17 @@ impl Mode> for Gamma { /// /// # Formula /// - /// ```ignore - /// (α - 1) / β + /// ```text + /// (α - 1) / β, where α≥1 /// ``` /// /// where `α` is the shape and `β` is the rate fn mode(&self) -> Option { - Some((self.shape - 1.0) / self.rate) + if self.shape < 1.0 { + None + } else { + Some((self.shape - 1.0) / self.rate) + } } } @@ -255,12 +338,12 @@ impl Continuous for Gamma { /// /// # Remarks /// - /// Returns `NAN` if any of `shape` or `rate` are `INF` - /// or if `x` is `INF` + /// Returns `NAN` if any of `shape` or `rate` are `f64::INFINITY` + /// or if `x` is `f64::INFINITY` /// /// # Formula /// - /// ```ignore + /// ```text /// (β^α / Γ(α)) * x^(α - 1) * e^(-β * x) /// ``` /// @@ -286,12 +369,12 @@ impl Continuous for Gamma { /// /// # Remarks /// - /// Returns `NAN` if any of `shape` or `rate` are `INF` - /// or if `x` is `INF` + /// Returns `NAN` if any of `shape` or `rate` are `f64::INFINITY` + /// or if `x` is `f64::INFINITY` /// /// # Formula /// - /// ```ignore + /// ```text /// ln((β^α / Γ(α)) * x^(α - 1) * e ^(-β * x)) /// ``` /// @@ -312,16 +395,13 @@ impl Continuous for Gamma { } /// Samples from a gamma distribution with a shape of `shape` and a /// rate of `rate` using `rng` as the source of randomness. Implementation from: -///
-///
-/// "A Simple Method for Generating Gamma Variables" - Marsaglia & Tsang -///
-///
+/// +/// _"A Simple Method for Generating Gamma Variables"_ - Marsaglia & Tsang +/// /// ACM Transactions on Mathematical Software, Vol. 26, No. 3, September 2000, /// Pages 363-372 -///
-///
-pub fn sample_unchecked(rng: &mut R, shape: f64, rate: f64) -> f64 { +#[cfg(feature = "rand")] +pub fn sample_unchecked(rng: &mut R, shape: f64, rate: f64) -> f64 { let mut a = shape; let mut afix = 1.0; if shape < 1.0 { @@ -342,8 +422,8 @@ pub fn sample_unchecked(rng: &mut R, shape: f64, rate: f64) -> }; } - v *= v * v; - x *= x; + v = v * v * v; + x = x * x; let u: f64 = rng.gen(); if u < 1.0 - 0.0331 * x * x || u.ln() < 0.5 * x + d * (1.0 - v + v.ln()) { return afix * d * v / rate; @@ -351,14 +431,13 @@ pub fn sample_unchecked(rng: &mut R, shape: f64, rate: f64) -> } } -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use super::*; - use crate::consts::ACC; use crate::distribution::internal::*; use crate::testing_boiler; - testing_boiler!((f64, f64), Gamma); + testing_boiler!(shape: f64, rate: f64; Gamma; GammaError); #[test] fn test_create() { @@ -367,26 +446,31 @@ mod tests { (1.0, 1.0), (10.0, 10.0), (10.0, 1.0), - (10.0, INF), + (10.0, f64::INFINITY), ]; - for &arg in valid.iter() { - try_create(arg); + for (s, r) in valid { + create_ok(s, r); } } #[test] fn test_bad_create() { let invalid = [ - (0.0, 0.0), - (1.0, f64::NAN), - (1.0, -1.0), - (-1.0, 1.0), - (-1.0, -1.0), - (-1.0, f64::NAN), + (0.0, 0.0, GammaError::ShapeInvalid), + (1.0, f64::NAN, GammaError::RateInvalid), + (1.0, -1.0, GammaError::RateInvalid), + (-1.0, 1.0, GammaError::ShapeInvalid), + (-1.0, -1.0, GammaError::ShapeInvalid), + (-1.0, f64::NAN, GammaError::ShapeInvalid), + ( + f64::INFINITY, + f64::INFINITY, + GammaError::ShapeAndRateInfinite, + ), ]; - for &arg in invalid.iter() { - bad_create_case(arg); + for (s, r, err) in invalid { + test_create_err(s, r, err); } } @@ -398,10 +482,10 @@ mod tests { ((1.0, 1.0), 1.0), ((10.0, 10.0), 1.0), ((10.0, 1.0), 10.0), - ((10.0, INF), 0.0), + ((10.0, f64::INFINITY), 0.0), ]; - for &(arg, res) in test.iter() { - test_case(arg, res, f); + for ((s, r), res) in test { + test_relative(s, r, res, f); } } @@ -413,10 +497,10 @@ mod tests { ((1.0, 1.0), 1.0), ((10.0, 10.0), 0.1), ((10.0, 1.0), 10.0), - ((10.0, INF), 0.0), + ((10.0, f64::INFINITY), 0.0), ]; - for &(arg, res) in test.iter() { - test_case(arg, res, f); + for ((s, r), res) in test { + test_relative(s, r, res, f); } } @@ -428,10 +512,10 @@ mod tests { ((1.0, 1.0), 1.0), ((10.0, 10.0), 0.2334690854869339583626209), ((10.0, 1.0), 2.53605417848097964238061239), - ((10.0, INF), f64::NEG_INFINITY), + ((10.0, f64::INFINITY), f64::NEG_INFINITY), ]; - for &(arg, res) in test.iter() { - test_case(arg, res, f); + for ((s, r), res) in test { + test_relative(s, r, res, f); } } @@ -443,10 +527,10 @@ mod tests { ((1.0, 1.0), 2.0), ((10.0, 10.0), 0.6324555320336758663997787), ((10.0, 1.0), 0.63245553203367586639977870), - ((10.0, INF), 0.6324555320336758), + ((10.0, f64::INFINITY), 0.6324555320336758), ]; - for &(arg, res) in test.iter() { - test_case(arg, res, f); + for ((s, r), res) in test { + test_relative(s, r, res, f); } } @@ -454,12 +538,16 @@ mod tests { fn test_mode() { let f = |x: Gamma| x.mode().unwrap(); let test = [((1.0, 0.1), 0.0), ((1.0, 1.0), 0.0)]; - for &(arg, res) in test.iter() { - test_case_special(arg, res, 10e-6, f); + for &((s, r), res) in test.iter() { + test_absolute(s, r, res, 10e-6, f); } - let test = [((10.0, 10.0), 0.9), ((10.0, 1.0), 9.0), ((10.0, INF), 0.0)]; - for &(arg, res) in test.iter() { - test_case(arg, res, f); + let test = [ + ((10.0, 10.0), 0.9), + ((10.0, 1.0), 9.0), + ((10.0, f64::INFINITY), 0.0), + ]; + for ((s, r), res) in test { + test_relative(s, r, res, f); } } @@ -471,21 +559,21 @@ mod tests { ((1.0, 1.0), 0.0), ((10.0, 10.0), 0.0), ((10.0, 1.0), 0.0), - ((10.0, INF), 0.0), + ((10.0, f64::INFINITY), 0.0), ]; - for &(arg, res) in test.iter() { - test_case(arg, res, f); + for ((s, r), res) in test { + test_relative(s, r, res, f); } let f = |x: Gamma| x.max(); let test = [ - ((1.0, 0.1), INF), - ((1.0, 1.0), INF), - ((10.0, 10.0), INF), - ((10.0, 1.0), INF), - ((10.0, INF), INF), + ((1.0, 0.1), f64::INFINITY), + ((1.0, 1.0), f64::INFINITY), + ((10.0, 10.0), f64::INFINITY), + ((10.0, 1.0), f64::INFINITY), + ((10.0, f64::INFINITY), f64::INFINITY), ]; - for &(arg, res) in test.iter() { - test_case(arg, res, f); + for ((s, r), res) in test { + test_relative(s, r, res, f); } } @@ -502,19 +590,19 @@ mod tests { ((10.0, 1.0), 1.0, 0.000001013777119630297402), ((10.0, 1.0), 10.0, 0.125110035721133298984764), ]; - for &(arg, x, res) in test.iter() { - test_case(arg, res, f(x)); + for ((s, r), x, res) in test { + test_relative(s, r, res, f(x)); } - //TODO: test special - // test_is_nan((10.0, INF), pdf(1.0)); // is this really the behavior we want? - //TODO: test special - // (10.0, INF, INF, 0.0, pdf(INF)),]; + // TODO: test special + // test_is_nan((10.0, f64::INFINITY), pdf(1.0)); // is this really the behavior we want? + // TODO: test special + // (10.0, f64::INFINITY, f64::INFINITY, 0.0, pdf(f64::INFINITY)),]; } #[test] fn test_pdf_at_zero() { - test_case((1.0, 0.1), 0.1, |x| x.pdf(0.0)); - test_case((1.0, 0.1), 0.1f64.ln(), |x| x.ln_pdf(0.0)); + test_relative(1.0, 0.1, 0.1, |x| x.pdf(0.0)); + test_relative(1.0, 0.1, 0.1f64.ln(), |x| x.ln_pdf(0.0)); } #[test] @@ -529,13 +617,13 @@ mod tests { ((10.0, 10.0), 10.0, -69.0527107131946016148658), ((10.0, 1.0), 1.0, -13.8018274800814696112077), ((10.0, 1.0), 10.0, -2.07856164313505845504579), - ((10.0, INF), INF, f64::NEG_INFINITY), + ((10.0, f64::INFINITY), f64::INFINITY, f64::NEG_INFINITY), ]; - for &(arg, x, res) in test.iter() { - test_case(arg, res, f(x)); + for ((s, r), x, res) in test { + test_relative(s, r, res, f(x)); } // TODO: test special - // test_is_nan((10.0, INF), f(1.0)); // is this really the behavior we want? + // test_is_nan((10.0, f64::INFINITY), f(1.0)); // is this really the behavior we want? } #[test] @@ -550,17 +638,43 @@ mod tests { ((10.0, 10.0), 10.0, 0.999999999999999999999999), ((10.0, 1.0), 1.0, 0.000000111425478338720677), ((10.0, 1.0), 10.0, 0.542070285528147791685835), - ((10.0, INF), 1.0, 0.0), - ((10.0, INF), 10.0, 1.0), + ((10.0, f64::INFINITY), 1.0, 0.0), + ((10.0, f64::INFINITY), 10.0, 1.0), ]; - for &(arg, x, res) in test.iter() { - test_case(arg, res, f(x)); + for ((s, r), x, res) in test { + test_relative(s, r, res, f(x)); } } #[test] fn test_cdf_at_zero() { - test_case((1.0, 0.1), 0.0, |x| x.cdf(0.0)); + test_relative(1.0, 0.1, 0.0, |x| x.cdf(0.0)); + } + + #[test] + fn test_cdf_inverse_identity() { + let f = |p: f64| move |g: Gamma| g.cdf(g.inverse_cdf(p)); + let params = [ + (1.0, 0.1), + (1.0, 1.0), + (10.0, 10.0), + (10.0, 1.0), + (100.0, 200.0), + ]; + + for (s, r) in params { + for n in -5..0 { + let p = 10.0f64.powi(n); + test_relative(s, r, p, f(p)); + } + } + + // test case from issue #200 + { + let x = 20.5567; + let f = |x: f64| move |g: Gamma| g.inverse_cdf(g.cdf(x)); + test_relative(3.0, 0.5, x, f(x)) + } } #[test] @@ -575,22 +689,22 @@ mod tests { ((10.0, 10.0), 10.0, 1.1253473960842808e-31), ((10.0, 1.0), 1.0, 0.9999998885745217), ((10.0, 1.0), 10.0, 0.4579297144718528), - ((10.0, INF), 1.0, 1.0), - ((10.0, INF), 10.0, 0.0), + ((10.0, f64::INFINITY), 1.0, 1.0), + ((10.0, f64::INFINITY), 10.0, 0.0), ]; - for &(arg, x, res) in test.iter() { - test_case(arg, res, f(x)); + for ((s, r), x, res) in test { + test_relative(s, r, res, f(x)); } } #[test] fn test_sf_at_zero() { - test_case((1.0, 0.1), 1.0, |x| x.sf(0.0)); + test_relative(1.0, 0.1, 1.0, |x| x.sf(0.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create((1.0, 0.5)), 0.0, 20.0); - test::check_continuous_distribution(&try_create((9.0, 2.0)), 0.0, 20.0); + test::check_continuous_distribution(&create_ok(1.0, 0.5), 0.0, 20.0); + test::check_continuous_distribution(&create_ok(9.0, 2.0), 0.0, 20.0); } } diff --git a/src/distribution/geometric.rs b/src/distribution/geometric.rs index a6e390d7..82af5eef 100644 --- a/src/distribution/geometric.rs +++ b/src/distribution/geometric.rs @@ -1,9 +1,6 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; -use rand::distributions::OpenClosed01; -use rand::Rng; -use std::{f64, u64}; +use std::f64; /// Implements the /// [Geometric](https://en.wikipedia.org/wiki/Geometric_distribution) @@ -20,11 +17,30 @@ use std::{f64, u64}; /// assert_eq!(n.pmf(1), 0.3); /// assert_eq!(n.pmf(2), 0.21); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Geometric { p: f64, } +/// Represents the errors that can occur when creating a [`Geometric`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum GeometricError { + /// The probability is NaN or not in `(0, 1]`. + ProbabilityInvalid, +} + +impl std::fmt::Display for GeometricError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + GeometricError::ProbabilityInvalid => write!(f, "Probability is NaN or not in (0, 1]"), + } + } +} + +impl std::error::Error for GeometricError {} + impl Geometric { /// Constructs a new shifted geometric distribution with a probability /// of `p` @@ -44,9 +60,9 @@ impl Geometric { /// result = Geometric::new(0.0); /// assert!(result.is_err()); /// ``` - pub fn new(p: f64) -> Result { + pub fn new(p: f64) -> Result { if p <= 0.0 || p > 1.0 || p.is_nan() { - Err(StatsError::BadParams) + Err(GeometricError::ProbabilityInvalid) } else { Ok(Geometric { p }) } @@ -68,8 +84,17 @@ impl Geometric { } } +impl std::fmt::Display for Geometric { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Geom({})", self.p) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Geometric { - fn sample(&self, r: &mut R) -> f64 { + fn sample(&self, r: &mut R) -> f64 { + use ::rand::distributions::OpenClosed01; + if ulps_eq!(self.p, 1.0) { 1.0 } else { @@ -85,7 +110,7 @@ impl DiscreteCDF for Geometric { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 - (1 - p) ^ x /// ``` fn cdf(&self, x: u64) -> f64 { @@ -104,7 +129,7 @@ impl DiscreteCDF for Geometric { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 - p) ^ x /// ``` fn sf(&self, x: u64) -> f64 { @@ -125,7 +150,7 @@ impl Min for Geometric { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 /// ``` fn min(&self) -> u64 { @@ -140,7 +165,7 @@ impl Max for Geometric { /// /// # Formula /// - /// ```ignore + /// ```text /// 2^63 - 1 /// ``` fn max(&self) -> u64 { @@ -153,38 +178,41 @@ impl Distribution for Geometric { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 / p /// ``` fn mean(&self) -> Option { Some(1.0 / self.p) } + /// Returns the standard deviation of the geometric distribution /// /// # Formula /// - /// ```ignore + /// ```text /// (1 - p) / p^2 /// ``` fn variance(&self) -> Option { Some((1.0 - self.p) / (self.p * self.p)) } + /// Returns the entropy of the geometric distribution /// /// # Formula /// - /// ```ignore + /// ```text /// (-(1 - p) * log_2(1 - p) - p * log_2(p)) / p /// ``` fn entropy(&self) -> Option { let inv = 1.0 / self.p; Some(-inv * (1. - self.p).log(2.0) + (inv - 1.).log(2.0)) } + /// Returns the skewness of the geometric distribution /// /// # Formula /// - /// ```ignore + /// ```text /// (2 - p) / sqrt(1 - p) /// ``` fn skewness(&self) -> Option { @@ -200,7 +228,7 @@ impl Mode> for Geometric { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 /// ``` fn mode(&self) -> Option { @@ -215,7 +243,7 @@ impl Median for Geometric { /// /// # Formula /// - /// ```ignore + /// ```text /// ceil(-1 / log_2(1 - p)) /// ``` fn median(&self) -> f64 { @@ -229,7 +257,7 @@ impl Discrete for Geometric { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 - p)^(x - 1) * p /// ``` fn pmf(&self, x: u64) -> f64 { @@ -245,7 +273,7 @@ impl Discrete for Geometric { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((1 - p)^(x - 1) * p) /// ``` fn ln_pmf(&self, x: u64) -> f64 { @@ -262,176 +290,130 @@ impl Discrete for Geometric { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use std::fmt::Debug; - use crate::statistics::*; - use crate::distribution::{DiscreteCDF, Discrete, Geometric}; + use super::*; use crate::distribution::internal::*; - use crate::consts::ACC; + use crate::testing_boiler; - fn try_create(p: f64) -> Geometric { - let n = Geometric::new(p); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(p: f64) { - let n = try_create(p); - assert_eq!(p, n.p()); - } - - fn bad_create_case(p: f64) { - let n = Geometric::new(p); - assert!(n.is_err()); - } - - fn get_value(p: f64, eval: F) -> T - where T: PartialEq + Debug, - F: Fn(Geometric) -> T - { - let n = try_create(p); - eval(n) - } - - fn test_case(p: f64, expected: T, eval: F) - where T: PartialEq + Debug, - F: Fn(Geometric) -> T - { - let x = get_value(p, eval); - assert_eq!(expected, x); - } - - fn test_almost(p: f64, expected: f64, acc: f64, eval: F) - where F: Fn(Geometric) -> f64 - { - let x = get_value(p, eval); - assert_almost_eq!(expected, x, acc); - } - - fn test_is_nan(p: f64, eval: F) - where F: Fn(Geometric) -> f64 - { - let x = get_value(p, eval); - assert!(x.is_nan()); - } + testing_boiler!(p: f64; Geometric; GeometricError); #[test] fn test_create() { - create_case(0.3); - create_case(1.0); + create_ok(0.3); + create_ok(1.0); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN); - bad_create_case(0.0); - bad_create_case(-1.0); - bad_create_case(2.0); + create_err(f64::NAN); + create_err(0.0); + create_err(-1.0); + create_err(2.0); } #[test] fn test_mean() { let mean = |x: Geometric| x.mean().unwrap(); - test_case(0.3, 1.0 / 0.3, mean); - test_case(1.0, 1.0, mean); + test_exact(0.3, 1.0 / 0.3, mean); + test_exact(1.0, 1.0, mean); } #[test] fn test_variance() { let variance = |x: Geometric| x.variance().unwrap(); - test_case(0.3, 0.7 / (0.3 * 0.3), variance); - test_case(1.0, 0.0, variance); + test_exact(0.3, 0.7 / (0.3 * 0.3), variance); + test_exact(1.0, 0.0, variance); } #[test] fn test_entropy() { let entropy = |x: Geometric| x.entropy().unwrap(); - test_almost(0.3, 2.937636330768973333333, 1e-14, entropy); + test_absolute(0.3, 2.937636330768973333333, 1e-14, entropy); test_is_nan(1.0, entropy); } #[test] fn test_skewness() { let skewness = |x: Geometric| x.skewness().unwrap(); - test_almost(0.3, 2.031888635868469187947, 1e-15, skewness); - test_case(1.0, f64::INFINITY, skewness); + test_absolute(0.3, 2.031888635868469187947, 1e-15, skewness); + test_exact(1.0, f64::INFINITY, skewness); } #[test] fn test_median() { let median = |x: Geometric| x.median(); - test_case(0.0001, 6932.0, median); - test_case(0.1, 7.0, median); - test_case(0.3, 2.0, median); - test_case(0.9, 1.0, median); - // test_case(0.99, 1.0, median); - test_case(1.0, 0.0, median); + test_exact(0.0001, 6932.0, median); + test_exact(0.1, 7.0, median); + test_exact(0.3, 2.0, median); + test_exact(0.9, 1.0, median); + // test_exact(0.99, 1.0, median); + test_exact(1.0, 0.0, median); } #[test] fn test_mode() { let mode = |x: Geometric| x.mode().unwrap(); - test_case(0.3, 1, mode); - test_case(1.0, 1, mode); + test_exact(0.3, 1, mode); + test_exact(1.0, 1, mode); } #[test] fn test_min_max() { let min = |x: Geometric| x.min(); let max = |x: Geometric| x.max(); - test_case(0.3, 1, min); - test_case(0.3, u64::MAX, max); + test_exact(0.3, 1, min); + test_exact(0.3, u64::MAX, max); } #[test] fn test_pmf() { let pmf = |arg: u64| move |x: Geometric| x.pmf(arg); - test_case(0.3, 0.3, pmf(1)); - test_case(0.3, 0.21, pmf(2)); - test_case(1.0, 1.0, pmf(1)); - test_case(1.0, 0.0, pmf(2)); - test_almost(0.5, 0.5, 1e-10, pmf(1)); - test_almost(0.5, 0.25, 1e-10, pmf(2)); + test_exact(0.3, 0.3, pmf(1)); + test_exact(0.3, 0.21, pmf(2)); + test_exact(1.0, 1.0, pmf(1)); + test_exact(1.0, 0.0, pmf(2)); + test_absolute(0.5, 0.5, 1e-10, pmf(1)); + test_absolute(0.5, 0.25, 1e-10, pmf(2)); } #[test] fn test_pmf_lower_bound() { let pmf = |arg: u64| move |x: Geometric| x.pmf(arg); - test_case(0.3, 0.0, pmf(0)); + test_exact(0.3, 0.0, pmf(0)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: u64| move |x: Geometric| x.ln_pmf(arg); - test_almost(0.3, -1.203972804325935992623, 1e-15, ln_pmf(1)); - test_almost(0.3, -1.560647748264668371535, 1e-15, ln_pmf(2)); - test_case(1.0, 0.0, ln_pmf(1)); - test_case(1.0, f64::NEG_INFINITY, ln_pmf(2)); + test_absolute(0.3, -1.203972804325935992623, 1e-15, ln_pmf(1)); + test_absolute(0.3, -1.560647748264668371535, 1e-15, ln_pmf(2)); + test_exact(1.0, 0.0, ln_pmf(1)); + test_exact(1.0, f64::NEG_INFINITY, ln_pmf(2)); } #[test] fn test_ln_pmf_lower_bound() { let ln_pmf = |arg: u64| move |x: Geometric| x.ln_pmf(arg); - test_case(0.3, f64::NEG_INFINITY, ln_pmf(0)); + test_exact(0.3, f64::NEG_INFINITY, ln_pmf(0)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: Geometric| x.cdf(arg); - test_case(1.0, 1.0, cdf(1)); - test_case(1.0, 1.0, cdf(2)); - test_almost(0.5, 0.5, 1e-15, cdf(1)); - test_almost(0.5, 0.75, 1e-15, cdf(2)); + test_exact(1.0, 1.0, cdf(1)); + test_exact(1.0, 1.0, cdf(2)); + test_absolute(0.5, 0.5, 1e-15, cdf(1)); + test_absolute(0.5, 0.75, 1e-15, cdf(2)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: Geometric| x.sf(arg); - test_case(1.0, 0.0, sf(1)); - test_case(1.0, 0.0, sf(2)); - test_almost(0.5, 0.5, 1e-15, sf(1)); - test_almost(0.5, 0.25, 1e-15, sf(2)); + test_exact(1.0, 0.0, sf(1)); + test_exact(1.0, 0.0, sf(2)); + test_absolute(0.5, 0.5, 1e-15, sf(1)); + test_absolute(0.5, 0.25, 1e-15, sf(2)); } #[test] @@ -503,19 +485,19 @@ mod tests { #[test] fn test_cdf_lower_bound() { let cdf = |arg: u64| move |x: Geometric| x.cdf(arg); - test_case(0.3, 0.0, cdf(0)); + test_exact(0.3, 0.0, cdf(0)); } #[test] fn test_sf_lower_bound() { let sf = |arg: u64| move |x: Geometric| x.sf(arg); - test_case(0.3, 1.0, sf(0)); + test_exact(0.3, 1.0, sf(0)); } #[test] fn test_discrete() { - test::check_discrete_distribution(&try_create(0.3), 100); - test::check_discrete_distribution(&try_create(0.6), 100); - test::check_discrete_distribution(&try_create(1.0), 1); + test::check_discrete_distribution(&create_ok(0.3), 100); + test::check_discrete_distribution(&create_ok(0.6), 100); + test::check_discrete_distribution(&create_ok(1.0), 1); } } diff --git a/src/distribution/hypergeometric.rs b/src/distribution/hypergeometric.rs index 0ac8e750..ac39917d 100644 --- a/src/distribution/hypergeometric.rs +++ b/src/distribution/hypergeometric.rs @@ -1,35 +1,52 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::function::factorial; use crate::statistics::*; -use crate::{Result, StatsError}; -use rand::Rng; use std::cmp; use std::f64; /// Implements the /// [Hypergeometric](http://en.wikipedia.org/wiki/Hypergeometric_distribution) /// distribution -/// -/// # Examples -/// -/// ``` -/// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +// TODO: Add examples +#[derive(Copy, Clone, PartialEq, Eq, Debug)] pub struct Hypergeometric { population: u64, successes: u64, draws: u64, } +/// Represents the errors that can occur when creating a [`Hypergeometric`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum HypergeometricError { + /// The number of successes is greater than the population. + TooManySuccesses, + + /// The number of draws is greater than the population. + TooManyDraws, +} + +impl std::fmt::Display for HypergeometricError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + HypergeometricError::TooManySuccesses => write!(f, "successes > population"), + HypergeometricError::TooManyDraws => write!(f, "draws > population"), + } + } +} + +impl std::error::Error for HypergeometricError {} + impl Hypergeometric { /// Constructs a new hypergeometric distribution /// with a population (N) of `population`, number /// of successes (K) of `successes`, and number of draws - /// (n) of `draws` + /// (n) of `draws`. /// /// # Errors /// - /// If `successes > population` or `draws > population` + /// If `successes > population` or `draws > population`. /// /// # Examples /// @@ -42,16 +59,24 @@ impl Hypergeometric { /// result = Hypergeometric::new(2, 3, 2); /// assert!(result.is_err()); /// ``` - pub fn new(population: u64, successes: u64, draws: u64) -> Result { - if successes > population || draws > population { - Err(StatsError::BadParams) - } else { - Ok(Hypergeometric { - population, - successes, - draws, - }) + pub fn new( + population: u64, + successes: u64, + draws: u64, + ) -> Result { + if successes > population { + return Err(HypergeometricError::TooManySuccesses); + } + + if draws > population { + return Err(HypergeometricError::TooManyDraws); } + + Ok(Hypergeometric { + population, + successes, + draws, + }) } /// Returns the population size of the hypergeometric @@ -110,8 +135,19 @@ impl Hypergeometric { } } +impl std::fmt::Display for Hypergeometric { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Hypergeometric({},{},{})", + self.population, self.successes, self.draws + ) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Hypergeometric { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { let mut population = self.population as f64; let mut successes = self.successes as f64; let mut draws = self.draws; @@ -139,18 +175,17 @@ impl DiscreteCDF for Hypergeometric { /// /// # Formula /// - /// ```ignore - /// 1 - ((n choose k+1) * (N-n choose K-k-1)) / (N choose K) * 3_F_2(1, - /// k+1-K, k+1-n; k+2, N+k+2-K-n; 1) + /// ```text + /// 1 - ((n choose x+1) * (N-n choose K-x-1)) / (N choose K) * 3_F_2(1, + /// x+1-K, x+1-n; k+2, N+x+2-K-n; 1) /// ``` /// /// where `N` is population, `K` is successes, `n` is draws, - /// and `p_F_q` is the [generalized hypergeometric - /// function](https://en.wikipedia. - /// org/wiki/Generalized_hypergeometric_function) + /// and `p_F_q` is the + /// [generalized hypergeometric function](https://en.wikipedia.org/wiki/Generalized_hypergeometric_function) /// /// Calculated as a discrete integral over the probability mass - /// function evaluated from 0..k+1 + /// function evaluated from 0..x+1 fn cdf(&self, x: u64) -> f64 { if x < self.min() { 0.0 @@ -173,18 +208,17 @@ impl DiscreteCDF for Hypergeometric { /// /// # Formula /// - /// ```ignore - /// 1 - ((n choose k+1) * (N-n choose K-k-1)) / (N choose K) * 3_F_2(1, - /// k+1-K, k+1-n; k+2, N+k+2-K-n; 1) + /// ```text + /// 1 - ((n choose x+1) * (N-n choose K-x-1)) / (N choose K) * 3_F_2(1, + /// x+1-K, x+1-n; x+2, N+x+2-K-n; 1) /// ``` /// /// where `N` is population, `K` is successes, `n` is draws, - /// and `p_F_q` is the [generalized hypergeometric - /// function](https://en.wikipedia. - /// org/wiki/Generalized_hypergeometric_function) + /// and `p_F_q` is the + /// [generalized hypergeometric function](https://en.wikipedia.org/wiki/Generalized_hypergeometric_function) /// /// Calculated as a discrete integral over the probability mass - /// function evaluated from (k+1)..max + /// function evaluated from (x+1)..max fn sf(&self, x: u64) -> f64 { if x < self.min() { 1.0 @@ -193,7 +227,7 @@ impl DiscreteCDF for Hypergeometric { } else { let k = x; let ln_denom = factorial::ln_binomial(self.population, self.draws); - (k + 1 .. self.max() + 1).fold(0.0, |acc, i| { + (k + 1..=self.max()).fold(0.0, |acc, i| { acc + (factorial::ln_binomial(self.successes, i) + factorial::ln_binomial(self.population - self.successes, self.draws - i) - ln_denom) @@ -210,7 +244,7 @@ impl Min for Hypergeometric { /// /// # Formula /// - /// ```ignore + /// ```text /// max(0, n + K - N) /// ``` /// @@ -227,7 +261,7 @@ impl Max for Hypergeometric { /// /// # Formula /// - /// ```ignore + /// ```text /// min(K, n) /// ``` /// @@ -246,7 +280,7 @@ impl Distribution for Hypergeometric { /// /// # Formula /// - /// ```ignore + /// ```text /// K * n / N /// ``` /// @@ -258,6 +292,7 @@ impl Distribution for Hypergeometric { Some(self.successes as f64 * self.draws as f64 / self.population as f64) } } + /// Returns the variance of the hypergeometric distribution /// /// # None @@ -266,7 +301,7 @@ impl Distribution for Hypergeometric { /// /// # Formula /// - /// ```ignore + /// ```text /// n * (K / N) * ((N - K) / N) * ((N - n) / (N - 1)) /// ``` /// @@ -281,6 +316,7 @@ impl Distribution for Hypergeometric { Some(val) } } + /// Returns the skewness of the hypergeometric distribution /// /// # None @@ -289,7 +325,7 @@ impl Distribution for Hypergeometric { /// /// # Formula /// - /// ```ignore + /// ```text /// ((N - 2K) * (N - 1)^(1 / 2) * (N - 2n)) / ([n * K * (N - K) * (N - /// n)]^(1 / 2) * (N - 2)) /// ``` @@ -315,7 +351,7 @@ impl Mode> for Hypergeometric { /// /// # Formula /// - /// ```ignore + /// ```text /// floor((n + 1) * (k + 1) / (N + 2)) /// ``` /// @@ -331,7 +367,7 @@ impl Discrete for Hypergeometric { /// /// # Formula /// - /// ```ignore + /// ```text /// (K choose x) * (N-K choose n-x) / (N choose n) /// ``` /// @@ -351,7 +387,7 @@ impl Discrete for Hypergeometric { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((K choose x) * (N-K choose n-x) / (N choose n)) /// ``` /// @@ -364,229 +400,182 @@ impl Discrete for Hypergeometric { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use std::fmt::Debug; - use crate::statistics::*; - use crate::distribution::{DiscreteCDF, Discrete, Hypergeometric}; + use super::*; use crate::distribution::internal::*; - use crate::consts::ACC; - - fn try_create(population: u64, successes: u64, draws: u64) -> Hypergeometric { - let n = Hypergeometric::new(population, successes, draws); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(population: u64, successes: u64, draws: u64) { - let n = try_create(population, successes, draws); - assert_eq!(population, n.population()); - assert_eq!(successes, n.successes()); - assert_eq!(draws, n.draws()); - } - - fn bad_create_case(population: u64, successes: u64, draws: u64) { - let n = Hypergeometric::new(population, successes, draws); - assert!(n.is_err()); - } + use crate::testing_boiler; - fn get_value(population: u64, successes: u64, draws: u64, eval: F) -> T - where T: PartialEq + Debug, - F: Fn(Hypergeometric) -> T - { - let n = try_create(population, successes, draws); - eval(n) - } - - fn test_case(population: u64, successes: u64, draws: u64, expected: T, eval: F) - where T: PartialEq + Debug, - F: Fn(Hypergeometric) -> T - { - let x = get_value(population, successes, draws, eval); - assert_eq!(expected, x); - } - - fn test_almost(population: u64, successes: u64, draws: u64, expected: f64, acc: f64, eval: F) - where F: Fn(Hypergeometric) -> f64 - { - let x = get_value(population, successes, draws, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(population: u64, successes: u64, draws: u64; Hypergeometric; HypergeometricError); #[test] fn test_create() { - create_case(0, 0, 0); - create_case(1, 1, 1,); - create_case(2, 1, 1); - create_case(2, 2, 2); - create_case(10, 1, 1); - create_case(10, 5, 3); + create_ok(0, 0, 0); + create_ok(1, 1, 1,); + create_ok(2, 1, 1); + create_ok(2, 2, 2); + create_ok(10, 1, 1); + create_ok(10, 5, 3); } #[test] fn test_bad_create() { - bad_create_case(2, 3, 2); - bad_create_case(10, 5, 20); - bad_create_case(0, 1, 1); + test_create_err(2, 3, 2, HypergeometricError::TooManySuccesses); + test_create_err(10, 5, 20, HypergeometricError::TooManyDraws); + create_err(0, 1, 1); } #[test] fn test_mean() { let mean = |x: Hypergeometric| x.mean().unwrap(); - test_case(1, 1, 1, 1.0, mean); - test_case(2, 1, 1, 0.5, mean); - test_case(2, 2, 2, 2.0, mean); - test_case(10, 1, 1, 0.1, mean); - test_case(10, 5, 3, 15.0 / 10.0, mean); + test_exact(1, 1, 1, 1.0, mean); + test_exact(2, 1, 1, 0.5, mean); + test_exact(2, 2, 2, 2.0, mean); + test_exact(10, 1, 1, 0.1, mean); + test_exact(10, 5, 3, 15.0 / 10.0, mean); } #[test] - #[should_panic] fn test_mean_with_population_0() { - let mean = |x: Hypergeometric| x.mean().unwrap(); - get_value(0, 0, 0, mean); + test_none(0, 0, 0, |dist| dist.mean()); } #[test] fn test_variance() { let variance = |x: Hypergeometric| x.variance().unwrap(); - test_case(2, 1, 1, 0.25, variance); - test_case(2, 2, 2, 0.0, variance); - test_case(10, 1, 1, 81.0 / 900.0, variance); - test_case(10, 5, 3, 525.0 / 900.0, variance); + test_exact(2, 1, 1, 0.25, variance); + test_exact(2, 2, 2, 0.0, variance); + test_exact(10, 1, 1, 81.0 / 900.0, variance); + test_exact(10, 5, 3, 525.0 / 900.0, variance); } #[test] - #[should_panic] fn test_variance_with_pop_lte_1() { - let variance = |x: Hypergeometric| x.variance().unwrap(); - get_value(1, 1, 1, variance); + test_none(1, 1, 1, |dist| dist.variance()); } #[test] fn test_skewness() { let skewness = |x: Hypergeometric| x.skewness().unwrap(); - test_case(10, 1, 1, 8.0 / 3.0, skewness); - test_case(10, 5, 3, 0.0, skewness); + test_exact(10, 1, 1, 8.0 / 3.0, skewness); + test_exact(10, 5, 3, 0.0, skewness); } #[test] - #[should_panic] fn test_skewness_with_pop_lte_2() { - let skewness = |x: Hypergeometric| x.skewness().unwrap(); - get_value(2, 2, 2, skewness); + test_none(2, 2, 2, |dist| dist.skewness()); } #[test] fn test_mode() { let mode = |x: Hypergeometric| x.mode().unwrap(); - test_case(0, 0, 0, 0, mode); - test_case(1, 1, 1, 1, mode); - test_case(2, 1, 1, 1, mode); - test_case(2, 2, 2, 2, mode); - test_case(10, 1, 1, 0, mode); - test_case(10, 5, 3, 2, mode); + test_exact(0, 0, 0, 0, mode); + test_exact(1, 1, 1, 1, mode); + test_exact(2, 1, 1, 1, mode); + test_exact(2, 2, 2, 2, mode); + test_exact(10, 1, 1, 0, mode); + test_exact(10, 5, 3, 2, mode); } #[test] fn test_min() { let min = |x: Hypergeometric| x.min(); - test_case(0, 0, 0, 0, min); - test_case(1, 1, 1, 1, min); - test_case(2, 1, 1, 0, min); - test_case(2, 2, 2, 2, min); - test_case(10, 1, 1, 0, min); - test_case(10, 5, 3, 0, min); + test_exact(0, 0, 0, 0, min); + test_exact(1, 1, 1, 1, min); + test_exact(2, 1, 1, 0, min); + test_exact(2, 2, 2, 2, min); + test_exact(10, 1, 1, 0, min); + test_exact(10, 5, 3, 0, min); } #[test] fn test_max() { let max = |x: Hypergeometric| x.max(); - test_case(0, 0, 0, 0, max); - test_case(1, 1, 1, 1, max); - test_case(2, 1, 1, 1, max); - test_case(2, 2, 2, 2, max); - test_case(10, 1, 1, 1, max); - test_case(10, 5, 3, 3, max); + test_exact(0, 0, 0, 0, max); + test_exact(1, 1, 1, 1, max); + test_exact(2, 1, 1, 1, max); + test_exact(2, 2, 2, 2, max); + test_exact(10, 1, 1, 1, max); + test_exact(10, 5, 3, 3, max); } #[test] fn test_pmf() { let pmf = |arg: u64| move |x: Hypergeometric| x.pmf(arg); - test_case(0, 0, 0, 1.0, pmf(0)); - test_case(1, 1, 1, 1.0, pmf(1)); - test_case(2, 1, 1, 0.5, pmf(0)); - test_case(2, 1, 1, 0.5, pmf(1)); - test_case(2, 2, 2, 1.0, pmf(2)); - test_case(10, 1, 1, 0.9, pmf(0)); - test_case(10, 1, 1, 0.1, pmf(1)); - test_case(10, 5, 3, 0.41666666666666666667, pmf(1)); - test_case(10, 5, 3, 0.083333333333333333333, pmf(3)); + test_exact(0, 0, 0, 1.0, pmf(0)); + test_exact(1, 1, 1, 1.0, pmf(1)); + test_exact(2, 1, 1, 0.5, pmf(0)); + test_exact(2, 1, 1, 0.5, pmf(1)); + test_exact(2, 2, 2, 1.0, pmf(2)); + test_exact(10, 1, 1, 0.9, pmf(0)); + test_exact(10, 1, 1, 0.1, pmf(1)); + test_exact(10, 5, 3, 0.41666666666666666667, pmf(1)); + test_exact(10, 5, 3, 0.083333333333333333333, pmf(3)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: u64| move |x: Hypergeometric| x.ln_pmf(arg); - test_case(0, 0, 0, 0.0, ln_pmf(0)); - test_case(1, 1, 1, 0.0, ln_pmf(1)); - test_case(2, 1, 1, -0.6931471805599453094172, ln_pmf(0)); - test_case(2, 1, 1, -0.6931471805599453094172, ln_pmf(1)); - test_case(2, 2, 2, 0.0, ln_pmf(2)); - test_almost(10, 1, 1, -0.1053605156578263012275, 1e-14, ln_pmf(0)); - test_almost(10, 1, 1, -2.302585092994045684018, 1e-14, ln_pmf(1)); - test_almost(10, 5, 3, -0.875468737353899935621, 1e-14, ln_pmf(1)); - test_almost(10, 5, 3, -2.484906649788000310234, 1e-14, ln_pmf(3)); + test_exact(0, 0, 0, 0.0, ln_pmf(0)); + test_exact(1, 1, 1, 0.0, ln_pmf(1)); + test_exact(2, 1, 1, -0.6931471805599453094172, ln_pmf(0)); + test_exact(2, 1, 1, -0.6931471805599453094172, ln_pmf(1)); + test_exact(2, 2, 2, 0.0, ln_pmf(2)); + test_absolute(10, 1, 1, -0.1053605156578263012275, 1e-14, ln_pmf(0)); + test_absolute(10, 1, 1, -2.302585092994045684018, 1e-14, ln_pmf(1)); + test_absolute(10, 5, 3, -0.875468737353899935621, 1e-14, ln_pmf(1)); + test_absolute(10, 5, 3, -2.484906649788000310234, 1e-14, ln_pmf(3)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: Hypergeometric| x.cdf(arg); - test_case(2, 1, 1, 0.5, cdf(0)); - test_almost(10, 1, 1, 0.9, 1e-14, cdf(0)); - test_almost(10, 5, 3, 0.5, 1e-15, cdf(1)); - test_almost(10, 5, 3, 11.0 / 12.0, 1e-14, cdf(2)); - test_almost(10000, 2, 9800, 199.0 / 499950.0, 1e-14, cdf(0)); - test_almost(10000, 2, 9800, 19799.0 / 499950.0, 1e-12, cdf(1)); + test_exact(2, 1, 1, 0.5, cdf(0)); + test_absolute(10, 1, 1, 0.9, 1e-14, cdf(0)); + test_absolute(10, 5, 3, 0.5, 1e-15, cdf(1)); + test_absolute(10, 5, 3, 11.0 / 12.0, 1e-14, cdf(2)); + test_absolute(10000, 2, 9800, 199.0 / 499950.0, 1e-14, cdf(0)); + test_absolute(10000, 2, 9800, 19799.0 / 499950.0, 1e-12, cdf(1)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: Hypergeometric| x.sf(arg); - test_case(2, 1, 1, 0.5, sf(0)); - test_almost(10, 1, 1, 0.1, 1e-14, sf(0)); - test_almost(10, 5, 3, 0.5, 1e-15, sf(1)); - test_almost(10, 5, 3, 1.0 / 12.0, 1e-14, sf(2)); - test_almost(10000, 2, 9800, 499751. / 499950.0, 1e-10, sf(0)); - test_almost(10000, 2, 9800, 480151. / 499950.0, 1e-10, sf(1)); + test_exact(2, 1, 1, 0.5, sf(0)); + test_absolute(10, 1, 1, 0.1, 1e-14, sf(0)); + test_absolute(10, 5, 3, 0.5, 1e-15, sf(1)); + test_absolute(10, 5, 3, 1.0 / 12.0, 1e-14, sf(2)); + test_absolute(10000, 2, 9800, 499751. / 499950.0, 1e-10, sf(0)); + test_absolute(10000, 2, 9800, 480151. / 499950.0, 1e-10, sf(1)); } #[test] fn test_cdf_arg_too_big() { let cdf = |arg: u64| move |x: Hypergeometric| x.cdf(arg); - test_case(0, 0, 0, 1.0, cdf(0)); + test_exact(0, 0, 0, 1.0, cdf(0)); } #[test] fn test_cdf_arg_too_small() { let cdf = |arg: u64| move |x: Hypergeometric| x.cdf(arg); - test_case(2, 2, 2, 0.0, cdf(0)); + test_exact(2, 2, 2, 0.0, cdf(0)); } #[test] fn test_sf_arg_too_big() { let sf = |arg: u64| move |x: Hypergeometric| x.sf(arg); - test_case(0, 0, 0, 0.0, sf(0)); + test_exact(0, 0, 0, 0.0, sf(0)); } #[test] fn test_sf_arg_too_small() { let sf = |arg: u64| move |x: Hypergeometric| x.sf(arg); - test_case(2, 2, 2, 1.0, sf(0)); + test_exact(2, 2, 2, 1.0, sf(0)); } #[test] fn test_discrete() { - test::check_discrete_distribution(&try_create(5, 4, 3), 4); - test::check_discrete_distribution(&try_create(3, 2, 1), 2); + test::check_discrete_distribution(&create_ok(5, 4, 3), 4); + test::check_discrete_distribution(&create_ok(3, 2, 1), 2); } } diff --git a/src/distribution/internal.rs b/src/distribution/internal.rs index 95c93872..9e7651b0 100644 --- a/src/distribution/internal.rs +++ b/src/distribution/internal.rs @@ -1,77 +1,355 @@ -/// Returns true if there are no elements in `x` in `arr` -/// such that `x <= 0.0` or `x` is `f64::NAN` and `sum(arr) > 0.0`. -/// IF `incl_zero` is true, it tests for `x < 0.0` instead of `x <= 0.0` -pub fn is_valid_multinomial(arr: &[f64], incl_zero: bool) -> bool { - let mut sum = 0.0; - for &elt in arr { - if incl_zero && elt < 0.0 || !incl_zero && elt <= 0.0 || elt.is_nan() { - return false; - } - sum += elt; +use num_traits::Num; + +/// Implements univariate function bisection searching for criteria +/// ```text +/// smallest k such that f(k) >= z +/// ``` +/// Evaluates to `None` if +/// - provided interval has lower bound greater than upper bound +/// - function found not semi-monotone on the provided interval containing `z` +/// +/// Evaluates to `Some(k)`, where `k` satisfies the search criteria +pub fn integral_bisection_search( + f: impl Fn(&K) -> T, + z: T, + lb: K, + ub: K, +) -> Option { + if !(f(&lb)..=f(&ub)).contains(&z) { + return None; + } + let two = K::one() + K::one(); + let mut lb = lb; + let mut ub = ub; + loop { + let mid = (lb.clone() + ub.clone()) / two.clone(); + if !(f(&lb)..=f(&ub)).contains(&f(&mid)) { + // if f found not monotone on the interval + return None; + } else if f(&lb) == z { + return Some(lb); + } else if f(&ub) == z { + return Some(ub); + } else if (lb.clone() + K::one()) == ub { + // no more elements to search + return Some(ub); + } else if f(&mid) >= z { + ub = mid; + } else { + lb = mid; + } } - sum != 0.0 } #[macro_use] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] pub mod test { - use super::is_valid_multinomial; - use crate::consts::ACC; + use super::*; use crate::distribution::{Continuous, ContinuousCDF, Discrete, DiscreteCDF}; #[macro_export] macro_rules! testing_boiler { - ($arg:ty, $dist:ty) => { - fn try_create(arg: $arg) -> $dist { - let n = <$dist>::new.call_once(arg); - assert!(n.is_ok()); - n.unwrap() + ($($arg_name:ident: $arg_ty:ty),+; $dist:ty; $dist_err:ty) => { + fn make_param_text($($arg_name: $arg_ty),+) -> String { + // "" + let mut param_text = String::new(); + + // "shape=10.0, rate=NaN, " + $( + param_text.push_str( + &format!( + "{}={:?}, ", + stringify!($arg_name), + $arg_name, + ) + ); + )+ + + // "shape=10.0, rate=NaN" (removes trailing comma and whitespace) + param_text.pop(); + param_text.pop(); + + param_text + } + + /// Creates and returns a distribution with the given parameters, + /// panicking if `::new` fails. + fn create_ok($($arg_name: $arg_ty),+) -> $dist { + match <$dist>::new($($arg_name),+) { + Ok(d) => d, + Err(e) => panic!( + "{}::new was expected to succeed, but failed for {} with error: '{}'", + stringify!($dist), + make_param_text($($arg_name),+), + e + ) + } } - fn bad_create_case(arg: $arg) { - let n = <$dist>::new.call(arg); - assert!(n.is_err()); + /// Returns the error when creating a distribution with the given parameters, + /// panicking if `::new` succeeds. + #[allow(dead_code)] + fn create_err($($arg_name: $arg_ty),+) -> $dist_err { + match <$dist>::new($($arg_name),+) { + Err(e) => e, + Ok(d) => panic!( + "{}::new was expected to fail, but succeeded for {} with result: {:?}", + stringify!($dist), + make_param_text($($arg_name),+), + d + ) + } } - fn get_value(arg: $arg, eval: F) -> T + /// Creates a distribution with the given parameters, calls the `get_fn` + /// function with the new distribution and returns the result of `get_fn`. + /// + /// Panics if `::new` fails. + fn create_and_get($($arg_name: $arg_ty),+, get_fn: F) -> T where F: Fn($dist) -> T, { - let n = try_create(arg); - eval(n) + let n = create_ok($($arg_name),+); + get_fn(n) } - fn test_case(arg: $arg, expected: T, eval: F) + /// Creates a distribution with the given parameters, calls the `get_fn` + /// function with the new distribution and compares the result of `get_fn` + /// to `expected` exactly. + /// + /// Panics if `::new` fails. + #[allow(dead_code)] + fn test_exact($($arg_name: $arg_ty),+, expected: T, get_fn: F) where F: Fn($dist) -> T, - T: ::core::fmt::Debug + ::approx::RelativeEq, + T: ::core::cmp::PartialEq + ::core::fmt::Debug { - let x = get_value(arg, eval); - assert_relative_eq!(expected, x, max_relative = ACC); + let x = create_and_get($($arg_name),+, get_fn); + if x != expected { + panic!( + "Expected {:?}, got {:?} for {}", + expected, + x, + make_param_text($($arg_name),+) + ); + } } - #[allow(dead_code)] // This is not used by all distributions. - fn test_case_special(arg: $arg, expected: T, acc: f64, eval: F) + /// Gets a value for the given parameters by calling `create_and_get` + /// and compares it to `expected`. + /// + /// Allows relative error of up to [`crate::consts::ACC`]. + /// + /// Panics if `::new` fails. + #[allow(dead_code)] + fn test_relative($($arg_name: $arg_ty),+, expected: f64, get_fn: F) where - F: Fn($dist) -> T, - T: ::core::fmt::Debug + ::approx::AbsDiffEq, + F: Fn($dist) -> f64, + { + let x = create_and_get($($arg_name),+, get_fn); + let max_relative = $crate::consts::ACC; + + if !::approx::relative_eq!(expected, x, max_relative = max_relative) { + panic!( + "Expected {:?} to be almost equal to {:?} (max. relative error of {:?}), but wasn't for {}", + x, + expected, + max_relative, + make_param_text($($arg_name),+) + ); + } + } + + /// Gets a value for the given parameters by calling `create_and_get` + /// and compares it to `expected`. + /// + /// Allows absolute error of up to `acc`. + /// + /// Panics if `::new` fails. + #[allow(dead_code)] + fn test_absolute($($arg_name: $arg_ty),+, expected: f64, acc: f64, get_fn: F) + where + F: Fn($dist) -> f64, + { + let x = create_and_get($($arg_name),+, get_fn); + + // abs_diff_eq! cannot handle infinities, so we manually accept them here + if expected.is_infinite() && x == expected { + return; + } + + if !::approx::abs_diff_eq!(expected, x, epsilon = acc) { + panic!( + "Expected {:?} to be almost equal to {:?} (max. absolute error of {:?}), but wasn't for {}", + x, + expected, + acc, + make_param_text($($arg_name),+) + ); + } + } + + /// Purposely fails creating a distribution with the given + /// parameters and compares the returned error to `expected`. + /// + /// Panics if `::new` succeeds. + #[allow(dead_code)] + fn test_create_err($($arg_name: $arg_ty),+, expected: $dist_err) { - let x = get_value(arg, eval); - assert_abs_diff_eq!(expected, x, epsilon = acc); + let err = create_err($($arg_name),+); + if err != expected { + panic!( + "{}::new was expected to fail with error {:?}, but failed with error {:?} for {}", + stringify!($dist), + expected, + err, + make_param_text($($arg_name),+) + ) + } } - #[allow(dead_code)] // This is not used by all distributions. - fn test_none(arg: $arg, eval: F) + /// Gets a value for the given parameters by calling `create_and_get` + /// and asserts that it is [`NAN`]. + /// + /// Panics if `::new` fails. + #[allow(dead_code)] + fn test_is_nan($($arg_name: $arg_ty),+, get_fn: F) + where + F: Fn($dist) -> f64 + { + let x = create_and_get($($arg_name),+, get_fn); + assert!(x.is_nan()); + } + + /// Gets a value for the given parameters by calling `create_and_get` + /// and asserts that it is [`None`]. + /// + /// Panics if `::new` fails. + #[allow(dead_code)] + fn test_none($($arg_name: $arg_ty),+, get_fn: F) where F: Fn($dist) -> Option, - T: ::core::cmp::PartialEq + ::core::fmt::Debug, + T: ::core::fmt::Debug, { - let x = get_value(arg, eval); - assert_eq!(None, x); + let x = create_and_get($($arg_name),+, get_fn); + + if let Some(inner) = x { + panic!( + "Expected None, got {:?} for {}", + inner, + make_param_text($($arg_name),+) + ) + } + } + + /// Asserts that associated error type is Send and Sync + #[test] + fn test_error_is_sync_send() { + fn assert_sync_send() {} + assert_sync_send::<$dist_err>(); } }; } + pub mod boiler_tests { + use crate::distribution::{Beta, BetaError}; + use crate::statistics::*; + + testing_boiler!(shape_a: f64, shape_b: f64; Beta; BetaError); + + #[test] + fn create_ok_success() { + let b = create_ok(0.8, 1.2); + assert_eq!(b.shape_a(), 0.8); + assert_eq!(b.shape_b(), 1.2); + } + + #[test] + #[should_panic] + fn create_err_failure() { + create_err(0.8, 1.2); + } + + #[test] + fn create_err_success() { + let err = create_err(-0.5, 1.2); + assert_eq!(err, BetaError::ShapeAInvalid); + } + + #[test] + #[should_panic] + fn create_ok_failure() { + create_ok(-0.5, 1.2); + } + + #[test] + fn test_exact_success() { + test_exact(1.5, 1.5, 0.5, |dist| dist.mode().unwrap()); + } + + #[test] + #[should_panic] + fn test_exact_failure() { + test_exact(1.2, 1.4, 0.333333333333, |dist| dist.mode().unwrap()); + } + + #[test] + fn test_relative_success() { + test_relative(1.2, 1.4, 0.333333333333, |dist| dist.mode().unwrap()); + } + + #[test] + #[should_panic] + fn test_relative_failure() { + test_relative(1.2, 1.4, 0.333, |dist| dist.mode().unwrap()); + } + + #[test] + fn test_absolute_success() { + test_absolute(1.2, 1.4, 0.333333333333, 1e-12, |dist| dist.mode().unwrap()); + } + + #[test] + #[should_panic] + fn test_absolute_failure() { + test_absolute(1.2, 1.4, 0.333333333333, 1e-15, |dist| dist.mode().unwrap()); + } + + #[test] + fn test_create_err_success() { + test_create_err(0.0, 0.5, BetaError::ShapeAInvalid); + } + + #[test] + #[should_panic] + fn test_create_err_failure() { + test_create_err(0.0, 0.5, BetaError::BothShapesInfinite); + } + + #[test] + fn test_is_nan_success() { + // Not sure that any Beta API can return a NaN, so we force the issue + test_is_nan(0.8, 1.2, |_| f64::NAN); + } + + #[test] + #[should_panic] + fn test_is_nan_failure() { + test_is_nan(0.8, 1.2, |dist| dist.mean().unwrap()); + } + + #[test] + fn test_is_none_success() { + test_none(f64::INFINITY, 1.2, |dist| dist.entropy()); + } + + #[test] + #[should_panic] + fn test_is_none_failure() { + test_none(0.8, 1.2, |dist| dist.mean()); + } + } + /// cdf should be the integral of the pdf fn check_integrate_pdf_is_cdf + Continuous>( dist: &D, @@ -178,22 +456,24 @@ pub mod test { } #[test] - fn test_is_valid_multinomial() { - use std::f64; - - let invalid = [1.0, f64::NAN, 3.0]; - assert!(!is_valid_multinomial(&invalid, true)); - let invalid2 = [-2.0, 5.0, 1.0, 6.2]; - assert!(!is_valid_multinomial(&invalid2, true)); - let invalid3 = [0.0, 0.0, 0.0]; - assert!(!is_valid_multinomial(&invalid3, true)); - let valid = [5.2, 0.0, 1e-15, 1000000.12]; - assert!(is_valid_multinomial(&valid, true)); - } + fn test_integer_bisection() { + fn search(z: usize, data: &[usize]) -> Option { + integral_bisection_search(|idx: &usize| data[*idx], z, 0, data.len() - 1) + } - #[test] - fn test_is_valid_multinomial_no_zero() { - let invalid = [5.2, 0.0, 1e-15, 1000000.12]; - assert!(!is_valid_multinomial(&invalid, false)); + let needle = 3; + let data = (0..5) + .map(|n| if n >= needle { n + 1 } else { n }) + .collect::>(); + + for i in 0..(data.len()) { + assert_eq!(search(data[i], &data), Some(i),) + } + { + let infimum = search(needle, &data); + let found_element = search(needle + 1, &data); // 4 > needle && member of range + assert_eq!(found_element, Some(needle)); + assert_eq!(infimum, found_element) + } } } diff --git a/src/distribution/inverse_gamma.rs b/src/distribution/inverse_gamma.rs index 1cf69fa9..db101fd0 100644 --- a/src/distribution/inverse_gamma.rs +++ b/src/distribution/inverse_gamma.rs @@ -1,8 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; use crate::statistics::*; -use crate::{Result, StatsError}; -use rand::Rng; use std::f64; /// Implements the [Inverse @@ -20,12 +18,39 @@ use std::f64; /// assert!(prec::almost_eq(n.mean().unwrap(), 1.0, 1e-14)); /// assert_eq!(n.pdf(1.0), 0.07554920138253064); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct InverseGamma { shape: f64, rate: f64, } +/// Represents the errors that can occur when creating an [`InverseGamma`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum InverseGammaError { + /// The shape is NaN, infinite, zero or less than zero. + ShapeInvalid, + + /// The rate is NaN, infinite, zero or less than zero. + RateInvalid, +} + +impl std::fmt::Display for InverseGammaError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + InverseGammaError::ShapeInvalid => { + write!(f, "Shape is NaN, infinite, zero or less than zero") + } + InverseGammaError::RateInvalid => { + write!(f, "Rate is NaN, infinite, zero or less than zero") + } + } + } +} + +impl std::error::Error for InverseGammaError {} + impl InverseGamma { /// Constructs a new inverse gamma distribution with a shape (α) /// of `shape` and a rate (β) of `rate` @@ -46,16 +71,16 @@ impl InverseGamma { /// result = InverseGamma::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(shape: f64, rate: f64) -> Result { - let is_nan = shape.is_nan() || rate.is_nan(); - match (shape, rate, is_nan) { - (_, _, true) => Err(StatsError::BadParams), - (_, _, false) if shape <= 0.0 || rate <= 0.0 => Err(StatsError::BadParams), - (_, _, false) if shape.is_infinite() || rate.is_infinite() => { - Err(StatsError::BadParams) - } - (_, _, false) => Ok(InverseGamma { shape, rate }), + pub fn new(shape: f64, rate: f64) -> Result { + if shape.is_nan() || shape.is_infinite() || shape <= 0.0 { + return Err(InverseGammaError::ShapeInvalid); + } + + if rate.is_nan() || rate.is_infinite() || rate <= 0.0 { + return Err(InverseGammaError::RateInvalid); } + + Ok(InverseGamma { shape, rate }) } /// Returns the shape (α) of the inverse gamma distribution @@ -87,8 +112,15 @@ impl InverseGamma { } } +impl std::fmt::Display for InverseGamma { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Inv-Gamma({}, {})", self.shape, self.rate) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for InverseGamma { - fn sample(&self, r: &mut R) -> f64 { + fn sample(&self, r: &mut R) -> f64 { 1.0 / super::gamma::sample_unchecked(r, self.shape, self.rate) } } @@ -99,7 +131,7 @@ impl ContinuousCDF for InverseGamma { /// /// # Formula /// - /// ```ignore + /// ```text /// Γ(α, β / x) / Γ(α) /// ``` /// @@ -121,7 +153,7 @@ impl ContinuousCDF for InverseGamma { /// /// # Formula /// - /// ```ignore + /// ```text /// Γ(α, β / x) / Γ(α) /// ``` /// @@ -146,7 +178,7 @@ impl Min for InverseGamma { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> f64 { @@ -161,8 +193,8 @@ impl Max for InverseGamma { /// /// # Formula /// - /// ```ignore - /// INF + /// ```text + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY @@ -178,7 +210,7 @@ impl Distribution for InverseGamma { /// /// # Formula /// - /// ```ignore + /// ```text /// β / (α - 1) /// ``` /// @@ -190,6 +222,7 @@ impl Distribution for InverseGamma { Some(self.rate / (self.shape - 1.0)) } } + /// Returns the variance of the inverse gamma distribution /// /// # None @@ -198,7 +231,7 @@ impl Distribution for InverseGamma { /// /// # Formula /// - /// ```ignore + /// ```text /// β^2 / ((α - 1)^2 * (α - 2)) /// ``` /// @@ -212,11 +245,12 @@ impl Distribution for InverseGamma { Some(val) } } + /// Returns the entropy of the inverse gamma distribution /// /// # Formula /// - /// ```ignore + /// ```text /// α + ln(β * Γ(α)) - (1 + α) * ψ(α) /// ``` /// @@ -227,6 +261,7 @@ impl Distribution for InverseGamma { - (1.0 + self.shape) * gamma::digamma(self.shape); Some(entr) } + /// Returns the skewness of the inverse gamma distribution /// /// # None @@ -235,7 +270,7 @@ impl Distribution for InverseGamma { /// /// # Formula /// - /// ```ignore + /// ```text /// 4 * sqrt(α - 2) / (α - 3) /// ``` /// @@ -254,7 +289,7 @@ impl Mode> for InverseGamma { /// /// # Formula /// - /// ```ignore + /// ```text /// β / (α + 1) /// ``` /// @@ -270,7 +305,7 @@ impl Continuous for InverseGamma { /// /// # Formula /// - /// ```ignore + /// ```text /// (β^α / Γ(α)) * x^(-α - 1) * e^(-β / x) /// ``` /// @@ -291,7 +326,7 @@ impl Continuous for InverseGamma { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((β^α / Γ(α)) * x^(-α - 1) * e^(-β / x)) /// ``` /// @@ -302,179 +337,136 @@ impl Continuous for InverseGamma { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use crate::statistics::*; - use crate::distribution::{ContinuousCDF, Continuous, InverseGamma}; + use super::*; use crate::distribution::internal::*; - use crate::consts::ACC; + use crate::testing_boiler; - fn try_create(shape: f64, rate: f64) -> InverseGamma { - let n = InverseGamma::new(shape, rate); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(shape: f64, rate: f64) { - let n = try_create(shape, rate); - assert_eq!(shape, n.shape()); - assert_eq!(rate, n.rate()); - } - - fn bad_create_case(shape: f64, rate: f64) { - let n = InverseGamma::new(shape, rate); - assert!(n.is_err()); - } - - fn get_value(shape: f64, rate: f64, eval: F) -> f64 - where F: Fn(InverseGamma) -> f64 - { - let n = try_create(shape, rate); - eval(n) - } - - fn test_case(shape: f64, rate: f64, expected: f64, eval: F) - where F: Fn(InverseGamma) -> f64 - { - let x = get_value(shape, rate, eval); - assert_eq!(expected, x); - } - - fn test_almost(shape: f64, rate: f64, expected: f64, acc: f64, eval: F) - where F: Fn(InverseGamma) -> f64 - { - let x = get_value(shape, rate, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(shape: f64, rate: f64; InverseGamma; InverseGammaError); #[test] fn test_create() { - create_case(0.1, 0.1); - create_case(1.0, 1.0); + create_ok(0.1, 0.1); + create_ok(1.0, 1.0); } #[test] fn test_bad_create() { - bad_create_case(0.0, 1.0); - bad_create_case(-1.0, 1.0); - bad_create_case(-100.0, 1.0); - bad_create_case(f64::NEG_INFINITY, 1.0); - bad_create_case(f64::NAN, 1.0); - bad_create_case(1.0, 0.0); - bad_create_case(1.0, -1.0); - bad_create_case(1.0, -100.0); - bad_create_case(1.0, f64::NEG_INFINITY); - bad_create_case(1.0, f64::NAN); - bad_create_case(f64::INFINITY, 1.0); - bad_create_case(1.0, f64::INFINITY); - bad_create_case(f64::INFINITY, f64::INFINITY); + test_create_err(0.0, 1.0, InverseGammaError::ShapeInvalid); + test_create_err(1.0, -1.0, InverseGammaError::RateInvalid); + create_err(-1.0, 1.0); + create_err(-100.0, 1.0); + create_err(f64::NEG_INFINITY, 1.0); + create_err(f64::NAN, 1.0); + create_err(1.0, 0.0); + create_err(1.0, -100.0); + create_err(1.0, f64::NEG_INFINITY); + create_err(1.0, f64::NAN); + create_err(f64::INFINITY, 1.0); + create_err(1.0, f64::INFINITY); + create_err(f64::INFINITY, f64::INFINITY); } #[test] fn test_mean() { let mean = |x: InverseGamma| x.mean().unwrap(); - test_almost(1.1, 0.1, 1.0, 1e-14, mean); - test_almost(1.1, 1.0, 10.0, 1e-14, mean); + test_absolute(1.1, 0.1, 1.0, 1e-14, mean); + test_absolute(1.1, 1.0, 10.0, 1e-14, mean); } #[test] - #[should_panic] fn test_mean_with_shape_lte_1() { - let mean = |x: InverseGamma| x.mean().unwrap(); - get_value(0.1, 0.1, mean); + test_none(0.1, 0.1, |dist| dist.mean()); } #[test] fn test_variance() { let variance = |x: InverseGamma| x.variance().unwrap(); - test_almost(2.1, 0.1, 0.08264462809917355371901, 1e-15, variance); - test_almost(2.1, 1.0, 8.264462809917355371901, 1e-13, variance); + test_absolute(2.1, 0.1, 0.08264462809917355371901, 1e-15, variance); + test_absolute(2.1, 1.0, 8.264462809917355371901, 1e-13, variance); } #[test] - #[should_panic] fn test_variance_with_shape_lte_2() { - let variance = |x: InverseGamma| x.variance().unwrap(); - get_value(0.1, 0.1, variance); + test_none(0.1, 0.1, |dist| dist.variance()); } #[test] fn test_entropy() { let entropy = |x: InverseGamma| x.entropy().unwrap(); - test_almost(0.1, 0.1, 11.51625799319234475054, 1e-14, entropy); - test_almost(1.0, 1.0, 2.154431329803065721213, 1e-14, entropy); + test_absolute(0.1, 0.1, 11.51625799319234475054, 1e-14, entropy); + test_absolute(1.0, 1.0, 2.154431329803065721213, 1e-14, entropy); } #[test] fn test_skewness() { let skewness = |x: InverseGamma| x.skewness().unwrap(); - test_almost(3.1, 0.1, 41.95235392680606187966, 1e-13, skewness); - test_almost(3.1, 1.0, 41.95235392680606187966, 1e-13, skewness); - test_case(5.0, 0.1, 3.464101615137754587055, skewness); + test_absolute(3.1, 0.1, 41.95235392680606187966, 1e-13, skewness); + test_absolute(3.1, 1.0, 41.95235392680606187966, 1e-13, skewness); + test_exact(5.0, 0.1, 3.464101615137754587055, skewness); } #[test] - #[should_panic] fn test_skewness_with_shape_lte_3() { - let skewness = |x: InverseGamma| x.skewness().unwrap(); - get_value(0.1, 0.1, skewness); + test_none(0.1, 0.1, |dist| dist.skewness()); } #[test] fn test_mode() { let mode = |x: InverseGamma| x.mode().unwrap(); - test_case(0.1, 0.1, 0.09090909090909090909091, mode); - test_case(1.0, 1.0, 0.5, mode); + test_exact(0.1, 0.1, 0.09090909090909090909091, mode); + test_exact(1.0, 1.0, 0.5, mode); } #[test] fn test_min_max() { let min = |x: InverseGamma| x.min(); let max = |x: InverseGamma| x.max(); - test_case(1.0, 1.0, 0.0, min); - test_case(1.0, 1.0, f64::INFINITY, max); + test_exact(1.0, 1.0, 0.0, min); + test_exact(1.0, 1.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: InverseGamma| x.pdf(arg); - test_almost(0.1, 0.1, 0.0628591853882328004197, 1e-15, pdf(1.2)); - test_almost(0.1, 1.0, 0.0297426109178248997426, 1e-15, pdf(2.0)); - test_case(1.0, 0.1, 0.04157808822362745501024, pdf(1.5)); - test_case(1.0, 1.0, 0.3018043114632487660842, pdf(1.2)); + test_absolute(0.1, 0.1, 0.0628591853882328004197, 1e-15, pdf(1.2)); + test_absolute(0.1, 1.0, 0.0297426109178248997426, 1e-15, pdf(2.0)); + test_exact(1.0, 0.1, 0.04157808822362745501024, pdf(1.5)); + test_exact(1.0, 1.0, 0.3018043114632487660842, pdf(1.2)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: InverseGamma| x.ln_pdf(arg); - test_almost(0.1, 0.1, 0.0628591853882328004197f64.ln(), 1e-15, ln_pdf(1.2)); - test_almost(0.1, 1.0, 0.0297426109178248997426f64.ln(), 1e-15, ln_pdf(2.0)); - test_case(1.0, 0.1, 0.04157808822362745501024f64.ln(), ln_pdf(1.5)); - test_case(1.0, 1.0, 0.3018043114632487660842f64.ln(), ln_pdf(1.2)); + test_absolute(0.1, 0.1, 0.0628591853882328004197f64.ln(), 1e-15, ln_pdf(1.2)); + test_absolute(0.1, 1.0, 0.0297426109178248997426f64.ln(), 1e-15, ln_pdf(2.0)); + test_exact(1.0, 0.1, 0.04157808822362745501024f64.ln(), ln_pdf(1.5)); + test_exact(1.0, 1.0, 0.3018043114632487660842f64.ln(), ln_pdf(1.2)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: InverseGamma| x.cdf(arg); - test_almost(0.1, 0.1, 0.1862151961946054271994, 1e-14, cdf(1.2)); - test_almost(0.1, 1.0, 0.05859755410986647796141, 1e-14, cdf(2.0)); - test_case(1.0, 0.1, 0.9355069850316177377304, cdf(1.5)); - test_almost(1.0, 1.0, 0.4345982085070782231613, 1e-14, cdf(1.2)); + test_absolute(0.1, 0.1, 0.1862151961946054271994, 1e-14, cdf(1.2)); + test_absolute(0.1, 1.0, 0.05859755410986647796141, 1e-14, cdf(2.0)); + test_exact(1.0, 0.1, 0.9355069850316177377304, cdf(1.5)); + test_absolute(1.0, 1.0, 0.4345982085070782231613, 1e-14, cdf(1.2)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: InverseGamma| x.sf(arg); - test_almost(0.1, 0.1, 0.8137848038053936, 1e-14, sf(1.2)); - test_almost(0.1, 1.0, 0.9414024458901327, 1e-14, sf(2.0)); - test_almost(1.0, 0.1, 0.0644930149683822, 1e-14, sf(1.5)); - test_almost(1.0, 1.0, 0.565401791492922, 1e-14, sf(1.2)); + test_absolute(0.1, 0.1, 0.8137848038053936, 1e-14, sf(1.2)); + test_absolute(0.1, 1.0, 0.9414024458901327, 1e-14, sf(2.0)); + test_absolute(1.0, 0.1, 0.0644930149683822, 1e-14, sf(1.5)); + test_absolute(1.0, 1.0, 0.565401791492922, 1e-14, sf(1.2)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(1.0, 0.5), 0.0, 100.0); - test::check_continuous_distribution(&try_create(9.0, 2.0), 0.0, 100.0); + test::check_continuous_distribution(&create_ok(1.0, 0.5), 0.0, 100.0); + test::check_continuous_distribution(&create_ok(9.0, 2.0), 0.0, 100.0); } } diff --git a/src/distribution/laplace.rs b/src/distribution/laplace.rs index 2d3d5590..b54bbd9f 100644 --- a/src/distribution/laplace.rs +++ b/src/distribution/laplace.rs @@ -1,7 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::{Distribution, Max, Median, Min, Mode}; -use crate::{Result, StatsError}; -use rand::Rng; use std::f64; /// Implements the [Laplace](https://en.wikipedia.org/wiki/Laplace_distribution) @@ -17,12 +15,35 @@ use std::f64; /// assert_eq!(n.mode().unwrap(), 0.0); /// assert_eq!(n.pdf(1.0), 0.18393972058572117); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Laplace { location: f64, scale: f64, } +/// Represents the errors that can occur when creating a [`Laplace`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum LaplaceError { + /// The location is NaN. + LocationInvalid, + + /// The scale is NaN, zero or less than zero. + ScaleInvalid, +} + +impl std::fmt::Display for LaplaceError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + LaplaceError::LocationInvalid => write!(f, "Location is NaN"), + LaplaceError::ScaleInvalid => write!(f, "Scale is NaN, zero or less than zero"), + } + } +} + +impl std::error::Error for LaplaceError {} + impl Laplace { /// Constructs a new laplace distribution with the given /// location and scale. @@ -42,12 +63,16 @@ impl Laplace { /// result = Laplace::new(0.0, -1.0); /// assert!(result.is_err()); /// ``` - pub fn new(location: f64, scale: f64) -> Result { - if location.is_nan() || scale.is_nan() || scale <= 0.0 { - Err(StatsError::BadParams) - } else { - Ok(Laplace { location, scale }) + pub fn new(location: f64, scale: f64) -> Result { + if location.is_nan() { + return Err(LaplaceError::LocationInvalid); } + + if scale.is_nan() || scale <= 0.0 { + return Err(LaplaceError::ScaleInvalid); + } + + Ok(Laplace { location, scale }) } /// Returns the location of the laplace distribution @@ -79,8 +104,15 @@ impl Laplace { } } +impl std::fmt::Display for Laplace { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Laplace({}, {})", self.location, self.scale) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Laplace { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { let x: f64 = rng.gen_range(-0.5..0.5); self.location - self.scale * x.signum() * (1. - 2. * x.abs()).ln() } @@ -92,7 +124,7 @@ impl ContinuousCDF for Laplace { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / 2) * (1 + signum(x - μ)) - signum(x - μ) * exp(-|x - μ| / b) /// ``` /// @@ -111,7 +143,7 @@ impl ContinuousCDF for Laplace { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 - [(1 / 2) * (1 + signum(x - μ)) - signum(x - μ) * exp(-|x - μ| / b)] /// ``` /// @@ -131,11 +163,11 @@ impl ContinuousCDF for Laplace { /// # Formula /// /// if p <= 1/2 - /// ```ignore + /// ```text /// μ + b * ln(2p) /// ``` /// if p >= 1/2 - /// ```ignore + /// ```text /// μ - b * ln(2 - 2p) /// ``` /// @@ -158,7 +190,7 @@ impl Min for Laplace { /// /// # Formula /// - /// ```ignore + /// ```text /// NEG_INF /// ``` fn min(&self) -> f64 { @@ -172,8 +204,8 @@ impl Max for Laplace { /// /// # Formula /// - /// ```ignore - /// INF + /// ```text + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY @@ -185,7 +217,7 @@ impl Distribution for Laplace { /// /// # Formula /// - /// ```ignore + /// ```text /// μ /// ``` /// @@ -193,11 +225,12 @@ impl Distribution for Laplace { fn mean(&self) -> Option { Some(self.location) } + /// Returns the variance of the laplace distribution /// /// # Formula /// - /// ```ignore + /// ```text /// 2*b^2 /// ``` /// @@ -205,11 +238,12 @@ impl Distribution for Laplace { fn variance(&self) -> Option { Some(2. * self.scale * self.scale) } + /// Returns the entropy of the laplace distribution /// /// # Formula /// - /// ```ignore + /// ```text /// ln(2be) /// ``` /// @@ -217,11 +251,12 @@ impl Distribution for Laplace { fn entropy(&self) -> Option { Some((2. * self.scale).ln() + 1.) } + /// Returns the skewness of the laplace distribution /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn skewness(&self) -> Option { @@ -234,7 +269,7 @@ impl Median for Laplace { /// /// # Formula /// - /// ```ignore + /// ```text /// μ /// ``` /// @@ -249,7 +284,7 @@ impl Mode> for Laplace { /// /// # Formula /// - /// ```ignore + /// ```text /// μ /// ``` /// @@ -265,7 +300,7 @@ impl Continuous for Laplace { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / 2b) * exp(-|x - μ| / b) /// ``` /// where `μ` is the location and `b` is the scale @@ -278,7 +313,7 @@ impl Continuous for Laplace { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((1 / 2b) * exp(-|x - μ| / b)) /// ``` /// @@ -288,202 +323,171 @@ impl Continuous for Laplace { } } -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { use super::*; - use core::f64::INFINITY as INF; - use rand::thread_rng; - fn try_create(location: f64, scale: f64) -> Laplace { - let n = Laplace::new(location, scale); - assert!(n.is_ok()); - n.unwrap() - } + use crate::testing_boiler; - fn bad_create_case(location: f64, scale: f64) { - let n = Laplace::new(location, scale); - assert!(n.is_err()); - } - - fn test_case(location: f64, scale: f64, expected: f64, eval: F) - where - F: Fn(Laplace) -> f64, - { - let n = try_create(location, scale); - let x = eval(n); - assert_eq!(expected, x); - } - - fn test_is_nan(location: f64, scale: f64, eval: F) - where - F: Fn(Laplace) -> f64, - { - let n = try_create(location, scale); - let x = eval(n); - assert!(x.is_nan()); - } - - fn test_almost(location: f64, scale: f64, expected: f64, acc: f64, eval: F) - where - F: Fn(Laplace) -> f64, - { - let n = try_create(location, scale); - let x = eval(n); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(location: f64, scale: f64; Laplace; LaplaceError); // A wrapper for the `assert_relative_eq!` macro from the approx crate. // // `rtol` is the accepable relative error. This function is for testing // relative tolerance *only*. It should not be used with `expected = 0`. // - fn test_rel_close(location: f64, scale: f64, expected: f64, rtol: f64, eval: F) + fn test_rel_close(location: f64, scale: f64, expected: f64, rtol: f64, get_fn: F) where F: Fn(Laplace) -> f64, { - let n = try_create(location, scale); - let x = eval(n); + let x = create_and_get(location, scale, get_fn); assert_relative_eq!(expected, x, epsilon = 0.0, max_relative = rtol); } #[test] fn test_create() { - try_create(1.0, 2.0); - try_create(-INF, 0.1); - try_create(-5.0 - 1.0, 1.0); - try_create(0.0, 5.0); - try_create(1.0, 7.0); - try_create(5.0, 10.0); - try_create(INF, INF); + create_ok(1.0, 2.0); + create_ok(f64::NEG_INFINITY, 0.1); + create_ok(-5.0 - 1.0, 1.0); + create_ok(0.0, 5.0); + create_ok(1.0, 7.0); + create_ok(5.0, 10.0); + create_ok(f64::INFINITY, f64::INFINITY); } #[test] fn test_bad_create() { - bad_create_case(2.0, -1.0); - bad_create_case(f64::NAN, 1.0); - bad_create_case(f64::NAN, -1.0); + test_create_err(2.0, -1.0, LaplaceError::ScaleInvalid); + test_create_err(f64::NAN, 1.0, LaplaceError::LocationInvalid); + create_err(f64::NAN, -1.0); } #[test] fn test_mean() { let mean = |x: Laplace| x.mean().unwrap(); - test_case(-INF, 0.1, -INF, mean); - test_case(-5.0 - 1.0, 1.0, -6.0, mean); - test_case(0.0, 5.0, 0.0, mean); - test_case(1.0, 10.0, 1.0, mean); - test_case(INF, INF, INF, mean); + test_exact(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, mean); + test_exact(-5.0 - 1.0, 1.0, -6.0, mean); + test_exact(0.0, 5.0, 0.0, mean); + test_exact(1.0, 10.0, 1.0, mean); + test_exact(f64::INFINITY, f64::INFINITY, f64::INFINITY, mean); } #[test] fn test_variance() { let variance = |x: Laplace| x.variance().unwrap(); - test_almost(-INF, 0.1, 0.02, 1E-12, variance); - test_almost(-5.0 - 1.0, 1.0, 2.0, 1E-12, variance); - test_almost(0.0, 5.0, 50.0, 1E-12, variance); - test_almost(1.0, 7.0, 98.0, 1E-12, variance); - test_almost(5.0, 10.0, 200.0, 1E-12, variance); - test_almost(INF, INF, INF, 1E-12, variance); + test_absolute(f64::NEG_INFINITY, 0.1, 0.02, 1E-12, variance); + test_absolute(-5.0 - 1.0, 1.0, 2.0, 1E-12, variance); + test_absolute(0.0, 5.0, 50.0, 1E-12, variance); + test_absolute(1.0, 7.0, 98.0, 1E-12, variance); + test_absolute(5.0, 10.0, 200.0, 1E-12, variance); + test_absolute(f64::INFINITY, f64::INFINITY, f64::INFINITY, 1E-12, variance); } #[test] fn test_entropy() { let entropy = |x: Laplace| x.entropy().unwrap(); - test_almost(-INF, 0.1, (2.0 * f64::consts::E * 0.1).ln(), 1E-12, entropy); - test_almost(-6.0, 1.0, (2.0 * f64::consts::E).ln(), 1E-12, entropy); - test_almost(1.0, 7.0, (2.0 * f64::consts::E * 7.0).ln(), 1E-12, entropy); - test_almost(5., 10., (2. * f64::consts::E * 10.).ln(), 1E-12, entropy); - test_almost(INF, INF, INF, 1E-12, entropy); + test_absolute( + f64::NEG_INFINITY, + 0.1, + (2.0 * f64::consts::E * 0.1).ln(), + 1E-12, + entropy, + ); + test_absolute(-6.0, 1.0, (2.0 * f64::consts::E).ln(), 1E-12, entropy); + test_absolute(1.0, 7.0, (2.0 * f64::consts::E * 7.0).ln(), 1E-12, entropy); + test_absolute(5., 10., (2. * f64::consts::E * 10.).ln(), 1E-12, entropy); + test_absolute(f64::INFINITY, f64::INFINITY, f64::INFINITY, 1E-12, entropy); } #[test] fn test_skewness() { let skewness = |x: Laplace| x.skewness().unwrap(); - test_case(-INF, 0.1, 0.0, skewness); - test_case(-6.0, 1.0, 0.0, skewness); - test_case(1.0, 7.0, 0.0, skewness); - test_case(5.0, 10.0, 0.0, skewness); - test_case(INF, INF, 0.0, skewness); + test_exact(f64::NEG_INFINITY, 0.1, 0.0, skewness); + test_exact(-6.0, 1.0, 0.0, skewness); + test_exact(1.0, 7.0, 0.0, skewness); + test_exact(5.0, 10.0, 0.0, skewness); + test_exact(f64::INFINITY, f64::INFINITY, 0.0, skewness); } #[test] fn test_mode() { let mode = |x: Laplace| x.mode().unwrap(); - test_case(-INF, 0.1, -INF, mode); - test_case(-6.0, 1.0, -6.0, mode); - test_case(1.0, 7.0, 1.0, mode); - test_case(5.0, 10.0, 5.0, mode); - test_case(INF, INF, INF, mode); + test_exact(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, mode); + test_exact(-6.0, 1.0, -6.0, mode); + test_exact(1.0, 7.0, 1.0, mode); + test_exact(5.0, 10.0, 5.0, mode); + test_exact(f64::INFINITY, f64::INFINITY, f64::INFINITY, mode); } #[test] fn test_median() { let median = |x: Laplace| x.median(); - test_case(-INF, 0.1, -INF, median); - test_case(-6.0, 1.0, -6.0, median); - test_case(1.0, 7.0, 1.0, median); - test_case(5.0, 10.0, 5.0, median); - test_case(INF, INF, INF, median); + test_exact(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, median); + test_exact(-6.0, 1.0, -6.0, median); + test_exact(1.0, 7.0, 1.0, median); + test_exact(5.0, 10.0, 5.0, median); + test_exact(f64::INFINITY, f64::INFINITY, f64::INFINITY, median); } #[test] fn test_min() { - test_case(0.0, 1.0, -INF, |l| l.min()); + test_exact(0.0, 1.0, f64::NEG_INFINITY, |l| l.min()); } #[test] fn test_max() { - test_case(0.0, 1.0, INF, |l| l.max()); + test_exact(0.0, 1.0, f64::INFINITY, |l| l.max()); } #[test] fn test_density() { let pdf = |arg: f64| move |x: Laplace| x.pdf(arg); - test_almost(0.0, 0.1, 1.529511602509129e-06, 1E-12, pdf(1.5)); - test_almost(1.0, 0.1, 7.614989872356341e-08, 1E-12, pdf(2.8)); - test_almost(-1.0, 0.1, 3.8905661205668983e-19, 1E-12, pdf(-5.4)); - test_almost(5.0, 0.1, 5.056107463052243e-43, 1E-12, pdf(-4.9)); - test_almost(-5.0, 0.1, 1.9877248679543235e-30, 1E-12, pdf(2.0)); - test_almost(INF, 0.1, 0.0, 1E-12, pdf(5.5)); - test_almost(-INF, 0.1, 0.0, 1E-12, pdf(-0.0)); - test_almost(0.0, 1.0, 0.0, 1E-12, pdf(INF)); - test_almost(1.0, 1.0, 0.00915781944436709, 1E-12, pdf(5.0)); - test_almost(-1.0, 1.0, 0.5, 1E-12, pdf(-1.0)); - test_almost(5.0, 1.0, 0.0012393760883331792, 1E-12, pdf(-1.0)); - test_almost(-5.0, 1.0, 0.0002765421850739168, 1E-12, pdf(2.5)); - test_almost(INF, 0.1, 0.0, 1E-12, pdf(2.0)); - test_almost(-INF, 0.1, 0.0, 1E-12, pdf(15.0)); - test_almost(0.0, INF, 0.0, 1E-12, pdf(89.3)); - test_almost(1.0, INF, 0.0, 1E-12, pdf(-0.1)); - test_almost(-1.0, INF, 0.0, 1E-12, pdf(0.1)); - test_almost(5.0, INF, 0.0, 1E-12, pdf(-6.1)); - test_almost(-5.0, INF, 0.0, 1E-12, pdf(-10.0)); - test_is_nan(INF, INF, pdf(2.0)); - test_is_nan(-INF, INF, pdf(-5.1)); + test_absolute(0.0, 0.1, 1.529511602509129e-06, 1E-12, pdf(1.5)); + test_absolute(1.0, 0.1, 7.614989872356341e-08, 1E-12, pdf(2.8)); + test_absolute(-1.0, 0.1, 3.8905661205668983e-19, 1E-12, pdf(-5.4)); + test_absolute(5.0, 0.1, 5.056107463052243e-43, 1E-12, pdf(-4.9)); + test_absolute(-5.0, 0.1, 1.9877248679543235e-30, 1E-12, pdf(2.0)); + test_absolute(f64::INFINITY, 0.1, 0.0, 1E-12, pdf(5.5)); + test_absolute(f64::NEG_INFINITY, 0.1, 0.0, 1E-12, pdf(-0.0)); + test_absolute(0.0, 1.0, 0.0, 1E-12, pdf(f64::INFINITY)); + test_absolute(1.0, 1.0, 0.00915781944436709, 1E-12, pdf(5.0)); + test_absolute(-1.0, 1.0, 0.5, 1E-12, pdf(-1.0)); + test_absolute(5.0, 1.0, 0.0012393760883331792, 1E-12, pdf(-1.0)); + test_absolute(-5.0, 1.0, 0.0002765421850739168, 1E-12, pdf(2.5)); + test_absolute(f64::INFINITY, 0.1, 0.0, 1E-12, pdf(2.0)); + test_absolute(f64::NEG_INFINITY, 0.1, 0.0, 1E-12, pdf(15.0)); + test_absolute(0.0, f64::INFINITY, 0.0, 1E-12, pdf(89.3)); + test_absolute(1.0, f64::INFINITY, 0.0, 1E-12, pdf(-0.1)); + test_absolute(-1.0, f64::INFINITY, 0.0, 1E-12, pdf(0.1)); + test_absolute(5.0, f64::INFINITY, 0.0, 1E-12, pdf(-6.1)); + test_absolute(-5.0, f64::INFINITY, 0.0, 1E-12, pdf(-10.0)); + test_is_nan(f64::INFINITY, f64::INFINITY, pdf(2.0)); + test_is_nan(f64::NEG_INFINITY, f64::INFINITY, pdf(-5.1)); } #[test] fn test_ln_density() { let ln_pdf = |arg: f64| move |x: Laplace| x.ln_pdf(arg); - test_almost(0.0, 0.1, -13.3905620875659, 1E-12, ln_pdf(1.5)); - test_almost(1.0, 0.1, -16.390562087565897, 1E-12, ln_pdf(2.8)); - test_almost(-1.0, 0.1, -42.39056208756591, 1E-12, ln_pdf(-5.4)); - test_almost(5.0, 0.1, -97.3905620875659, 1E-12, ln_pdf(-4.9)); - test_almost(-5.0, 0.1, -68.3905620875659, 1E-12, ln_pdf(2.0)); - test_case(INF, 0.1, -INF, ln_pdf(5.5)); - test_case(-INF, 0.1, -INF, ln_pdf(-0.0)); - test_case(0.0, 1.0, -INF, ln_pdf(INF)); - test_almost(1.0, 1.0, -4.693147180559945, 1E-12, ln_pdf(5.0)); - test_almost(-1.0, 1.0, -f64::consts::LN_2, 1E-12, ln_pdf(-1.0)); - test_almost(5.0, 1.0, -6.693147180559945, 1E-12, ln_pdf(-1.0)); - test_almost(-5.0, 1.0, -8.193147180559945, 1E-12, ln_pdf(2.5)); - test_case(INF, 0.1, -INF, ln_pdf(2.0)); - test_case(-INF, 0.1, -INF, ln_pdf(15.0)); - test_case(0.0, INF, -INF, ln_pdf(89.3)); - test_case(1.0, INF, -INF, ln_pdf(-0.1)); - test_case(-1.0, INF, -INF, ln_pdf(0.1)); - test_case(5.0, INF, -INF, ln_pdf(-6.1)); - test_case(-5.0, INF, -INF, ln_pdf(-10.0)); - test_is_nan(INF, INF, ln_pdf(2.0)); - test_is_nan(-INF, INF, ln_pdf(-5.1)); + test_absolute(0.0, 0.1, -13.3905620875659, 1E-12, ln_pdf(1.5)); + test_absolute(1.0, 0.1, -16.390562087565897, 1E-12, ln_pdf(2.8)); + test_absolute(-1.0, 0.1, -42.39056208756591, 1E-12, ln_pdf(-5.4)); + test_absolute(5.0, 0.1, -97.3905620875659, 1E-12, ln_pdf(-4.9)); + test_absolute(-5.0, 0.1, -68.3905620875659, 1E-12, ln_pdf(2.0)); + test_exact(f64::INFINITY, 0.1, f64::NEG_INFINITY, ln_pdf(5.5)); + test_exact(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, ln_pdf(-0.0)); + test_exact(0.0, 1.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); + test_absolute(1.0, 1.0, -4.693147180559945, 1E-12, ln_pdf(5.0)); + test_absolute(-1.0, 1.0, -f64::consts::LN_2, 1E-12, ln_pdf(-1.0)); + test_absolute(5.0, 1.0, -6.693147180559945, 1E-12, ln_pdf(-1.0)); + test_absolute(-5.0, 1.0, -8.193147180559945, 1E-12, ln_pdf(2.5)); + test_exact(f64::INFINITY, 0.1, f64::NEG_INFINITY, ln_pdf(2.0)); + test_exact(f64::NEG_INFINITY, 0.1, f64::NEG_INFINITY, ln_pdf(15.0)); + test_exact(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(89.3)); + test_exact(1.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-0.1)); + test_exact(-1.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.1)); + test_exact(5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-6.1)); + test_exact(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-10.0)); + test_is_nan(f64::INFINITY, f64::INFINITY, ln_pdf(2.0)); + test_is_nan(f64::NEG_INFINITY, f64::INFINITY, ln_pdf(-5.1)); } #[test] @@ -546,23 +550,27 @@ mod tests { test_rel_close(loc, scale, expected, reltol, inverse_cdf(0.95)); } + #[cfg(feature = "rand")] #[test] fn test_sample() { use ::rand::distributions::Distribution; - let l = try_create(0.1, 0.5); + use ::rand::thread_rng; + + let l = create_ok(0.1, 0.5); l.sample(&mut thread_rng()); } + #[cfg(feature = "rand")] #[test] fn test_sample_distribution() { + use ::rand::distributions::Distribution; use ::rand::rngs::StdRng; use ::rand::SeedableRng; - use rand::distributions::Distribution; // sanity check sampling let location = 0.0; let scale = 1.0; - let n = try_create(location, scale); + let n = create_ok(location, scale); let trials = 10_000; let tolerance = 250; diff --git a/src/distribution/log_normal.rs b/src/distribution/log_normal.rs index 13854f6f..2cf9d7cb 100644 --- a/src/distribution/log_normal.rs +++ b/src/distribution/log_normal.rs @@ -1,8 +1,7 @@ +use crate::consts; use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::erf; use crate::statistics::*; -use crate::{consts, Result, StatsError}; -use rand::Rng; use std::f64; /// Implements the @@ -20,12 +19,35 @@ use std::f64; /// assert_eq!(n.mean().unwrap(), (0.5f64).exp()); /// assert!(prec::almost_eq(n.pdf(1.0), 0.3989422804014326779399, 1e-16)); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct LogNormal { location: f64, scale: f64, } +/// Represents the errors that can occur when creating a [`LogNormal`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum LogNormalError { + /// The location is NaN. + LocationInvalid, + + /// The scale is NaN, zero or less than zero. + ScaleInvalid, +} + +impl std::fmt::Display for LogNormalError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + LogNormalError::LocationInvalid => write!(f, "Location is NaN"), + LogNormalError::ScaleInvalid => write!(f, "Scale is NaN, zero or less than zero"), + } + } +} + +impl std::error::Error for LogNormalError {} + impl LogNormal { /// Constructs a new log-normal distribution with a location of `location` /// and a scale of `scale` @@ -46,17 +68,28 @@ impl LogNormal { /// result = LogNormal::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(location: f64, scale: f64) -> Result { - if location.is_nan() || scale.is_nan() || scale <= 0.0 { - Err(StatsError::BadParams) - } else { - Ok(LogNormal { location, scale }) + pub fn new(location: f64, scale: f64) -> Result { + if location.is_nan() { + return Err(LogNormalError::LocationInvalid); + } + + if scale.is_nan() || scale <= 0.0 { + return Err(LogNormalError::ScaleInvalid); } + + Ok(LogNormal { location, scale }) } } +impl std::fmt::Display for LogNormal { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "LogNormal({}, {}^2)", self.location, self.scale) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for LogNormal { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { super::normal::sample_unchecked(rng, self.location, self.scale).exp() } } @@ -68,7 +101,7 @@ impl ContinuousCDF for LogNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / 2) + (1 / 2) * erf((ln(x) - μ) / sqrt(2) * σ) /// ``` /// @@ -89,7 +122,7 @@ impl ContinuousCDF for LogNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / 2) + (1 / 2) * erf(-(ln(x) - μ) / sqrt(2) * σ) /// ``` /// @@ -100,9 +133,9 @@ impl ContinuousCDF for LogNormal { /// the sign of the argument error function with respect to the cdf. /// /// the normal cdf Φ (and internal error function) as the following property: - /// ```ignore + /// ```text /// Φ(-x) + Φ(x) = 1 - /// Φ(-x) = 1 - Φ(x) + /// Φ(-x) = 1 - Φ(x) /// ``` fn sf(&self, x: f64) -> f64 { if x <= 0.0 { @@ -113,6 +146,33 @@ impl ContinuousCDF for LogNormal { 0.5 * erf::erfc((x.ln() - self.location) / (self.scale * f64::consts::SQRT_2)) } } + + /// Calculates the inverse cumulative distribution function for the + /// log-normal distribution at `p` + /// + /// # Panics + /// + /// If `p < 0.0` or `p > 1.0` + /// + /// # Formula + /// + /// ```text + /// μ - σ * sqrt(2) * erfc_inv(2p) + /// ``` + /// + /// where `μ` is the location, `σ` is the scale and `erfc_inv` is + /// the inverse of the complementary error function + fn inverse_cdf(&self, p: f64) -> f64 { + if p == 0.0 { + 0.0 + } else if p < 1.0 { + (self.location - (self.scale * f64::consts::SQRT_2 * erf::erfc_inv(2.0 * p))).exp() + } else if p == 1.0 { + f64::INFINITY + } else { + panic!("p must be within [0.0, 1.0]"); + } + } } impl Min for LogNormal { @@ -121,7 +181,7 @@ impl Min for LogNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> f64 { @@ -135,8 +195,8 @@ impl Max for LogNormal { /// /// # Formula /// - /// ```ignore - /// INF + /// ```text + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY @@ -148,7 +208,7 @@ impl Distribution for LogNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// e^(μ + σ^2 / 2) /// ``` /// @@ -156,11 +216,12 @@ impl Distribution for LogNormal { fn mean(&self) -> Option { Some((self.location + self.scale * self.scale / 2.0).exp()) } + /// Returns the variance of the log-normal distribution /// /// # Formula /// - /// ```ignore + /// ```text /// (e^(σ^2) - 1) * e^(2μ + σ^2) /// ``` /// @@ -169,11 +230,12 @@ impl Distribution for LogNormal { let sigma2 = self.scale * self.scale; Some((sigma2.exp() - 1.0) * (self.location + self.location + sigma2).exp()) } + /// Returns the entropy of the log-normal distribution /// /// # Formula /// - /// ```ignore + /// ```text /// ln(σe^(μ + 1 / 2) * sqrt(2π)) /// ``` /// @@ -181,11 +243,12 @@ impl Distribution for LogNormal { fn entropy(&self) -> Option { Some(0.5 + self.scale.ln() + self.location + consts::LN_SQRT_2PI) } + /// Returns the skewness of the log-normal distribution /// /// # Formula /// - /// ```ignore + /// ```text /// (e^(σ^2) + 2) * sqrt(e^(σ^2) - 1) /// ``` /// @@ -201,7 +264,7 @@ impl Median for LogNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// e^μ /// ``` /// @@ -216,7 +279,7 @@ impl Mode> for LogNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// e^(μ - σ^2) /// ``` /// @@ -232,7 +295,7 @@ impl Continuous for LogNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / xσ * sqrt(2π)) * e^(-((ln(x) - μ)^2) / 2σ^2) /// ``` /// @@ -251,7 +314,7 @@ impl Continuous for LogNormal { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((1 / xσ * sqrt(2π)) * e^(-((ln(x) - μ)^2) / 2σ^2)) /// ``` /// @@ -267,357 +330,386 @@ impl Continuous for LogNormal { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use crate::statistics::*; - use crate::distribution::{ContinuousCDF, Continuous, LogNormal}; + use super::*; use crate::distribution::internal::*; - use crate::consts::ACC; + use crate::testing_boiler; - fn try_create(mean: f64, std_dev: f64) -> LogNormal { - let n = LogNormal::new(mean, std_dev); - assert!(n.is_ok()); - n.unwrap() - } - - fn bad_create_case(mean: f64, std_dev: f64) { - let n = LogNormal::new(mean, std_dev); - assert!(n.is_err()); - } - - fn get_value(mean: f64, std_dev: f64, eval: F) -> f64 - where F: Fn(LogNormal) -> f64 - { - let n = try_create(mean, std_dev); - eval(n) - } - - fn test_case(mean: f64, std_dev: f64, expected: f64, eval: F) - where F: Fn(LogNormal) -> f64 - { - let x = get_value(mean, std_dev, eval); - assert_eq!(expected, x); - } - - fn test_almost(mean: f64, std_dev: f64, expected: f64, acc: f64, eval: F) - where F: Fn(LogNormal) -> f64 - { - let x = get_value(mean, std_dev, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(location: f64, scale: f64; LogNormal; LogNormalError); #[test] fn test_create() { - try_create(10.0, 0.1); - try_create(-5.0, 1.0); - try_create(0.0, 10.0); - try_create(10.0, 100.0); - try_create(-5.0, f64::INFINITY); + create_ok(10.0, 0.1); + create_ok(-5.0, 1.0); + create_ok(0.0, 10.0); + create_ok(10.0, 100.0); + create_ok(-5.0, f64::INFINITY); } #[test] fn test_bad_create() { - bad_create_case(0.0, 0.0); - bad_create_case(f64::NAN, 1.0); - bad_create_case(1.0, f64::NAN); - bad_create_case(f64::NAN, f64::NAN); - bad_create_case(1.0, -1.0); + test_create_err(f64::NAN, 1.0, LogNormalError::LocationInvalid); + test_create_err(1.0, f64::NAN, LogNormalError::ScaleInvalid); + create_err(0.0, 0.0); + create_err(f64::NAN, f64::NAN); + create_err(1.0, -1.0); } #[test] fn test_mean() { let mean = |x: LogNormal| x.mean().unwrap(); - test_case(-1.0, 0.1, 0.369723444544058982601, mean); - test_case(-1.0, 1.5, 1.133148453066826316829, mean); - test_case(-1.0, 2.5, 8.372897488127264663205, mean); - test_case(-1.0, 5.5, 1362729.18425285481771, mean); - test_case(-0.1, 0.1, 0.9093729344682314204933, mean); - test_case(-0.1, 1.5, 2.787095460565850768514, mean); - test_case(-0.1, 2.5, 20.59400471119602917533, mean); - test_almost(-0.1, 5.5, 3351772.941252693807591, 1e-9, mean); - test_case(0.1, 0.1, 1.110710610355705232259, mean); - test_case(0.1, 1.5, 3.40416608279081898632, mean); - test_almost(0.1, 2.5, 25.15357415581836182776, 1e-14, mean); - test_almost(0.1, 5.5, 4093864.715172665106863, 1e-8, mean); - test_almost(1.5, 0.1, 4.50415363028848413209, 1e-15, mean); - test_case(1.5, 1.5, 13.80457418606709491926, mean); - test_case(1.5, 2.5, 102.0027730826996844534, mean); - test_case(1.5, 5.5, 16601440.05723477471392, mean); - test_almost(2.5, 0.1, 12.24355896580102707724, 1e-14, mean); - test_almost(2.5, 1.5, 37.52472315960099891407, 1e-11, mean); - test_case(2.5, 2.5, 277.2722845231339804081, mean); - test_case(2.5, 5.5, 45127392.83383337999291, mean); - test_almost(5.5, 0.1, 245.9184556788219446833, 1e-13, mean); - test_case(5.5, 1.5, 753.7042125545612656606, mean); - test_case(5.5, 2.5, 5569.162708566004074422, mean); - test_case(5.5, 5.5, 906407915.0111549133446, mean); + test_exact(-1.0, 0.1, 0.369723444544058982601, mean); + test_exact(-1.0, 1.5, 1.133148453066826316829, mean); + test_exact(-1.0, 2.5, 8.372897488127264663205, mean); + test_exact(-1.0, 5.5, 1362729.18425285481771, mean); + test_exact(-0.1, 0.1, 0.9093729344682314204933, mean); + test_exact(-0.1, 1.5, 2.787095460565850768514, mean); + test_exact(-0.1, 2.5, 20.59400471119602917533, mean); + test_absolute(-0.1, 5.5, 3351772.941252693807591, 1e-9, mean); + test_exact(0.1, 0.1, 1.110710610355705232259, mean); + test_exact(0.1, 1.5, 3.40416608279081898632, mean); + test_absolute(0.1, 2.5, 25.15357415581836182776, 1e-14, mean); + test_absolute(0.1, 5.5, 4093864.715172665106863, 1e-8, mean); + test_absolute(1.5, 0.1, 4.50415363028848413209, 1e-15, mean); + test_exact(1.5, 1.5, 13.80457418606709491926, mean); + test_exact(1.5, 2.5, 102.0027730826996844534, mean); + test_exact(1.5, 5.5, 16601440.05723477471392, mean); + test_absolute(2.5, 0.1, 12.24355896580102707724, 1e-14, mean); + test_absolute(2.5, 1.5, 37.52472315960099891407, 1e-11, mean); + test_exact(2.5, 2.5, 277.2722845231339804081, mean); + test_exact(2.5, 5.5, 45127392.83383337999291, mean); + test_absolute(5.5, 0.1, 245.9184556788219446833, 1e-13, mean); + test_exact(5.5, 1.5, 753.7042125545612656606, mean); + test_exact(5.5, 2.5, 5569.162708566004074422, mean); + test_exact(5.5, 5.5, 906407915.0111549133446, mean); } #[test] fn test_variance() { let variance = |x: LogNormal| x.variance().unwrap(); - test_almost(-1.0, 0.1, 0.001373811865368952608715, 1e-16, variance); - test_case(-1.0, 1.5, 10.898468544015731954, variance); - test_case(-1.0, 2.5, 36245.39726189994988081, variance); - test_almost(-1.0, 5.5, 2.5481629178024539E+25, 1e10, variance); - test_almost(-0.1, 0.1, 0.008311077467909703803238, 1e-16, variance); - test_case(-0.1, 1.5, 65.93189259328902509552, variance); - test_almost(-0.1, 2.5, 219271.8756420929704707, 1e-10, variance); - test_almost(-0.1, 5.5, 1.541548733459471E+26, 1e12, variance); - test_almost(0.1, 0.1, 0.01239867063063756838894, 1e-15, variance); - test_almost(0.1, 1.5, 98.35882573290010981464, 1e-13, variance); - test_almost(0.1, 2.5, 327115.1995809995715014, 1e-10, variance); - test_almost(0.1, 5.5, 2.299720473192458E+26, 1e12, variance); - test_almost(1.5, 0.1, 0.2038917589520099120699, 1e-14, variance); - test_almost(1.5, 1.5, 1617.476145997433210727, 1e-12, variance); - test_almost(1.5, 2.5, 5379293.910566451644527, 1e-9, variance); - test_almost(1.5, 5.5, 3.7818090853910142E+27, 1e12, variance); - test_almost(2.5, 0.1, 1.506567645006046841936, 1e-13, variance); - test_almost(2.5, 1.5, 11951.62198145717670088, 1e-11, variance); - test_case(2.5, 2.5, 39747904.47781154725843, variance); - test_almost(2.5, 5.5, 2.7943999487399818E+28, 1e13, variance); - test_almost(5.5, 0.1, 607.7927673399807484235, 1e-11, variance); - test_case(5.5, 1.5, 4821628.436260521100027, variance); - test_case(5.5, 2.5, 16035449147.34799637823, variance); - test_case(5.5, 5.5, 1.127341399856331737823E+31, variance); + test_absolute(-1.0, 0.1, 0.001373811865368952608715, 1e-16, variance); + test_exact(-1.0, 1.5, 10.898468544015731954, variance); + test_exact(-1.0, 2.5, 36245.39726189994988081, variance); + test_absolute(-1.0, 5.5, 2.5481629178024539E+25, 1e10, variance); + test_absolute(-0.1, 0.1, 0.008311077467909703803238, 1e-16, variance); + test_exact(-0.1, 1.5, 65.93189259328902509552, variance); + test_absolute(-0.1, 2.5, 219271.8756420929704707, 1e-10, variance); + test_absolute(-0.1, 5.5, 1.541548733459471E+26, 1e12, variance); + test_absolute(0.1, 0.1, 0.01239867063063756838894, 1e-15, variance); + test_absolute(0.1, 1.5, 98.35882573290010981464, 1e-13, variance); + test_absolute(0.1, 2.5, 327115.1995809995715014, 1e-10, variance); + test_absolute(0.1, 5.5, 2.299720473192458E+26, 1e12, variance); + test_absolute(1.5, 0.1, 0.2038917589520099120699, 1e-14, variance); + test_absolute(1.5, 1.5, 1617.476145997433210727, 1e-12, variance); + test_absolute(1.5, 2.5, 5379293.910566451644527, 1e-9, variance); + test_absolute(1.5, 5.5, 3.7818090853910142E+27, 1e12, variance); + test_absolute(2.5, 0.1, 1.506567645006046841936, 1e-13, variance); + test_absolute(2.5, 1.5, 11951.62198145717670088, 1e-11, variance); + test_exact(2.5, 2.5, 39747904.47781154725843, variance); + test_absolute(2.5, 5.5, 2.7943999487399818E+28, 1e13, variance); + test_absolute(5.5, 0.1, 607.7927673399807484235, 1e-11, variance); + test_exact(5.5, 1.5, 4821628.436260521100027, variance); + test_exact(5.5, 2.5, 16035449147.34799637823, variance); + test_exact(5.5, 5.5, 1.127341399856331737823E+31, variance); } #[test] fn test_entropy() { let entropy = |x: LogNormal| x.entropy().unwrap(); - test_case(-1.0, 0.1, -1.8836465597893728867265104870209210873020761202386, entropy); - test_case(-1.0, 1.5, 0.82440364131283712375834285186996677643338789710028, entropy); - test_case(-1.0, 2.5, 1.335229265078827806963856948173628711311498693546, entropy); - test_case(-1.0, 5.5, 2.1236866254430979764250411929125703716076041932149, entropy); - test_almost(-0.1, 0.1, -0.9836465597893728922776256101467037894202344606927, 1e-15, entropy); - test_case(-0.1, 1.5, 1.7244036413128371182072277287441840743152295566462, entropy); - test_case(-0.1, 2.5, 2.2352292650788278014127418250478460091933403530919, entropy); - test_case(-0.1, 5.5, 3.0236866254430979708739260697867876694894458527608, entropy); - test_almost(0.1, 0.1, -0.7836465597893728811753953638951383851839177797845, 1e-15, entropy); - test_almost(0.1, 1.5, 1.9244036413128371293094579749957494785515462375544, 1e-15, entropy); - test_case(0.1, 2.5, 2.4352292650788278125149720712994114134296570340001, entropy); - test_case(0.1, 5.5, 3.223686625443097981976156316038353073725762533669, entropy); - test_almost(1.5, 0.1, 0.6163534402106271132734895129790789126979238797614, 1e-15, entropy); - test_case(1.5, 1.5, 3.3244036413128371237583428518699667764333878971003, entropy); - test_case(1.5, 2.5, 3.835229265078827806963856948173628711311498693546, entropy); - test_case(1.5, 5.5, 4.6236866254430979764250411929125703716076041932149, entropy); - test_case(2.5, 0.1, 1.6163534402106271132734895129790789126979238797614, entropy); - test_almost(2.5, 1.5, 4.3244036413128371237583428518699667764333878971003, 1e-15, entropy); - test_case(2.5, 2.5, 4.835229265078827806963856948173628711311498693546, entropy); - test_case(2.5, 5.5, 5.6236866254430979764250411929125703716076041932149, entropy); - test_case(5.5, 0.1, 4.6163534402106271132734895129790789126979238797614, entropy); - test_almost(5.5, 1.5, 7.3244036413128371237583428518699667764333878971003, 1e-15, entropy); - test_case(5.5, 2.5, 7.835229265078827806963856948173628711311498693546, entropy); - test_case(5.5, 5.5, 8.6236866254430979764250411929125703716076041932149, entropy); + test_exact(-1.0, 0.1, -1.8836465597893728867265104870209210873020761202386, entropy); + test_exact(-1.0, 1.5, 0.82440364131283712375834285186996677643338789710028, entropy); + test_exact(-1.0, 2.5, 1.335229265078827806963856948173628711311498693546, entropy); + test_exact(-1.0, 5.5, 2.1236866254430979764250411929125703716076041932149, entropy); + test_absolute(-0.1, 0.1, -0.9836465597893728922776256101467037894202344606927, 1e-15, entropy); + test_exact(-0.1, 1.5, 1.7244036413128371182072277287441840743152295566462, entropy); + test_exact(-0.1, 2.5, 2.2352292650788278014127418250478460091933403530919, entropy); + test_exact(-0.1, 5.5, 3.0236866254430979708739260697867876694894458527608, entropy); + test_absolute(0.1, 0.1, -0.7836465597893728811753953638951383851839177797845, 1e-15, entropy); + test_absolute(0.1, 1.5, 1.9244036413128371293094579749957494785515462375544, 1e-15, entropy); + test_exact(0.1, 2.5, 2.4352292650788278125149720712994114134296570340001, entropy); + test_exact(0.1, 5.5, 3.223686625443097981976156316038353073725762533669, entropy); + test_absolute(1.5, 0.1, 0.6163534402106271132734895129790789126979238797614, 1e-15, entropy); + test_exact(1.5, 1.5, 3.3244036413128371237583428518699667764333878971003, entropy); + test_exact(1.5, 2.5, 3.835229265078827806963856948173628711311498693546, entropy); + test_exact(1.5, 5.5, 4.6236866254430979764250411929125703716076041932149, entropy); + test_exact(2.5, 0.1, 1.6163534402106271132734895129790789126979238797614, entropy); + test_absolute(2.5, 1.5, 4.3244036413128371237583428518699667764333878971003, 1e-15, entropy); + test_exact(2.5, 2.5, 4.835229265078827806963856948173628711311498693546, entropy); + test_exact(2.5, 5.5, 5.6236866254430979764250411929125703716076041932149, entropy); + test_exact(5.5, 0.1, 4.6163534402106271132734895129790789126979238797614, entropy); + test_absolute(5.5, 1.5, 7.3244036413128371237583428518699667764333878971003, 1e-15, entropy); + test_exact(5.5, 2.5, 7.835229265078827806963856948173628711311498693546, entropy); + test_exact(5.5, 5.5, 8.6236866254430979764250411929125703716076041932149, entropy); } #[test] fn test_skewness() { let skewness = |x: LogNormal| x.skewness().unwrap(); - test_almost(-1.0, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); - test_case(-1.0, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); - test_almost(-1.0, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); - test_almost(-1.0, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); - test_almost(-0.1, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); - test_case(-0.1, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); - test_almost(-0.1, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); - test_almost(-0.1, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); - test_almost(0.1, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); - test_case(0.1, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); - test_almost(0.1, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); - test_almost(0.1, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); - test_almost(1.5, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); - test_case(1.5, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); - test_almost(1.5, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); - test_almost(1.5, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); - test_almost(2.5, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); - test_case(2.5, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); - test_almost(2.5, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); - test_almost(2.5, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); - test_almost(5.5, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); - test_case(5.5, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); - test_almost(5.5, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); - test_almost(5.5, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); + test_absolute(-1.0, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); + test_exact(-1.0, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); + test_absolute(-1.0, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); + test_absolute(-1.0, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); + test_absolute(-0.1, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); + test_exact(-0.1, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); + test_absolute(-0.1, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); + test_absolute(-0.1, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); + test_absolute(0.1, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); + test_exact(0.1, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); + test_absolute(0.1, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); + test_absolute(0.1, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); + test_absolute(1.5, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); + test_exact(1.5, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); + test_absolute(1.5, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); + test_absolute(1.5, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); + test_absolute(2.5, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); + test_exact(2.5, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); + test_absolute(2.5, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); + test_absolute(2.5, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); + test_absolute(5.5, 0.1, 0.30175909933883402945387113824982918009810212213629, 1e-14, skewness); + test_exact(5.5, 1.5, 33.46804679732172529147579024311650645764144530123, skewness); + test_absolute(5.5, 2.5, 11824.007933610287521341659465200553739278936344799, 1e-11, skewness); + test_absolute(5.5, 5.5, 50829064464591483629.132631635472412625371367420496, 1e4, skewness); } #[test] fn test_mode() { let mode = |x: LogNormal| x.mode().unwrap(); - test_case(-1.0, 0.1, 0.36421897957152331652213191863106773137983085909534, mode); - test_case(-1.0, 1.5, 0.03877420783172200988689983526759614326014406193602, mode); - test_case(-1.0, 2.5, 0.0007101743888425490635846003705775444086763023873619, mode); - test_case(-1.0, 5.5, 0.000000000000026810038677818032221548731163905979029274677187036, mode); - test_case(-0.1, 0.1, 0.89583413529652823774737070060865897390995185639633, mode); - test_case(-0.1, 1.5, 0.095369162215549610417813418326627245539514227574881, mode); - test_case(-0.1, 2.5, 0.0017467471362611196181003627521060283221112106850165, mode); - test_case(-0.1, 5.5, 0.00000000000006594205454219929159167575814655534255162059017114, mode); - test_case(0.1, 0.1, 1.0941742837052103542285651753780976842292770841345, mode); - test_case(0.1, 1.5, 0.11648415777349696821514223131929465848700730137808, mode); - test_case(0.1, 2.5, 0.0021334817700377079925027678518795817076296484352472, mode); - test_case(0.1, 5.5, 0.000000000000080541807296590798973741710866097756565304960216803, mode); - test_case(1.5, 0.1, 4.4370955190036645692996309927420381428715912422597, mode); - test_case(1.5, 1.5, 0.47236655274101470713804655094326791297020357913648, mode); - test_case(1.5, 2.5, 0.008651695203120634177071503957250390848166331197708, mode); - test_case(1.5, 5.5, 0.00000000000032661313427874471360158184468030186601222739665225, mode); - test_case(2.5, 0.1, 12.061276120444720299113038763305617245808510584994, mode); - test_case(2.5, 1.5, 1.2840254166877414840734205680624364583362808652815, mode); - test_case(2.5, 2.5, 0.023517745856009108236151185100432939470067655273072, mode); - test_case(2.5, 5.5, 0.00000000000088782654784596584473099190326928541185172970391855, mode); - test_case(5.5, 0.1, 242.2572068579541371904816252345031593584721473492, mode); - test_case(5.5, 1.5, 25.790339917193062089080107669377221876655268848954, mode); - test_case(5.5, 2.5, 0.47236655274101470713804655094326791297020357913648, mode); - test_case(5.5, 5.5, 0.000000000017832472908146389493511850431527026413424899198327, mode); + test_exact(-1.0, 0.1, 0.36421897957152331652213191863106773137983085909534, mode); + test_exact(-1.0, 1.5, 0.03877420783172200988689983526759614326014406193602, mode); + test_exact(-1.0, 2.5, 0.0007101743888425490635846003705775444086763023873619, mode); + test_exact(-1.0, 5.5, 0.000000000000026810038677818032221548731163905979029274677187036, mode); + test_exact(-0.1, 0.1, 0.89583413529652823774737070060865897390995185639633, mode); + test_exact(-0.1, 1.5, 0.095369162215549610417813418326627245539514227574881, mode); + test_exact(-0.1, 2.5, 0.0017467471362611196181003627521060283221112106850165, mode); + test_exact(-0.1, 5.5, 0.00000000000006594205454219929159167575814655534255162059017114, mode); + test_exact(0.1, 0.1, 1.0941742837052103542285651753780976842292770841345, mode); + test_exact(0.1, 1.5, 0.11648415777349696821514223131929465848700730137808, mode); + test_exact(0.1, 2.5, 0.0021334817700377079925027678518795817076296484352472, mode); + test_exact(0.1, 5.5, 0.000000000000080541807296590798973741710866097756565304960216803, mode); + test_exact(1.5, 0.1, 4.4370955190036645692996309927420381428715912422597, mode); + test_exact(1.5, 1.5, 0.47236655274101470713804655094326791297020357913648, mode); + test_exact(1.5, 2.5, 0.008651695203120634177071503957250390848166331197708, mode); + test_exact(1.5, 5.5, 0.00000000000032661313427874471360158184468030186601222739665225, mode); + test_exact(2.5, 0.1, 12.061276120444720299113038763305617245808510584994, mode); + test_exact(2.5, 1.5, 1.2840254166877414840734205680624364583362808652815, mode); + test_exact(2.5, 2.5, 0.023517745856009108236151185100432939470067655273072, mode); + test_exact(2.5, 5.5, 0.00000000000088782654784596584473099190326928541185172970391855, mode); + test_exact(5.5, 0.1, 242.2572068579541371904816252345031593584721473492, mode); + test_exact(5.5, 1.5, 25.790339917193062089080107669377221876655268848954, mode); + test_exact(5.5, 2.5, 0.47236655274101470713804655094326791297020357913648, mode); + test_exact(5.5, 5.5, 0.000000000017832472908146389493511850431527026413424899198327, mode); } #[test] fn test_median() { let median = |x: LogNormal| x.median(); - test_case(-1.0, 0.1, 0.36787944117144232159552377016146086744581113103177, median); - test_case(-1.0, 1.5, 0.36787944117144232159552377016146086744581113103177, median); - test_case(-1.0, 2.5, 0.36787944117144232159552377016146086744581113103177, median); - test_case(-1.0, 5.5, 0.36787944117144232159552377016146086744581113103177, median); - test_case(-0.1, 0.1, 0.90483741803595956814139238421693559530906465375738, median); - test_case(-0.1, 1.5, 0.90483741803595956814139238421693559530906465375738, median); - test_case(-0.1, 2.5, 0.90483741803595956814139238421693559530906465375738, median); - test_case(-0.1, 5.5, 0.90483741803595956814139238421693559530906465375738, median); - test_case(0.1, 0.1, 1.1051709180756476309466388234587796577416634163742, median); - test_case(0.1, 1.5, 1.1051709180756476309466388234587796577416634163742, median); - test_case(0.1, 2.5, 1.1051709180756476309466388234587796577416634163742, median); - test_case(0.1, 5.5, 1.1051709180756476309466388234587796577416634163742, median); - test_case(1.5, 0.1, 4.4816890703380648226020554601192758190057498683697, median); - test_case(1.5, 1.5, 4.4816890703380648226020554601192758190057498683697, median); - test_case(1.5, 2.5, 4.4816890703380648226020554601192758190057498683697, median); - test_case(1.5, 5.5, 4.4816890703380648226020554601192758190057498683697, median); - test_case(2.5, 0.1, 12.182493960703473438070175951167966183182767790063, median); - test_case(2.5, 1.5, 12.182493960703473438070175951167966183182767790063, median); - test_case(2.5, 2.5, 12.182493960703473438070175951167966183182767790063, median); - test_case(2.5, 5.5, 12.182493960703473438070175951167966183182767790063, median); - test_case(5.5, 0.1, 244.6919322642203879151889495118393501842287101075, median); - test_case(5.5, 1.5, 244.6919322642203879151889495118393501842287101075, median); - test_case(5.5, 2.5, 244.6919322642203879151889495118393501842287101075, median); - test_case(5.5, 5.5, 244.6919322642203879151889495118393501842287101075, median); + test_exact(-1.0, 0.1, 0.36787944117144232159552377016146086744581113103177, median); + test_exact(-1.0, 1.5, 0.36787944117144232159552377016146086744581113103177, median); + test_exact(-1.0, 2.5, 0.36787944117144232159552377016146086744581113103177, median); + test_exact(-1.0, 5.5, 0.36787944117144232159552377016146086744581113103177, median); + test_exact(-0.1, 0.1, 0.90483741803595956814139238421693559530906465375738, median); + test_exact(-0.1, 1.5, 0.90483741803595956814139238421693559530906465375738, median); + test_exact(-0.1, 2.5, 0.90483741803595956814139238421693559530906465375738, median); + test_exact(-0.1, 5.5, 0.90483741803595956814139238421693559530906465375738, median); + test_exact(0.1, 0.1, 1.1051709180756476309466388234587796577416634163742, median); + test_exact(0.1, 1.5, 1.1051709180756476309466388234587796577416634163742, median); + test_exact(0.1, 2.5, 1.1051709180756476309466388234587796577416634163742, median); + test_exact(0.1, 5.5, 1.1051709180756476309466388234587796577416634163742, median); + test_exact(1.5, 0.1, 4.4816890703380648226020554601192758190057498683697, median); + test_exact(1.5, 1.5, 4.4816890703380648226020554601192758190057498683697, median); + test_exact(1.5, 2.5, 4.4816890703380648226020554601192758190057498683697, median); + test_exact(1.5, 5.5, 4.4816890703380648226020554601192758190057498683697, median); + test_exact(2.5, 0.1, 12.182493960703473438070175951167966183182767790063, median); + test_exact(2.5, 1.5, 12.182493960703473438070175951167966183182767790063, median); + test_exact(2.5, 2.5, 12.182493960703473438070175951167966183182767790063, median); + test_exact(2.5, 5.5, 12.182493960703473438070175951167966183182767790063, median); + test_exact(5.5, 0.1, 244.6919322642203879151889495118393501842287101075, median); + test_exact(5.5, 1.5, 244.6919322642203879151889495118393501842287101075, median); + test_exact(5.5, 2.5, 244.6919322642203879151889495118393501842287101075, median); + test_exact(5.5, 5.5, 244.6919322642203879151889495118393501842287101075, median); } #[test] fn test_min_max() { let min = |x: LogNormal| x.min(); let max = |x: LogNormal| x.max(); - test_case(0.0, 0.1, 0.0, min); - test_case(-3.0, 10.0, 0.0, min); - test_case(0.0, 0.1, f64::INFINITY, max); - test_case(-3.0, 10.0, f64::INFINITY, max); + test_exact(0.0, 0.1, 0.0, min); + test_exact(-3.0, 10.0, 0.0, min); + test_exact(0.0, 0.1, f64::INFINITY, max); + test_exact(-3.0, 10.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: LogNormal| x.pdf(arg); - test_almost(-0.1, 0.1, 1.7968349035073582236359415565799753846986440127816e-104, 1e-118, pdf(0.1)); - test_almost(-0.1, 0.1, 0.00000018288923328441197822391757965928083462391836798722, 1e-21, pdf(0.5)); - test_case(-0.1, 0.1, 2.3363114904470413709866234247494393485647978367885, pdf(0.8)); - test_almost(-0.1, 1.5, 0.90492497850024368541682348133921492204585092983646, 1e-15, pdf(0.1)); - test_almost(-0.1, 1.5, 0.49191985207660942803818797602364034466489243416574, 1e-16, pdf(0.5)); - test_case(-0.1, 1.5, 0.33133347214343229148978298237579567194870525187207, pdf(0.8)); - test_case(-0.1, 2.5, 1.0824698632626565182080576574958317806389057196768, pdf(0.1)); - test_almost(-0.1, 2.5, 0.31029619474753883558901295436486123689563749784867, 1e-16, pdf(0.5)); - test_almost(-0.1, 2.5, 0.19922929916156673799861939824205622734205083805245, 1e-16, pdf(0.8)); + test_absolute(-0.1, 0.1, 1.7968349035073582236359415565799753846986440127816e-104, 1e-118, pdf(0.1)); + test_absolute(-0.1, 0.1, 0.00000018288923328441197822391757965928083462391836798722, 1e-21, pdf(0.5)); + test_exact(-0.1, 0.1, 2.3363114904470413709866234247494393485647978367885, pdf(0.8)); + test_absolute(-0.1, 1.5, 0.90492497850024368541682348133921492204585092983646, 1e-15, pdf(0.1)); + test_absolute(-0.1, 1.5, 0.49191985207660942803818797602364034466489243416574, 1e-16, pdf(0.5)); + test_exact(-0.1, 1.5, 0.33133347214343229148978298237579567194870525187207, pdf(0.8)); + test_exact(-0.1, 2.5, 1.0824698632626565182080576574958317806389057196768, pdf(0.1)); + test_absolute(-0.1, 2.5, 0.31029619474753883558901295436486123689563749784867, 1e-16, pdf(0.5)); + test_absolute(-0.1, 2.5, 0.19922929916156673799861939824205622734205083805245, 1e-16, pdf(0.8)); // Test removed because it was causing compiler issues (see issue 31407 for rust) -// test_almost(1.5, 0.1, 4.1070141770545881694056265342787422035256248474059e-313, 1e-322, pdf(0.1)); +// test_absolute(1.5, 0.1, 4.1070141770545881694056265342787422035256248474059e-313, 1e-322, pdf(0.1)); // - test_almost(1.5, 0.1, 2.8602688726477103843476657332784045661507239533567e-104, 1e-116, pdf(0.5)); - test_case(1.5, 0.1, 1.6670425710002183246335601541889400558525870482613e-64, pdf(0.8)); - test_almost(1.5, 1.5, 0.10698412103361841220076392503406214751353235895732, 1e-16, pdf(0.1)); - test_almost(1.5, 1.5, 0.18266125308224685664142384493330155315630876975024, 1e-16, pdf(0.5)); - test_almost(1.5, 1.5, 0.17185785323404088913982425377565512294017306418953, 1e-16, pdf(0.8)); - test_almost(1.5, 2.5, 0.50186885259059181992025035649158160252576845315332, 1e-15, pdf(0.1)); - test_almost(1.5, 2.5, 0.21721369314437986034957451699565540205404697589349, 1e-16, pdf(0.5)); - test_case(1.5, 2.5, 0.15729636000661278918949298391170443742675565300598, pdf(0.8)); - test_case(2.5, 0.1, 5.6836826548848916385760779034504046896805825555997e-500, pdf(0.1)); - test_almost(2.5, 0.1, 3.1225608678589488061206338085285607881363155340377e-221, 1e-233, pdf(0.5)); - test_almost(2.5, 0.1, 4.6994713794671660918554320071312374073172560048297e-161, 1e-173, pdf(0.8)); - test_almost(2.5, 1.5, 0.015806486291412916772431170442330946677601577502353, 1e-16, pdf(0.1)); - test_almost(2.5, 1.5, 0.055184331257528847223852028950484131834529030116388, 1e-16, pdf(0.5)); - test_case(2.5, 1.5, 0.063982134749859504449658286955049840393511776984362, pdf(0.8)); - test_almost(2.5, 2.5, 0.25212505662402617595900822552548977822542300480086, 1e-15, pdf(0.1)); - test_almost(2.5, 2.5, 0.14117186955911792460646517002386088579088567275401, 1e-16, pdf(0.5)); - test_almost(2.5, 2.5, 0.11021452580363707866161369621432656293405065561317, 1e-16, pdf(0.8)); + test_absolute(1.5, 0.1, 2.8602688726477103843476657332784045661507239533567e-104, 1e-116, pdf(0.5)); + test_exact(1.5, 0.1, 1.6670425710002183246335601541889400558525870482613e-64, pdf(0.8)); + test_absolute(1.5, 1.5, 0.10698412103361841220076392503406214751353235895732, 1e-16, pdf(0.1)); + test_absolute(1.5, 1.5, 0.18266125308224685664142384493330155315630876975024, 1e-16, pdf(0.5)); + test_absolute(1.5, 1.5, 0.17185785323404088913982425377565512294017306418953, 1e-16, pdf(0.8)); + test_absolute(1.5, 2.5, 0.50186885259059181992025035649158160252576845315332, 1e-15, pdf(0.1)); + test_absolute(1.5, 2.5, 0.21721369314437986034957451699565540205404697589349, 1e-16, pdf(0.5)); + test_exact(1.5, 2.5, 0.15729636000661278918949298391170443742675565300598, pdf(0.8)); + test_exact(2.5, 0.1, 5.6836826548848916385760779034504046896805825555997e-500, pdf(0.1)); + test_absolute(2.5, 0.1, 3.1225608678589488061206338085285607881363155340377e-221, 1e-233, pdf(0.5)); + test_absolute(2.5, 0.1, 4.6994713794671660918554320071312374073172560048297e-161, 1e-173, pdf(0.8)); + test_absolute(2.5, 1.5, 0.015806486291412916772431170442330946677601577502353, 1e-16, pdf(0.1)); + test_absolute(2.5, 1.5, 0.055184331257528847223852028950484131834529030116388, 1e-16, pdf(0.5)); + test_exact(2.5, 1.5, 0.063982134749859504449658286955049840393511776984362, pdf(0.8)); + test_absolute(2.5, 2.5, 0.25212505662402617595900822552548977822542300480086, 1e-15, pdf(0.1)); + test_absolute(2.5, 2.5, 0.14117186955911792460646517002386088579088567275401, 1e-16, pdf(0.5)); + test_absolute(2.5, 2.5, 0.11021452580363707866161369621432656293405065561317, 1e-16, pdf(0.8)); } #[test] fn test_neg_pdf() { let pdf = |arg: f64| move |x: LogNormal| x.pdf(arg); - test_case(0.0, 1.0, 0.0, pdf(0.0)); + test_exact(0.0, 1.0, 0.0, pdf(0.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: LogNormal| x.ln_pdf(arg); - test_case(-0.1, 0.1, -238.88282294119596467794686179588610665317241097599, ln_pdf(0.1)); - test_almost(-0.1, 0.1, -15.514385149961296196003163062199569075052113039686, 1e-14, ln_pdf(0.5)); - test_case(-0.1, 0.1, 0.84857339958981283964373051826407417105725729082041, ln_pdf(0.8)); - test_almost(-0.1, 1.5, -0.099903235403144611051953094864849327288457482212211, 1e-15, ln_pdf(0.1)); - test_almost(-0.1, 1.5, -0.70943947804316122682964396008813828577195771418027, 1e-15, ln_pdf(0.5)); - test_almost(-0.1, 1.5, -1.1046299420497998262946038709903250420774183529995, 1e-15, ln_pdf(0.8)); - test_almost(-0.1, 2.5, 0.07924534056485078867266307735371665927517517183681, 1e-16, ln_pdf(0.1)); - test_case(-0.1, 2.5, -1.1702279707433794860424967893989374511050637417043, ln_pdf(0.5)); - test_case(-0.1, 2.5, -1.6132988605030400828957768752511536087538109996183, ln_pdf(0.8)); - test_case(1.5, 0.1, -719.29643782024317312262673764204041218720576249741, ln_pdf(0.1)); - test_almost(1.5, 0.1, -238.41793403955250272430898754048547661932857086122, 1e-13, ln_pdf(0.5)); - test_case(1.5, 0.1, -146.85439481068371057247137024006716189469284256628, ln_pdf(0.8)); - test_almost(1.5, 1.5, -2.2350748570877992856465076624973458117562108140674, 1e-15, ln_pdf(0.1)); - test_almost(1.5, 1.5, -1.7001219175524556705452882616787223585705662860012, 1e-15, ln_pdf(0.5)); - test_almost(1.5, 1.5, -1.7610875785399045023354101841009649273236721172008, 1e-15, ln_pdf(0.8)); - test_almost(1.5, 2.5, -0.68941644324162489418137656699398207513321602763104, 1e-15, ln_pdf(0.1)); - test_case(1.5, 2.5, -1.5268736489667254857801287379715477173125628275598, ln_pdf(0.5)); - test_case(1.5, 2.5, -1.8496236096394777662704671479709839674424623547308, ln_pdf(0.8)); - test_almost(2.5, 0.1, -1149.5549471196476523788026360929146688367845019398, 1e-12, ln_pdf(0.1)); - test_almost(2.5, 0.1, -507.73265209554698134113704985174959301922196605736, 1e-12, ln_pdf(0.5)); - test_almost(2.5, 0.1, -369.16874994210463740474549611573497379941224077335, 1e-13, ln_pdf(0.8)); - test_almost(2.5, 1.5, -4.1473348984184862316495477617980296904955324113457, 1e-15, ln_pdf(0.1)); - test_almost(2.5, 1.5, -2.8970762200235424747307247601045786110485663457169, 1e-15, ln_pdf(0.5)); - test_case(2.5, 1.5, -2.7491513791239977024488074547907467152956602019989, ln_pdf(0.8)); - test_almost(2.5, 2.5, -1.3778300581206721947424710027422282714793718026513, 1e-15, ln_pdf(0.1)); - test_case(2.5, 2.5, -1.9577771978563167352868858774048559682046428490575, ln_pdf(0.5)); - test_case(2.5, 2.5, -2.2053265778497513183112901654193054111123780652581, ln_pdf(0.8)); + test_exact(-0.1, 0.1, -238.88282294119596467794686179588610665317241097599, ln_pdf(0.1)); + test_absolute(-0.1, 0.1, -15.514385149961296196003163062199569075052113039686, 1e-14, ln_pdf(0.5)); + test_exact(-0.1, 0.1, 0.84857339958981283964373051826407417105725729082041, ln_pdf(0.8)); + test_absolute(-0.1, 1.5, -0.099903235403144611051953094864849327288457482212211, 1e-15, ln_pdf(0.1)); + test_absolute(-0.1, 1.5, -0.70943947804316122682964396008813828577195771418027, 1e-15, ln_pdf(0.5)); + test_absolute(-0.1, 1.5, -1.1046299420497998262946038709903250420774183529995, 1e-15, ln_pdf(0.8)); + test_absolute(-0.1, 2.5, 0.07924534056485078867266307735371665927517517183681, 1e-16, ln_pdf(0.1)); + test_exact(-0.1, 2.5, -1.1702279707433794860424967893989374511050637417043, ln_pdf(0.5)); + test_exact(-0.1, 2.5, -1.6132988605030400828957768752511536087538109996183, ln_pdf(0.8)); + test_exact(1.5, 0.1, -719.29643782024317312262673764204041218720576249741, ln_pdf(0.1)); + test_absolute(1.5, 0.1, -238.41793403955250272430898754048547661932857086122, 1e-13, ln_pdf(0.5)); + test_exact(1.5, 0.1, -146.85439481068371057247137024006716189469284256628, ln_pdf(0.8)); + test_absolute(1.5, 1.5, -2.2350748570877992856465076624973458117562108140674, 1e-15, ln_pdf(0.1)); + test_absolute(1.5, 1.5, -1.7001219175524556705452882616787223585705662860012, 1e-15, ln_pdf(0.5)); + test_absolute(1.5, 1.5, -1.7610875785399045023354101841009649273236721172008, 1e-15, ln_pdf(0.8)); + test_absolute(1.5, 2.5, -0.68941644324162489418137656699398207513321602763104, 1e-15, ln_pdf(0.1)); + test_exact(1.5, 2.5, -1.5268736489667254857801287379715477173125628275598, ln_pdf(0.5)); + test_exact(1.5, 2.5, -1.8496236096394777662704671479709839674424623547308, ln_pdf(0.8)); + test_absolute(2.5, 0.1, -1149.5549471196476523788026360929146688367845019398, 1e-12, ln_pdf(0.1)); + test_absolute(2.5, 0.1, -507.73265209554698134113704985174959301922196605736, 1e-12, ln_pdf(0.5)); + test_absolute(2.5, 0.1, -369.16874994210463740474549611573497379941224077335, 1e-13, ln_pdf(0.8)); + test_absolute(2.5, 1.5, -4.1473348984184862316495477617980296904955324113457, 1e-15, ln_pdf(0.1)); + test_absolute(2.5, 1.5, -2.8970762200235424747307247601045786110485663457169, 1e-15, ln_pdf(0.5)); + test_exact(2.5, 1.5, -2.7491513791239977024488074547907467152956602019989, ln_pdf(0.8)); + test_absolute(2.5, 2.5, -1.3778300581206721947424710027422282714793718026513, 1e-15, ln_pdf(0.1)); + test_exact(2.5, 2.5, -1.9577771978563167352868858774048559682046428490575, ln_pdf(0.5)); + test_exact(2.5, 2.5, -2.2053265778497513183112901654193054111123780652581, ln_pdf(0.8)); } #[test] fn test_neg_ln_pdf() { let ln_pdf = |arg: f64| move |x: LogNormal| x.ln_pdf(arg); - test_case(0.0, 1.0, f64::NEG_INFINITY, ln_pdf(0.0)); + test_exact(0.0, 1.0, f64::NEG_INFINITY, ln_pdf(0.0)); } #[test] fn test_cdf() { - let cdf = |arg: f64| move |x: LogNormal| x.cdf(arg); - test_almost(-0.1, 0.1, 0.0, 1e-107, cdf(0.1)); - test_almost(-0.1, 0.1, 0.0000000015011556178148777579869633555518882664666520593658, 1e-19, cdf(0.5)); - test_almost(-0.1, 0.1, 0.10908001076375810900224507908874442583171381706127, 1e-11, cdf(0.8)); - test_almost(-0.1, 1.5, 0.070999149762464508991968731574953594549291668468349, 1e-11, cdf(0.1)); - test_case(-0.1, 1.5, 0.34626224992888089297789445771047690175505847991946, cdf(0.5)); - test_case(-0.1, 1.5, 0.46728530589487698517090261668589508746353129242404, cdf(0.8)); - test_almost(-0.1, 2.5, 0.18914969879695093477606645992572208111152994999076, 1e-10, cdf(0.1)); - test_case(-0.1, 2.5, 0.40622798321378106125020505907901206714868922279347, cdf(0.5)); - test_case(-0.1, 2.5, 0.48035707589956665425068652807400957345208517749893, cdf(0.8)); - test_almost(1.5, 0.1, 0.0, 1e-315, cdf(0.1)); - test_almost(1.5, 0.1, 0.0, 1e-106, cdf(0.5)); - test_almost(1.5, 0.1, 0.0, 1e-66, cdf(0.8)); - test_almost(1.5, 1.5, 0.005621455876973168709588070988239748831823850202953, 1e-12, cdf(0.1)); - test_almost(1.5, 1.5, 0.07185716187918271235246980951571040808235628115265, 1e-11, cdf(0.5)); - test_almost(1.5, 1.5, 0.12532699044614938400496547188720940854423187977236, 1e-11, cdf(0.8)); - test_almost(1.5, 2.5, 0.064125647996943514411570834861724406903677144126117, 1e-11, cdf(0.1)); - test_almost(1.5, 2.5, 0.19017302281590810871719754032332631806011441356498, 1e-10, cdf(0.5)); - test_almost(1.5, 2.5, 0.24533064397555500690927047163085419096928289095201, 1e-16, cdf(0.8)); - test_case(2.5, 0.1, 0.0, cdf(0.1)); - test_almost(2.5, 0.1, 0.0, 1e-223, cdf(0.5)); - test_almost(2.5, 0.1, 0.0, 1e-162, cdf(0.8)); - test_almost(2.5, 1.5, 0.00068304052220788502001572635016579586444611070077399, 1e-13, cdf(0.1)); - test_almost(2.5, 1.5, 0.016636862816580533038130583128179878924863968664206, 1e-12, cdf(0.5)); - test_almost(2.5, 1.5, 0.034729001282904174941366974418836262996834852343018, 1e-11, cdf(0.8)); - test_almost(2.5, 2.5, 0.027363708266690978870139978537188410215717307180775, 1e-11, cdf(0.1)); - test_almost(2.5, 2.5, 0.10075543423327634536450625420610429181921642201567, 1e-11, cdf(0.5)); - test_almost(2.5, 2.5, 0.13802019192453118732001307556787218421918336849121, 1e-11, cdf(0.8)); + cdf_tests(false); + } + + #[test] + fn test_inverse_cdf() { + cdf_tests(true) + } + + // we can reuse the (input, output) pairs from the CDF unit test + // and verify that passing an 'output' to .inverse_cdf gives 'input', + // except in cases where output would be 0.0 (the inverse_cdf is defined to + // always give 0.0 in this case). + fn cdf_tests(inverse: bool) { + let f = |arg: f64| move |x: LogNormal| if inverse { + x.inverse_cdf(arg) + } else { + x.cdf(arg) + }; + + // given some cdf_input and cdf_output, returns a tuple (input, output) where + // input is what we will provide to cdf/inverse_cdf, and output is expected return + // value + let arrange_input_output = |cdf_input: f64, cdf_output: f64| { + if inverse { + (cdf_output, cdf_input) + } else { + (cdf_input, cdf_output) + } + }; + + // calls test_almost after re-arranging the input/output arguments and calling f with input + let almost = |mean: f64, std_dev: f64, cdf_input: f64, cdf_output: f64, acc: f64| { + let (input, output) = arrange_input_output(cdf_input, cdf_output); + test_absolute(mean, std_dev, output, acc, f(input)); + }; + + // calls test_case after re-arranging the input/output arguments and calling f with input + let case = |mean: f64, std_dev: f64, cdf_input: f64, cdf_output: f64| { + let (input, output) = arrange_input_output(cdf_input, cdf_output); + test_exact(mean, std_dev, output, f(input)); + }; + + // we skip cases where the CDF outputs 0.0 when testing the inverse CDF because + // there are multiple inputs to the CDF which give an answer of 0.0, therefore testing whether + // inputting 0.0 to the inverse cdf will give the same answer is not a valid test + // the inverse cdf for log-normal is defined to give answer 0.0 for input 0.0 + if inverse { + case(-0.1, 0.1, 0.0, 0.0); + } + + if !inverse { + almost(-0.1, 0.1, 0.1, 0.0, 1e-107); + } + almost(-0.1, 0.1, 0.5, 0.0000000015011556178148777579869633555518882664666520593658, 1e-16); + almost(-0.1, 0.1, 0.8, 0.10908001076375810900224507908874442583171381706127, 1e-11); + almost(-0.1, 1.5, 0.1, 0.070999149762464508991968731574953594549291668468349, 1e-11); + case(-0.1, 1.5, 0.5, 0.34626224992888089297789445771047690175505847991946); + case(-0.1, 1.5, 0.8, 0.46728530589487698517090261668589508746353129242404); + almost(-0.1, 2.5, 0.1, 0.18914969879695093477606645992572208111152994999076, 1e-10); + case(-0.1, 2.5, 0.5, 0.40622798321378106125020505907901206714868922279347); + case(-0.1, 2.5, 0.8, 0.48035707589956665425068652807400957345208517749893); + + // input to inverse would be 0.0 + if !inverse { + almost(1.5, 0.1, 0.1, 0.0, 1e-315); + almost(1.5, 0.1, 0.5, 0.0, 1e-106); + almost(1.5, 0.1, 0.8, 0.0, 1e-66); + } + + almost(1.5, 1.5, 0.1, 0.005621455876973168709588070988239748831823850202953, 1e-12); + almost(1.5, 1.5, 0.8, 0.12532699044614938400496547188720940854423187977236, 1e-11); + almost(1.5, 2.5, 0.1, 0.064125647996943514411570834861724406903677144126117, 1e-11); + almost(1.5, 2.5, 0.5, 0.19017302281590810871719754032332631806011441356498, 1e-10); + almost(1.5, 2.5, 0.8, 0.24533064397555500690927047163085419096928289095201, 1e-16); + + // input to inverse would be 0.0 + if !inverse { + case(2.5, 0.1, 0.1, 0.0); + almost(2.5, 0.1, 0.5, 0.0, 1e-223); + almost(2.5, 0.1, 0.8, 0.0, 1e-162); + } + + almost(2.5, 1.5, 0.1, 0.00068304052220788502001572635016579586444611070077399, 1e-13); + almost(2.5, 1.5, 0.5, 0.016636862816580533038130583128179878924863968664206, 1e-12); + almost(2.5, 1.5, 0.8, 0.034729001282904174941366974418836262996834852343018, 1e-11); + almost(2.5, 2.5, 0.1, 0.027363708266690978870139978537188410215717307180775, 1e-11); + almost(2.5, 2.5, 0.5, 0.10075543423327634536450625420610429181921642201567, 1e-11); + almost(2.5, 2.5, 0.8, 0.13802019192453118732001307556787218421918336849121, 1e-11); } #[test] @@ -625,34 +717,34 @@ mod tests { let sf = |arg: f64| move |x: LogNormal| x.sf(arg); // Wolfram Alpha:: SurvivalFunction[ LogNormalDistribution(-0.1, 0.1), 0.1] - test_almost(-0.1, 0.1, 1.0, 1e-107, sf(0.1)); + test_absolute(-0.1, 0.1, 1.0, 1e-107, sf(0.1)); // Wolfram Alpha:: SurvivalFunction[ LogNormalDistribution(-0.1, 0.1), 0.8] - test_almost(-0.1, 0.1, 0.890919989231123, 1e-14, sf(0.8)); + test_absolute(-0.1, 0.1, 0.890919989231123, 1e-14, sf(0.8)); // Wolfram Alpha:: SurvivalFunction[LogNormalDistribution[1.5, 1], 0.8] - test_almost(1.5, 1.0, 0.957568715612642, 1e-14, sf(0.8)); + test_absolute(1.5, 1.0, 0.957568715612642, 1e-14, sf(0.8)); // Wolfram Alpha:: SurvivalFunction[ LogNormalDistribution(2.5, 1.5), 0.1] - test_almost(2.5, 1.5, 0.9993169594777358, 1e-14, sf(0.1)); + test_absolute(2.5, 1.5, 0.9993169594777358, 1e-14, sf(0.1)); } #[test] fn test_neg_cdf() { let cdf = |arg: f64| move |x: LogNormal| x.cdf(arg); - test_case(0.0, 1.0, 0.0, cdf(0.0)); + test_exact(0.0, 1.0, 0.0, cdf(0.0)); } #[test] fn test_neg_sf() { let sf = |arg: f64| move |x: LogNormal| x.sf(arg); - test_case(0.0, 1.0, 1.0, sf(0.0)); + test_exact(0.0, 1.0, 1.0, sf(0.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(0.0, 0.25), 0.0, 10.0); - test::check_continuous_distribution(&try_create(0.0, 0.5), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(0.0, 0.25), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(0.0, 0.5), 0.0, 10.0); } } diff --git a/src/distribution/mod.rs b/src/distribution/mod.rs index 5ff966eb..249c92da 100644 --- a/src/distribution/mod.rs +++ b/src/distribution/mod.rs @@ -2,39 +2,46 @@ //! and provides //! concrete implementations for a variety of distributions. use super::statistics::{Max, Min}; -use ::num_traits::{float::Float, Bounded, Num}; +use ::num_traits::{Float, Num}; +use num_traits::NumAssignOps; pub use self::bernoulli::Bernoulli; -pub use self::beta::Beta; -pub use self::binomial::Binomial; -pub use self::categorical::Categorical; -pub use self::cauchy::Cauchy; -pub use self::chi::Chi; +pub use self::beta::{Beta, BetaError}; +pub use self::binomial::{Binomial, BinomialError}; +pub use self::categorical::{Categorical, CategoricalError}; +pub use self::cauchy::{Cauchy, CauchyError}; +pub use self::chi::{Chi, ChiError}; pub use self::chi_squared::ChiSquared; -pub use self::dirac::Dirac; -pub use self::dirichlet::Dirichlet; -pub use self::discrete_uniform::DiscreteUniform; +pub use self::dirac::{Dirac, DiracError}; +#[cfg(feature = "nalgebra")] +pub use self::dirichlet::{Dirichlet, DirichletError}; +pub use self::discrete_uniform::{DiscreteUniform, DiscreteUniformError}; pub use self::empirical::Empirical; pub use self::erlang::Erlang; -pub use self::exponential::Exp; -pub use self::fisher_snedecor::FisherSnedecor; -pub use self::gamma::Gamma; -pub use self::geometric::Geometric; -pub use self::hypergeometric::Hypergeometric; -pub use self::inverse_gamma::InverseGamma; -pub use self::laplace::Laplace; -pub use self::log_normal::LogNormal; -pub use self::multinomial::Multinomial; -pub use self::multivariate_normal::MultivariateNormal; +pub use self::exponential::{Exp, ExpError}; +pub use self::fisher_snedecor::{FisherSnedecor, FisherSnedecorError}; +pub use self::gamma::{Gamma, GammaError}; +pub use self::geometric::{Geometric, GeometricError}; +pub use self::hypergeometric::{Hypergeometric, HypergeometricError}; +pub use self::inverse_gamma::{InverseGamma, InverseGammaError}; +pub use self::laplace::{Laplace, LaplaceError}; +pub use self::log_normal::{LogNormal, LogNormalError}; +#[cfg(feature = "nalgebra")] +pub use self::multinomial::{Multinomial, MultinomialError}; +#[cfg(feature = "nalgebra")] +pub use self::multivariate_normal::{MultivariateNormal, MultivariateNormalError}; +#[cfg(feature = "nalgebra")] pub use self::multivariate_normal_diag::MultivariateNormalDiag; -pub use self::negative_binomial::NegativeBinomial; -pub use self::normal::Normal; -pub use self::pareto::Pareto; -pub use self::poisson::Poisson; -pub use self::students_t::StudentsT; -pub use self::triangular::Triangular; -pub use self::uniform::Uniform; -pub use self::weibull::Weibull; +#[cfg(feature = "nalgebra")] +pub use self::multivariate_students_t::{MultivariateStudent, MultivariateStudentError}; +pub use self::negative_binomial::{NegativeBinomial, NegativeBinomialError}; +pub use self::normal::{Normal, NormalError}; +pub use self::pareto::{Pareto, ParetoError}; +pub use self::poisson::{Poisson, PoissonError}; +pub use self::students_t::{StudentsT, StudentsTError}; +pub use self::triangular::{Triangular, TriangularError}; +pub use self::uniform::{Uniform, UniformError}; +pub use self::weibull::{Weibull, WeibullError}; mod bernoulli; mod beta; @@ -44,6 +51,7 @@ mod cauchy; mod chi; mod chi_squared; mod dirac; +#[cfg(feature = "nalgebra")] mod dirichlet; mod discrete_uniform; mod empirical; @@ -58,9 +66,14 @@ mod internal; mod inverse_gamma; mod laplace; mod log_normal; +#[cfg(feature = "nalgebra")] mod multinomial; +#[cfg(feature = "nalgebra")] mod multivariate_normal; +#[cfg(feature = "nalgebra")] mod multivariate_normal_diag; +#[cfg(feature = "nalgebra")] +mod multivariate_students_t; mod negative_binomial; mod normal; mod pareto; @@ -69,11 +82,11 @@ mod students_t; mod triangular; mod uniform; mod weibull; +#[cfg(feature = "rand")] mod ziggurat; +#[cfg(feature = "rand")] mod ziggurat_tables; -use crate::Result; - /// The `ContinuousCDF` trait is used to specify an interface for univariate /// distributions for which cdf float arguments are sensible. pub trait ContinuousCDF: Min + Max { @@ -103,7 +116,9 @@ pub trait ContinuousCDF: Min + Max { /// let n = Uniform::new(0.0, 1.0).unwrap(); /// assert_eq!(0.5, n.sf(0.5)); /// ``` - fn sf(&self, x: K) -> T; + fn sf(&self, x: K) -> T { + T::one() - self.cdf(x) + } /// Due to issues with rounding and floating-point accuracy the default /// implementation may be ill-behaved. @@ -111,6 +126,8 @@ pub trait ContinuousCDF: Min + Max { /// Performs a binary search on the domain of `cdf` to obtain an approximation /// of `F^-1(p) := inf { x | F(x) >= p }`. Needless to say, performance may /// may be lacking. + #[doc(alias = "quantile function")] + #[doc(alias = "quantile")] fn inverse_cdf(&self, p: T) -> K { if p == T::zero() { return self.min(); @@ -143,7 +160,9 @@ pub trait ContinuousCDF: Min + Max { /// The `DiscreteCDF` trait is used to specify an interface for univariate /// discrete distributions. -pub trait DiscreteCDF: Min + Max { +pub trait DiscreteCDF: + Min + Max +{ /// Returns the cumulative distribution function calculated /// at `x` for a given distribution. May panic depending /// on the implementor. @@ -169,33 +188,32 @@ pub trait DiscreteCDF: Min + Max { /// let n = DiscreteUniform::new(1, 10).unwrap(); /// assert_eq!(0.4, n.sf(6)); /// ``` - fn sf(&self, x: K) -> T; + fn sf(&self, x: K) -> T { + T::one() - self.cdf(x) + } /// Due to issues with rounding and floating-point accuracy the default implementation may be ill-behaved /// Specialized inverse cdfs should be used whenever possible. + /// + /// # Panics + /// this default impl panics if provided `p` not on interval [0.0, 1.0] fn inverse_cdf(&self, p: T) -> K { - // TODO: fix integer implementation if p == T::zero() { return self.min(); - }; - if p == T::one() { + } else if p == T::one() { return self.max(); - }; - let two = K::one() + K::one(); - let mut high = two.clone(); - let mut low = K::min_value(); - while self.cdf(high.clone()) < p { - high = high.clone() + high.clone(); + } else if !(T::zero()..=T::one()).contains(&p) { + panic!("p must be on [0, 1]") } - while high != low { - let mid = (high.clone() + low.clone()) / two.clone(); - if self.cdf(mid.clone()) >= p { - high = mid; - } else { - low = mid; - } + + let two = K::one() + K::one(); + let mut ub = two.clone(); + let lb = self.min(); + while self.cdf(ub.clone()) < p { + ub *= two.clone(); } - high + + internal::integral_bisection_search(|p| self.cdf(p.clone()), p, lb, ub).unwrap() } } diff --git a/src/distribution/multinomial.rs b/src/distribution/multinomial.rs index a4f0524d..d6304214 100644 --- a/src/distribution/multinomial.rs +++ b/src/distribution/multinomial.rs @@ -1,9 +1,7 @@ use crate::distribution::Discrete; use crate::function::factorial; use crate::statistics::*; -use crate::{Result, StatsError}; -use ::nalgebra::{DMatrix, DVector}; -use rand::Rng; +use nalgebra::{DVector, Dim, Dyn, OMatrix, OVector}; /// Implements the /// [Multinomial](https://en.wikipedia.org/wiki/Multinomial_distribution) @@ -16,18 +14,54 @@ use rand::Rng; /// ``` /// use statrs::distribution::Multinomial; /// use statrs::statistics::MeanN; -/// use nalgebra::DVector; +/// use nalgebra::vector; /// -/// let n = Multinomial::new(&[0.3, 0.7], 5).unwrap(); -/// assert_eq!(n.mean().unwrap(), DVector::from_vec(vec![1.5, 3.5])); +/// let n = Multinomial::new_from_nalgebra(vector![0.3, 0.7], 5).unwrap(); +/// assert_eq!(n.mean().unwrap(), (vector![1.5, 3.5])); /// ``` #[derive(Debug, Clone, PartialEq)] -pub struct Multinomial { - p: Vec, +pub struct Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + /// normalized probabilities for each species + p: OVector, + /// count of trials n: u64, } -impl Multinomial { +/// Represents the errors that can occur when creating a [`Multinomial`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum MultinomialError { + /// Fewer than two probabilities. + NotEnoughProbabilities, + + /// The sum of all probabilities is zero. + ProbabilitySumZero, + + /// At least one probability is NaN, infinite or less than zero. + ProbabilityInvalid, +} + +impl std::fmt::Display for MultinomialError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + MultinomialError::NotEnoughProbabilities => write!(f, "Fewer than two probabilities"), + MultinomialError::ProbabilitySumZero => write!(f, "The probabilities sum up to zero"), + MultinomialError::ProbabilityInvalid => write!( + f, + "At least one probability is NaN, infinity or less than zero" + ), + } + } +} + +impl std::error::Error for MultinomialError {} + +impl Multinomial { /// Constructs a new multinomial distribution with probabilities `p` /// and `n` number of trials. /// @@ -45,18 +79,42 @@ impl Multinomial { /// ``` /// use statrs::distribution::Multinomial; /// - /// let mut result = Multinomial::new(&[0.0, 1.0, 2.0], 3); + /// let mut result = Multinomial::new(vec![0.0, 1.0, 2.0], 3); /// assert!(result.is_ok()); /// - /// result = Multinomial::new(&[0.0, -1.0, 2.0], 3); + /// result = Multinomial::new(vec![0.0, -1.0, 2.0], 3); /// assert!(result.is_err()); /// ``` - pub fn new(p: &[f64], n: u64) -> Result { - if !super::internal::is_valid_multinomial(p, true) { - Err(StatsError::BadParams) - } else { - Ok(Multinomial { p: p.to_vec(), n }) + pub fn new(p: Vec, n: u64) -> Result { + Self::new_from_nalgebra(p.into(), n) + } +} + +impl Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + pub fn new_from_nalgebra(mut p: OVector, n: u64) -> Result { + if p.len() < 2 { + return Err(MultinomialError::NotEnoughProbabilities); } + + let mut sum = 0.0; + for &val in &p { + if val.is_nan() || val < 0.0 { + return Err(MultinomialError::ProbabilityInvalid); + } + + sum += val; + } + + if sum == 0.0 { + return Err(MultinomialError::ProbabilitySumZero); + } + + p.unscale_mut(p.lp_norm(1)); + Ok(Self { p, n }) } /// Returns the probabilities of the multinomial @@ -66,11 +124,12 @@ impl Multinomial { /// /// ``` /// use statrs::distribution::Multinomial; + /// use nalgebra::dvector; /// - /// let n = Multinomial::new(&[0.0, 1.0, 2.0], 3).unwrap(); - /// assert_eq!(n.p(), [0.0, 1.0, 2.0]); + /// let n = Multinomial::new(vec![0.0, 1.0, 2.0], 3).unwrap(); + /// assert_eq!(*n.p(), dvector![0.0, 1.0/3.0, 2.0/3.0]); /// ``` - pub fn p(&self) -> &[f64] { + pub fn p(&self) -> &OVector { &self.p } @@ -82,7 +141,7 @@ impl Multinomial { /// ``` /// use statrs::distribution::Multinomial; /// - /// let n = Multinomial::new(&[0.0, 1.0, 2.0], 3).unwrap(); + /// let n = Multinomial::new(vec![0.0, 1.0, 2.0], 3).unwrap(); /// assert_eq!(n.n(), 3); /// ``` pub fn n(&self) -> u64 { @@ -90,10 +149,27 @@ impl Multinomial { } } -impl ::rand::distributions::Distribution> for Multinomial { - fn sample(&self, rng: &mut R) -> Vec { - let p_cdf = super::categorical::prob_mass_to_cdf(self.p()); - let mut res = vec![0.0; self.p.len()]; +impl std::fmt::Display for Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Multinom({:#?},{})", self.p, self.n) + } +} + +#[cfg(feature = "rand")] +impl ::rand::distributions::Distribution> for Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ + fn sample(&self, rng: &mut R) -> OVector { + use nalgebra::Const; + + let p_cdf = super::categorical::prob_mass_to_cdf(self.p().as_slice()); + let mut res = OVector::zeros_generic(self.p.shape_generic().0, Const::<1>); for _ in 0..self.n { let i = super::categorical::sample_unchecked(rng, &p_cdf); let el = res.get_mut(i as usize).unwrap(); @@ -103,12 +179,16 @@ impl ::rand::distributions::Distribution> for Multinomial { } } -impl MeanN> for Multinomial { +impl MeanN> for Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, +{ /// Returns the mean of the multinomial distribution /// /// # Formula /// - /// ```ignore + /// ```text /// n * p_i for i in 1...k /// ``` /// @@ -121,24 +201,37 @@ impl MeanN> for Multinomial { } } -impl VarianceN> for Multinomial { +impl VarianceN> for Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ /// Returns the variance of the multinomial distribution /// /// # Formula /// - /// ```ignore + /// ```text /// n * p_i * (1 - p_i) for i in 1...k /// ``` /// /// where `n` is the number of trials, `p_i` is the `i`th probability, /// and `k` is the total number of probabilities - fn variance(&self) -> Option> { - let cov: Vec<_> = self - .p - .iter() - .map(|x| x * self.n as f64 * (1.0 - x)) - .collect(); - Some(DMatrix::from_diagonal(&DVector::from_vec(cov))) + fn variance(&self) -> Option> { + let mut cov = OMatrix::from_diagonal(&self.p.map(|x| x * (1.0 - x))); + let mut offdiag = |x: usize, y: usize| { + let elt = -self.p[x] * self.p[y]; + // cov[(x, y)] = elt; + cov[(y, x)] = elt; + }; + + for i in 0..self.p.len() { + for j in 0..i { + offdiag(i, j); + } + } + cov.fill_lower_triangle_with_upper_triangle(); + Some(cov.scale(self.n as f64)) } } @@ -147,7 +240,7 @@ impl VarianceN> for Multinomial { // /// // /// # Formula // /// -// /// ```ignore +// /// ```text // /// (1 - 2 * p_i) / (n * p_i * (1 - p_i)) for i in 1...k // /// ``` // /// @@ -163,7 +256,12 @@ impl VarianceN> for Multinomial { // } // } -impl<'a> Discrete<&'a [u64], f64> for Multinomial { +impl<'a, D> Discrete<&'a OVector, f64> for Multinomial +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ /// Calculates the probability mass function for the multinomial /// distribution /// with the given `x`'s corresponding to the probabilities for this @@ -171,26 +269,25 @@ impl<'a> Discrete<&'a [u64], f64> for Multinomial { /// /// # Panics /// - /// If the elements in `x` do not sum to `n` or if the length of `x` is not - /// equivalent to the length of `p` + /// If length of `x` is not equal to length of `p` /// /// # Formula /// - /// ```ignore + /// ```text /// (n! / x_1!...x_k!) * p_i^x_i for i in 1...k /// ``` /// /// where `n` is the number of trials, `p_i` is the `i`th probability, /// `x_i` is the `i`th `x` value, and `k` is the total number of /// probabilities - fn pmf(&self, x: &[u64]) -> f64 { + fn pmf(&self, x: &OVector) -> f64 { if self.p.len() != x.len() { panic!("Expected x and p to have equal lengths."); } if x.iter().sum::() != self.n { return 0.0; } - let coeff = factorial::multinomial(self.n, x); + let coeff = factorial::multinomial(self.n, x.as_slice()); let val = coeff * self .p @@ -207,26 +304,25 @@ impl<'a> Discrete<&'a [u64], f64> for Multinomial { /// /// # Panics /// - /// If the elements in `x` do not sum to `n` or if the length of `x` is not - /// equivalent to the length of `p` + /// If length of `x` is not equal to length of `p` /// /// # Formula /// - /// ```ignore + /// ```text /// ln((n! / x_1!...x_k!) * p_i^x_i) for i in 1...k /// ``` /// /// where `n` is the number of trials, `p_i` is the `i`th probability, /// `x_i` is the `i`th `x` value, and `k` is the total number of /// probabilities - fn ln_pmf(&self, x: &[u64]) -> f64 { + fn ln_pmf(&self, x: &OVector) -> f64 { if self.p.len() != x.len() { panic!("Expected x and p to have equal lengths."); } if x.iter().sum::() != self.n { return f64::NEG_INFINITY; } - let coeff = factorial::multinomial(self.n, x).ln(); + let coeff = factorial::multinomial(self.n, x.as_slice()).ln(); let val = coeff + self .p @@ -238,143 +334,221 @@ impl<'a> Discrete<&'a [u64], f64> for Multinomial { } } -// TODO: fix tests -// #[rustfmt::skip] -// #[cfg(test)] -// mod tests { -// use crate::statistics::*; -// use crate::distribution::{Discrete, Multinomial}; -// use crate::consts::ACC; - -// fn try_create(p: &[f64], n: u64) -> Multinomial { -// let dist = Multinomial::new(p, n); -// assert!(dist.is_ok()); -// dist.unwrap() -// } +#[rustfmt::skip] +#[cfg(test)] +mod tests { + use crate::{ + distribution::{Discrete, Multinomial, MultinomialError}, + statistics::{MeanN, VarianceN}, + }; + use nalgebra::{dmatrix, dvector, vector, DimMin, Dyn, OVector}; + use std::fmt::{Debug, Display}; -// fn create_case(p: &[f64], n: u64) { -// let dist = try_create(p, n); -// assert_eq!(dist.p(), p); -// assert_eq!(dist.n(), n); -// } + fn try_create(p: OVector, n: u64) -> Multinomial + where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, + { + let mvn = Multinomial::new_from_nalgebra(p, n); + assert!(mvn.is_ok()); + mvn.unwrap() + } -// fn bad_create_case(p: &[f64], n: u64) { -// let dist = Multinomial::new(p, n); -// assert!(dist.is_err()); -// } + fn bad_create_case(p: OVector, n: u64) -> MultinomialError + where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, + { + let dd = Multinomial::new_from_nalgebra(p, n); + assert!(dd.is_err()); + dd.unwrap_err() + } -// fn test_case(p: &[f64], n: u64, expected: &[f64], eval: F) -// where F: Fn(Multinomial) -> Vec -// { -// let dist = try_create(p, n); -// let x = eval(dist); -// assert_eq!(*expected, *x); -// } + fn test_almost(p: OVector, n: u64, expected: T, acc: f64, eval: F) + where + T: Debug + Display + approx::RelativeEq, + F: FnOnce(Multinomial) -> T, + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator, + { + let dd = try_create(p, n); + let x = eval(dd); + assert_relative_eq!(expected, x, epsilon = acc); + } -// fn test_almost(p: &[f64], n: u64, expected: &[f64], acc: f64, eval: F) -// where F: Fn(Multinomial) -> Vec -// { -// let dist = try_create(p, n); -// let x = eval(dist); -// assert_eq!(expected.len(), x.len()); -// for i in 0..expected.len() { -// assert_almost_eq!(expected[i], x[i], acc); -// } -// } + #[test] + fn test_create() { + assert_relative_eq!( + *try_create(vector![1.0, 1.0, 1.0], 4).p(), + vector![1.0 / 3.0, 1.0 / 3.0, 1.0 / 3.0] + ); + try_create(dvector![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], 4); + } -// fn test_almost_sr(p: &[f64], n: u64, expected: f64, acc:f64, eval: F) -// where F: Fn(Multinomial) -> f64 -// { -// let dist = try_create(p, n); -// let x = eval(dist); -// assert_almost_eq!(expected, x, acc); -// } + #[test] + fn test_bad_create() { + assert_eq!( + bad_create_case(vector![0.5], 4), + MultinomialError::NotEnoughProbabilities, + ); -// #[test] -// fn test_create() { -// create_case(&[1.0, 1.0, 1.0], 4); -// create_case(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0], 4); -// } + assert_eq!( + bad_create_case(vector![-1.0, 2.0], 4), + MultinomialError::ProbabilityInvalid, + ); -// #[test] -// fn test_bad_create() { -// bad_create_case(&[-1.0, 1.0], 4); -// bad_create_case(&[0.0, 0.0], 4); -// } + assert_eq!( + bad_create_case(vector![0.0, 0.0], 4), + MultinomialError::ProbabilitySumZero, + ); + assert_eq!( + bad_create_case(vector![1.0, f64::NAN], 4), + MultinomialError::ProbabilityInvalid, + ); + } -// #[test] -// fn test_mean() { -// let mean = |x: Multinomial| x.mean().unwrap(); -// test_case(&[0.3, 0.7], 5, &[1.5, 3.5], mean); -// test_case(&[0.1, 0.3, 0.6], 10, &[1.0, 3.0, 6.0], mean); -// test_case(&[0.15, 0.35, 0.3, 0.2], 20, &[3.0, 7.0, 6.0, 4.0], mean); -// } + #[test] + fn test_mean() { + let mean = |x: Multinomial<_>| x.mean().unwrap(); + test_almost(dvector![0.3, 0.7], 5, dvector![1.5, 3.5], 1e-12, mean); + test_almost( + dvector![0.1, 0.3, 0.6], + 10, + dvector![1.0, 3.0, 6.0], + 1e-12, + mean, + ); + test_almost( + dvector![1.0, 3.0, 6.0], + 10, + dvector![1.0, 3.0, 6.0], + 1e-12, + mean, + ); + test_almost( + dvector![0.15, 0.35, 0.3, 0.2], + 20, + dvector![3.0, 7.0, 6.0, 4.0], + 1e-12, + mean, + ); + } -// #[test] -// fn test_variance() { -// let variance = |x: Multinomial| x.variance().unwrap(); -// test_almost(&[0.3, 0.7], 5, &[1.05, 1.05], 1e-15, variance); -// test_almost(&[0.1, 0.3, 0.6], 10, &[0.9, 2.1, 2.4], 1e-15, variance); -// test_almost(&[0.15, 0.35, 0.3, 0.2], 20, &[2.55, 4.55, 4.2, 3.2], 1e-15, variance); -// } + #[test] + fn test_variance() { + let variance = |x: Multinomial<_>| x.variance().unwrap(); + test_almost( + dvector![0.3, 0.7], + 5, + dmatrix![1.05, -1.05; + -1.05, 1.05], + 1e-15, + variance, + ); + test_almost( + dvector![0.1, 0.3, 0.6], + 10, + dmatrix![0.9, -0.3, -0.6; + -0.3, 2.1, -1.8; + -0.6, -1.8, 2.4; + ], + 1e-15, + variance, + ); + test_almost( + dvector![0.15, 0.35, 0.3, 0.2], + 20, + dmatrix![2.55, -1.05, -0.90, -0.60; + -1.05, 4.55, -2.10, -1.40; + -0.90, -2.10, 4.20, -1.20; + -0.60, -1.40, -1.20, 3.20; + ], + 1e-15, + variance, + ); + } -// // #[test] -// // fn test_skewness() { -// // let skewness = |x: Multinomial| x.skewness().unwrap(); -// // test_almost(&[0.3, 0.7], 5, &[0.390360029179413, -0.390360029179413], 1e-15, skewness); -// // test_almost(&[0.1, 0.3, 0.6], 10, &[0.843274042711568, 0.276026223736942, -0.12909944487358], 1e-15, skewness); -// // test_almost(&[0.15, 0.35, 0.3, 0.2], 20, &[0.438357003759605, 0.140642169281549, 0.195180014589707, 0.335410196624968], 1e-15, skewness); -// // } - -// #[test] -// fn test_pmf() { -// let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg); -// test_almost_sr(&[0.3, 0.7], 10, 0.121060821, 1e-15, pmf(&[1, 9])); -// test_almost_sr(&[0.1, 0.3, 0.6], 10, 0.105815808, 1e-15, pmf(&[1, 3, 6])); -// test_almost_sr(&[0.15, 0.35, 0.3, 0.2], 10, 0.000145152, 1e-15, pmf(&[1, 1, 1, 7])); -// } + // // #[test] + // // fn test_skewness() { + // // let skewness = |x: Multinomial| x.skewness().unwrap(); + // // test_almost(&[0.3, 0.7], 5, &[0.390360029179413, -0.390360029179413], 1e-15, skewness); + // // test_almost(&[0.1, 0.3, 0.6], 10, &[0.843274042711568, 0.276026223736942, -0.12909944487358], 1e-15, skewness); + // // test_almost(&[0.15, 0.35, 0.3, 0.2], 20, &[0.438357003759605, 0.140642169281549, 0.195180014589707, 0.335410196624968], 1e-15, skewness); + // // } -// #[test] -// #[should_panic] -// fn test_pmf_x_wrong_length() { -// let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg); -// let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); -// n.pmf(&[1]); -// } + #[test] + fn test_pmf() { + let pmf = |arg: OVector| move |x: Multinomial<_>| x.pmf(&arg); + test_almost( + dvector![0.3, 0.7], + 10, + 0.121060821, + 1e-15, + pmf(dvector![1, 9]), + ); + test_almost( + dvector![0.1, 0.3, 0.6], + 10, + 0.105815808, + 1e-15, + pmf(dvector![1, 3, 6]), + ); + test_almost( + dvector![0.15, 0.35, 0.3, 0.2], + 10, + 0.000145152, + 1e-15, + pmf(dvector![1, 1, 1, 7]), + ); + } -// #[test] -// #[should_panic] -// fn test_pmf_x_wrong_sum() { -// let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg); -// let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); -// n.pmf(&[1, 3]); -// } + #[test] + fn test_error_is_sync_send() { + fn assert_sync_send() {} + assert_sync_send::(); + } -// #[test] -// fn test_ln_pmf() { -// let large_p = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; -// let n = Multinomial::new(large_p, 45).unwrap(); -// let x = &[1, 2, 3, 4, 5, 6, 7, 8, 9]; -// assert_almost_eq!(n.pmf(x).ln(), n.ln_pmf(x), 1e-13); -// let n2 = Multinomial::new(large_p, 18).unwrap(); -// let x2 = &[1, 1, 1, 2, 2, 2, 3, 3, 3]; -// assert_almost_eq!(n2.pmf(x2).ln(), n2.ln_pmf(x2), 1e-13); -// let n3 = Multinomial::new(large_p, 51).unwrap(); -// let x3 = &[5, 6, 7, 8, 7, 6, 5, 4, 3]; -// assert_almost_eq!(n3.pmf(x3).ln(), n3.ln_pmf(x3), 1e-13); -// } + // #[test] + // #[should_panic] + // fn test_pmf_x_wrong_length() { + // let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg); + // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); + // n.pmf(&[1]); + // } -// #[test] -// #[should_panic] -// fn test_ln_pmf_x_wrong_length() { -// let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); -// n.ln_pmf(&[1]); -// } + // #[test] + // #[should_panic] + // fn test_pmf_x_wrong_sum() { + // let pmf = |arg: &[u64]| move |x: Multinomial| x.pmf(arg); + // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); + // n.pmf(&[1, 3]); + // } -// #[test] -// #[should_panic] -// fn test_ln_pmf_x_wrong_sum() { -// let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); -// n.ln_pmf(&[1, 3]); -// } -// } + // #[test] + // fn test_ln_pmf() { + // let large_p = &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]; + // let n = Multinomial::new(large_p, 45).unwrap(); + // let x = &[1, 2, 3, 4, 5, 6, 7, 8, 9]; + // assert_almost_eq!(n.pmf(x).ln(), n.ln_pmf(x), 1e-13); + // let n2 = Multinomial::new(large_p, 18).unwrap(); + // let x2 = &[1, 1, 1, 2, 2, 2, 3, 3, 3]; + // assert_almost_eq!(n2.pmf(x2).ln(), n2.ln_pmf(x2), 1e-13); + // let n3 = Multinomial::new(large_p, 51).unwrap(); + // let x3 = &[5, 6, 7, 8, 7, 6, 5, 4, 3]; + // assert_almost_eq!(n3.pmf(x3).ln(), n3.ln_pmf(x3), 1e-13); + // } + + // #[test] + // #[should_panic] + // fn test_ln_pmf_x_wrong_length() { + // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); + // n.ln_pmf(&[1]); + // } + + // #[test] + // #[should_panic] + // fn test_ln_pmf_x_wrong_sum() { + // let n = Multinomial::new(&[0.3, 0.7], 10).unwrap(); + // n.ln_pmf(&[1, 3]); + // } +} diff --git a/src/distribution/multivariate_normal.rs b/src/distribution/multivariate_normal.rs index ab168020..b336d9db 100644 --- a/src/distribution/multivariate_normal.rs +++ b/src/distribution/multivariate_normal.rs @@ -1,16 +1,82 @@ use crate::distribution::Continuous; -use crate::distribution::Normal; use crate::statistics::{Max, MeanN, Min, Mode, VarianceN}; -use crate::{Result, StatsError}; -use nalgebra::{ - base::allocator::Allocator, base::dimension::DimName, Cholesky, DefaultAllocator, Dim, DimMin, - LU, U1, -}; -use nalgebra::{DMatrix, DVector}; -use rand::Rng; +use crate::StatsError; +use nalgebra::{Cholesky, Const, DMatrix, DVector, Dim, DimMin, Dyn, OMatrix, OVector}; use std::f64; use std::f64::consts::{E, PI}; +/// computes both the normalization and exponential argument in the normal distribution +/// # Errors +/// will error on dimension mismatch +pub(super) fn density_normalization_and_exponential( + mu: &OVector, + cov: &OMatrix, + precision: &OMatrix, + x: &OVector, +) -> std::result::Result<(f64, f64), StatsError> +where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator<(usize, usize), D>, +{ + Ok(( + density_distribution_pdf_const(mu, cov)?, + density_distribution_exponential(mu, precision, x)?, + )) +} + +/// computes the argument of the exponential term in the normal distribution +/// ```text +/// ``` +/// # Errors +/// will error on dimension mismatch +#[inline] +pub(super) fn density_distribution_exponential( + mu: &OVector, + precision: &OMatrix, + x: &OVector, +) -> std::result::Result +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ + if x.shape_generic().0 != precision.shape_generic().0 + || x.shape_generic().0 != mu.shape_generic().0 + || !precision.is_square() + { + return Err(StatsError::ContainersMustBeSameLength); + } + let dv = x - mu; + let exp_term: f64 = -0.5 * (precision * &dv).dot(&dv); + Ok(exp_term) + // TODO update to dimension mismatch error +} + +/// computes the argument of the normalization term in the normal distribution +/// # Errors +/// will error on dimension mismatch +#[inline] +pub(super) fn density_distribution_pdf_const( + mu: &OVector, + cov: &OMatrix, +) -> std::result::Result +where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator<(usize, usize), D>, +{ + if cov.shape_generic().0 != mu.shape_generic().0 || !cov.is_square() { + return Err(StatsError::ContainersMustBeSameLength); + } + let cov_det = cov.determinant(); + Ok(((2. * PI).powi(mu.nrows() as i32) * cov_det.abs()) + .recip() + .sqrt()) +} + /// Implements the [Multivariate Normal](https://en.wikipedia.org/wiki/Multivariate_normal_distribution) /// distribution using the "nalgebra" crate for matrix operations /// @@ -18,72 +84,140 @@ use std::f64::consts::{E, PI}; /// /// ``` /// use statrs::distribution::{MultivariateNormal, Continuous}; -/// use nalgebra::{DVector, DMatrix}; +/// use nalgebra::{matrix, vector}; /// use statrs::statistics::{MeanN, VarianceN}; /// -/// let mvn = MultivariateNormal::new(vec![0., 0.], vec![1., 0., 0., 1.]).unwrap(); -/// assert_eq!(mvn.mean().unwrap(), DVector::from_vec(vec![0., 0.])); -/// assert_eq!(mvn.variance().unwrap(), DMatrix::from_vec(2, 2, vec![1., 0., 0., 1.])); -/// assert_eq!(mvn.pdf(&DVector::from_vec(vec![1., 1.])), 0.05854983152431917); +/// let mvn = MultivariateNormal::new_from_nalgebra(vector![0., 0.], matrix![1., 0.; 0., 1.]).unwrap(); +/// assert_eq!(mvn.mean().unwrap(), vector![0., 0.]); +/// assert_eq!(mvn.variance().unwrap(), matrix![1., 0.; 0., 1.]); +/// assert_eq!(mvn.pdf(&vector![1., 1.]), 0.05854983152431917); /// ``` -#[derive(Debug, Clone, PartialEq)] -pub struct MultivariateNormal { - dim: usize, - cov_chol_decomp: DMatrix, - mu: DVector, - cov: DMatrix, - precision: DMatrix, +#[derive(Clone, PartialEq, Debug)] +pub struct MultivariateNormal +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ + cov_chol_decomp: OMatrix, + mu: OVector, + cov: OMatrix, + precision: OMatrix, pdf_const: f64, } -impl MultivariateNormal { - /// Constructs a new multivariate normal distribution with a mean of `mean` +/// Represents the errors that can occur when creating a [`MultivariateNormal`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum MultivariateNormalError { + /// The covariance matrix is asymmetric or contains a NaN. + CovInvalid, + + /// The mean vector contains a NaN. + MeanInvalid, + + /// The amount of rows in the vector of means is not equal to the amount + /// of rows in the covariance matrix. + DimensionMismatch, + + /// After all other validation, computing the Cholesky decomposition failed. + CholeskyFailed, +} + +impl std::fmt::Display for MultivariateNormalError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + MultivariateNormalError::CovInvalid => { + write!(f, "Covariance matrix is asymmetric or contains a NaN") + } + MultivariateNormalError::MeanInvalid => write!(f, "Mean vector contains a NaN"), + MultivariateNormalError::DimensionMismatch => write!( + f, + "Mean vector and covariance matrix do not have the same number of rows" + ), + MultivariateNormalError::CholeskyFailed => { + write!(f, "Computing the Cholesky decomposition failed") + } + } + } +} + +impl std::error::Error for MultivariateNormalError {} + +impl MultivariateNormal { + /// Constructs a new multivariate normal distribution with a mean of `mean` /// and covariance matrix `cov` /// /// # Errors /// /// Returns an error if the given covariance matrix is not /// symmetric or positive-definite - pub fn new(mean: Vec, cov: Vec) -> Result { + pub fn new(mean: Vec, cov: Vec) -> Result { let mean = DVector::from_vec(mean); let cov = DMatrix::from_vec(mean.len(), mean.len(), cov); - let dim = mean.len(); - // Check that the provided covariance matrix is symmetric - if cov.lower_triangle() != cov.upper_triangle().transpose() - // Check that mean and covariance do not contain NaN - || mean.iter().any(|f| f.is_nan()) + MultivariateNormal::new_from_nalgebra(mean, cov) + } +} + +impl MultivariateNormal +where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator<(usize, usize), D>, +{ + /// Constructs a new multivariate normal distribution with a mean of `mean` + /// and covariance matrix `cov` using `nalgebra` `OVector` and `OMatrix` + /// instead of `Vec` + /// + /// # Errors + /// + /// Returns an error if the given covariance matrix is not + /// symmetric or positive-definite + pub fn new_from_nalgebra( + mean: OVector, + cov: OMatrix, + ) -> Result { + if mean.iter().any(|f| f.is_nan()) { + return Err(MultivariateNormalError::MeanInvalid); + } + + if !cov.is_square() + || cov.lower_triangle() != cov.upper_triangle().transpose() || cov.iter().any(|f| f.is_nan()) - // Check that the dimensions match - || mean.nrows() != cov.nrows() || cov.nrows() != cov.ncols() { - return Err(StatsError::BadParams); + return Err(MultivariateNormalError::CovInvalid); } - let cov_det = cov.determinant(); - let pdf_const = ((2. * PI).powi(mean.nrows() as i32) * cov_det.abs()) - .recip() - .sqrt(); + + // Compare number of rows + if mean.shape_generic().0 != cov.shape_generic().0 { + return Err(MultivariateNormalError::DimensionMismatch); + } + // Store the Cholesky decomposition of the covariance matrix // for sampling match Cholesky::new(cov.clone()) { - None => Err(StatsError::BadParams), + None => Err(MultivariateNormalError::CholeskyFailed), Some(cholesky_decomp) => { let precision = cholesky_decomp.inverse(); Ok(MultivariateNormal { - dim, + // .unwrap() because prerequisites are already checked above + pdf_const: density_distribution_pdf_const(&mean, &cov).unwrap(), cov_chol_decomp: cholesky_decomp.unpack(), mu: mean, cov, precision, - pdf_const, }) } } } + /// Returns the entropy of the multivariate normal distribution /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / 2) * ln(det(2 * π * e * Σ)) /// ``` /// @@ -100,257 +234,480 @@ impl MultivariateNormal { } } -impl ::rand::distributions::Distribution> for MultivariateNormal { +impl std::fmt::Display for MultivariateNormal +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "N({}, {})", &self.mu, &self.cov) + } +} + +#[cfg(feature = "rand")] +impl ::rand::distributions::Distribution> for MultivariateNormal +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ /// Samples from the multivariate normal distribution /// /// # Formula + /// ```text /// L * Z + μ + /// ``` /// /// where `L` is the Cholesky decomposition of the covariance matrix, /// `Z` is a vector of normally distributed random variables, and /// `μ` is the mean vector - fn sample(&self, rng: &mut R) -> DVector { - let d = Normal::new(0., 1.).unwrap(); - let z = DVector::::from_distribution(self.dim, &d, rng); + fn sample(&self, rng: &mut R) -> OVector { + let d = crate::distribution::Normal::new(0., 1.).unwrap(); + let z = OVector::from_distribution_generic(self.mu.shape_generic().0, Const::<1>, &d, rng); (&self.cov_chol_decomp * z) + &self.mu } } -impl Min> for MultivariateNormal { +impl Min> for MultivariateNormal +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ /// Returns the minimum value in the domain of the /// multivariate normal distribution represented by a real vector - fn min(&self) -> DVector { - DVector::from_vec(vec![f64::NEG_INFINITY; self.dim]) + fn min(&self) -> OVector { + OMatrix::repeat_generic(self.mu.shape_generic().0, Const::<1>, f64::NEG_INFINITY) } } -impl Max> for MultivariateNormal { +impl Max> for MultivariateNormal +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ /// Returns the maximum value in the domain of the /// multivariate normal distribution represented by a real vector - fn max(&self) -> DVector { - DVector::from_vec(vec![f64::INFINITY; self.dim]) + fn max(&self) -> OVector { + OMatrix::repeat_generic(self.mu.shape_generic().0, Const::<1>, f64::INFINITY) } } -impl MeanN> for MultivariateNormal { +impl MeanN> for MultivariateNormal +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ /// Returns the mean of the normal distribution /// /// # Remarks /// /// This is the same mean used to construct the distribution - fn mean(&self) -> Option> { - let mut vec = vec![]; - for elt in self.mu.clone().into_iter() { - vec.push(*elt); - } - Some(DVector::from_vec(vec)) + fn mean(&self) -> Option> { + Some(self.mu.clone()) } } -impl VarianceN> for MultivariateNormal { +impl VarianceN> for MultivariateNormal +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ /// Returns the covariance matrix of the multivariate normal distribution - fn variance(&self) -> Option> { + fn variance(&self) -> Option> { Some(self.cov.clone()) } } -impl Mode> for MultivariateNormal { +impl Mode> for MultivariateNormal +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ /// Returns the mode of the multivariate normal distribution /// /// # Formula /// - /// ```ignore + /// ```text /// μ /// ``` /// /// where `μ` is the mean - fn mode(&self) -> DVector { + fn mode(&self) -> OVector { self.mu.clone() } } -impl<'a> Continuous<&'a DVector, f64> for MultivariateNormal { +impl<'a, D> Continuous<&'a OVector, f64> for MultivariateNormal +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ /// Calculates the probability density function for the multivariate /// normal distribution at `x` /// /// # Formula /// - /// ```ignore + /// ```text /// (2 * π) ^ (-k / 2) * det(Σ) ^ (1 / 2) * e ^ ( -(1 / 2) * transpose(x - μ) * inv(Σ) * (x - μ)) /// ``` /// /// where `μ` is the mean, `inv(Σ)` is the precision matrix, `det(Σ)` is the determinant /// of the covariance matrix, and `k` is the dimension of the distribution - fn pdf(&self, x: &'a DVector) -> f64 { - let dv = x - &self.mu; - let exp_term = -0.5 - * *(&dv.transpose() * &self.precision * &dv) - .get((0, 0)) - .unwrap(); - self.pdf_const * exp_term.exp() + fn pdf(&self, x: &OVector) -> f64 { + self.pdf_const + * density_distribution_exponential(&self.mu, &self.precision, x) + .unwrap() + .exp() } + /// Calculates the log probability density function for the multivariate /// normal distribution at `x`. Equivalent to pdf(x).ln(). - fn ln_pdf(&self, x: &'a DVector) -> f64 { - let dv = x - &self.mu; - let exp_term = -0.5 - * *(&dv.transpose() * &self.precision * &dv) - .get((0, 0)) - .unwrap(); - self.pdf_const.ln() + exp_term + fn ln_pdf(&self, x: &OVector) -> f64 { + self.pdf_const.ln() + + density_distribution_exponential(&self.mu, &self.precision, x).unwrap() } } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use crate::distribution::{Continuous, MultivariateNormal}; - use crate::statistics::*; - use crate::consts::ACC; use core::fmt::Debug; - use nalgebra::base::allocator::Allocator; - use nalgebra::{ - DefaultAllocator, Dim, DimMin, DimName, Matrix2, Matrix3, Vector2, Vector3, - U1, U2, + + use nalgebra::{dmatrix, dvector, matrix, vector, DimMin, OMatrix, OVector}; + + use crate::{ + distribution::{Continuous, MultivariateNormal}, + statistics::{Max, MeanN, Min, Mode, VarianceN}, }; - fn try_create(mean: Vec, covariance: Vec) -> MultivariateNormal + use super::MultivariateNormalError; + + fn try_create(mean: OVector, covariance: OMatrix) -> MultivariateNormal + where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator<(usize, usize), D>, { - let mvn = MultivariateNormal::new(mean, covariance); + let mvn = MultivariateNormal::new_from_nalgebra(mean, covariance); assert!(mvn.is_ok()); mvn.unwrap() } - fn create_case(mean: Vec, covariance: Vec) + fn create_case(mean: OVector, covariance: OMatrix) + where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator<(usize, usize), D>, { let mvn = try_create(mean.clone(), covariance.clone()); - assert_eq!(DVector::from_vec(mean.clone()), mvn.mean().unwrap()); - assert_eq!(DMatrix::from_vec(mean.len(), mean.len(), covariance), mvn.variance().unwrap()); + assert_eq!(mean, mvn.mean().unwrap()); + assert_eq!(covariance, mvn.variance().unwrap()); } - fn bad_create_case(mean: Vec, covariance: Vec) + fn bad_create_case(mean: OVector, covariance: OMatrix) + where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator<(usize, usize), D>, { - let mvn = MultivariateNormal::new(mean, covariance); + let mvn = MultivariateNormal::new_from_nalgebra(mean, covariance); assert!(mvn.is_err()); } - fn test_case(mean: Vec, covariance: Vec, expected: T, eval: F) - where + fn test_case( + mean: OVector, covariance: OMatrix, expected: T, eval: F, + ) where T: Debug + PartialEq, - F: FnOnce(MultivariateNormal) -> T, + F: FnOnce(MultivariateNormal) -> T, + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator<(usize, usize), D>, { let mvn = try_create(mean, covariance); let x = eval(mvn); assert_eq!(expected, x); } - fn test_almost( - mean: Vec, - covariance: Vec, - expected: f64, - acc: f64, - eval: F, + fn test_almost( + mean: OVector, covariance: OMatrix, expected: f64, acc: f64, eval: F, ) where - F: FnOnce(MultivariateNormal) -> f64, + F: FnOnce(MultivariateNormal) -> f64, + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator<(usize, usize), D>, { let mvn = try_create(mean, covariance); let x = eval(mvn); assert_almost_eq!(expected, x, acc); } - use super::*; - - macro_rules! dvec { - ($($x:expr),*) => (DVector::from_vec(vec![$($x),*])); - } - - macro_rules! mat2 { - ($x11:expr, $x12:expr, $x21:expr, $x22:expr) => (DMatrix::from_vec(2,2,vec![$x11, $x12, $x21, $x22])); - } - - // macro_rules! mat3 { - // ($x11:expr, $x12:expr, $x13:expr, $x21:expr, $x22:expr, $x23:expr, $x31:expr, $x32:expr, $x33:expr) => (DMatrix::from_vec(3,3,vec![$x11, $x12, $x13, $x21, $x22, $x23, $x31, $x32, $x33])); - // } - #[test] fn test_create() { - create_case(vec![0., 0.], vec![1., 0., 0., 1.]); - create_case(vec![10., 5.], vec![2., 1., 1., 2.]); - create_case(vec![4., 5., 6.], vec![2., 1., 0., 1., 2., 1., 0., 1., 2.]); - create_case(vec![0., f64::INFINITY], vec![1., 0., 0., 1.]); - create_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY]); + create_case(vector![0., 0.], matrix![1., 0.; 0., 1.]); + create_case(vector![10., 5.], matrix![2., 1.; 1., 2.]); + create_case( + vector![4., 5., 6.], + matrix![2., 1., 0.; 1., 2., 1.; 0., 1., 2.], + ); + create_case(dvector![0., f64::INFINITY], dmatrix![1., 0.; 0., 1.]); + create_case( + dvector![0., 0.], + dmatrix![f64::INFINITY, 0.; 0., f64::INFINITY], + ); } #[test] fn test_bad_create() { // Covariance not symmetric - bad_create_case(vec![0., 0.], vec![1., 1., 0., 1.]); + bad_create_case(vector![0., 0.], matrix![1., 1.; 0., 1.]); // Covariance not positive-definite - bad_create_case(vec![0., 0.], vec![1., 2., 2., 1.]); + bad_create_case(vector![0., 0.], matrix![1., 2.; 2., 1.]); // NaN in mean - bad_create_case(vec![0., f64::NAN], vec![1., 0., 0., 1.]); + bad_create_case(dvector![0., f64::NAN], dmatrix![1., 0.; 0., 1.]); // NaN in Covariance Matrix - bad_create_case(vec![0., 0.], vec![1., 0., 0., f64::NAN]); + bad_create_case(dvector![0., 0.], dmatrix![1., 0.; 0., f64::NAN]); } #[test] fn test_variance() { - let variance = |x: MultivariateNormal| x.variance().unwrap(); - test_case(vec![0., 0.], vec![1., 0., 0., 1.], mat2![1., 0., 0., 1.], variance); - test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], mat2![f64::INFINITY, 0., 0., f64::INFINITY], variance); + let variance = |x: MultivariateNormal<_>| x.variance().unwrap(); + test_case( + vector![0., 0.], + matrix![1., 0.; 0., 1.], + matrix![1., 0.; 0., 1.], + variance, + ); + test_case( + vector![0., 0.], + matrix![f64::INFINITY, 0.; 0., f64::INFINITY], + matrix![f64::INFINITY, 0.; 0., f64::INFINITY], + variance, + ); } #[test] fn test_entropy() { - let entropy = |x: MultivariateNormal| x.entropy().unwrap(); - test_case(vec![0., 0.], vec![1., 0., 0., 1.], 2.8378770664093453, entropy); - test_case(vec![0., 0.], vec![1., 0.5, 0.5, 1.], 2.694036030183455, entropy); - test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], f64::INFINITY, entropy); + let entropy = |x: MultivariateNormal<_>| x.entropy().unwrap(); + test_case( + dvector![0., 0.], + dmatrix![1., 0.; 0., 1.], + 2.8378770664093453, + entropy, + ); + test_case( + dvector![0., 0.], + dmatrix![1., 0.5; 0.5, 1.], + 2.694036030183455, + entropy, + ); + test_case( + dvector![0., 0.], + dmatrix![f64::INFINITY, 0.; 0., f64::INFINITY], + f64::INFINITY, + entropy, + ); } #[test] fn test_mode() { - let mode = |x: MultivariateNormal| x.mode(); - test_case(vec![0., 0.], vec![1., 0., 0., 1.], dvec![0., 0.], mode); - test_case(vec![f64::INFINITY, f64::INFINITY], vec![1., 0., 0., 1.], dvec![f64::INFINITY, f64::INFINITY], mode); + let mode = |x: MultivariateNormal<_>| x.mode(); + test_case( + vector![0., 0.], + matrix![1., 0.; 0., 1.], + vector![0., 0.], + mode, + ); + test_case( + vector![f64::INFINITY, f64::INFINITY], + matrix![1., 0.; 0., 1.], + vector![f64::INFINITY, f64::INFINITY], + mode, + ); } #[test] fn test_min_max() { - let min = |x: MultivariateNormal| x.min(); - let max = |x: MultivariateNormal| x.max(); - test_case(vec![0., 0.], vec![1., 0., 0., 1.], dvec![f64::NEG_INFINITY, f64::NEG_INFINITY], min); - test_case(vec![0., 0.], vec![1., 0., 0., 1.], dvec![f64::INFINITY, f64::INFINITY], max); - test_case(vec![10., 1.], vec![1., 0., 0., 1.], dvec![f64::NEG_INFINITY, f64::NEG_INFINITY], min); - test_case(vec![-3., 5.], vec![1., 0., 0., 1.], dvec![f64::INFINITY, f64::INFINITY], max); + let min = |x: MultivariateNormal<_>| x.min(); + let max = |x: MultivariateNormal<_>| x.max(); + test_case( + dvector![0., 0.], + dmatrix![1., 0.; 0., 1.], + dvector![f64::NEG_INFINITY, f64::NEG_INFINITY], + min, + ); + test_case( + dvector![0., 0.], + dmatrix![1., 0.; 0., 1.], + dvector![f64::INFINITY, f64::INFINITY], + max, + ); + test_case( + dvector![10., 1.], + dmatrix![1., 0.; 0., 1.], + dvector![f64::NEG_INFINITY, f64::NEG_INFINITY], + min, + ); + test_case( + dvector![-3., 5.], + dmatrix![1., 0.; 0., 1.], + dvector![f64::INFINITY, f64::INFINITY], + max, + ); } #[test] fn test_pdf() { - let pdf = |arg: DVector| move |x: MultivariateNormal| x.pdf(&arg); - test_case(vec![0., 0.], vec![1., 0., 0., 1.], 0.05854983152431917, pdf(dvec![1., 1.])); - test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 0.013064233284684921, 1e-15, pdf(dvec![1., 2.])); - test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 1.8618676045881531e-23, 1e-35, pdf(dvec![1., 10.])); - test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 5.920684802611216e-45, 1e-58, pdf(dvec![10., 10.])); - test_almost(vec![0., 0.], vec![1., 0.9, 0.9, 1.], 1.6576716577547003e-05, 1e-18, pdf(dvec![1., -1.])); - test_almost(vec![0., 0.], vec![1., 0.99, 0.99, 1.], 4.1970621773477824e-44, 1e-54, pdf(dvec![1., -1.])); - test_almost(vec![0.5, -0.2], vec![2.0, 0.3, 0.3, 0.5], 0.0013075203140666656, 1e-15, pdf(dvec![2., 2.])); - test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 0.0, pdf(dvec![10., 10.])); - test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 0.0, pdf(dvec![100., 100.])); + let pdf = |arg| move |x: MultivariateNormal<_>| x.pdf(&arg); + test_case( + vector![0., 0.], + matrix![1., 0.; 0., 1.], + 0.05854983152431917, + pdf(vector![1., 1.]), + ); + test_almost( + vector![0., 0.], + matrix![1., 0.; 0., 1.], + 0.013064233284684921, + 1e-15, + pdf(vector![1., 2.]), + ); + test_almost( + vector![0., 0.], + matrix![1., 0.; 0., 1.], + 1.8618676045881531e-23, + 1e-35, + pdf(vector![1., 10.]), + ); + test_almost( + vector![0., 0.], + matrix![1., 0.; 0., 1.], + 5.920684802611216e-45, + 1e-58, + pdf(vector![10., 10.]), + ); + test_almost( + vector![0., 0.], + matrix![1., 0.9; 0.9, 1.], + 1.6576716577547003e-05, + 1e-18, + pdf(vector![1., -1.]), + ); + test_almost( + vector![0., 0.], + matrix![1., 0.99; 0.99, 1.], + 4.1970621773477824e-44, + 1e-54, + pdf(vector![1., -1.]), + ); + test_almost( + vector![0.5, -0.2], + matrix![2.0, 0.3; 0.3, 0.5], + 0.0013075203140666656, + 1e-15, + pdf(vector![2., 2.]), + ); + test_case( + vector![0., 0.], + matrix![f64::INFINITY, 0.; 0., f64::INFINITY], + 0.0, + pdf(vector![10., 10.]), + ); + test_case( + vector![0., 0.], + matrix![f64::INFINITY, 0.; 0., f64::INFINITY], + 0.0, + pdf(vector![100., 100.]), + ); } #[test] fn test_ln_pdf() { - let ln_pdf = |arg: DVector<_>| move |x: MultivariateNormal| x.ln_pdf(&arg); - test_case(vec![0., 0.], vec![1., 0., 0., 1.], (0.05854983152431917f64).ln(), ln_pdf(dvec![1., 1.])); - test_almost(vec![0., 0.], vec![1., 0., 0., 1.], (0.013064233284684921f64).ln(), 1e-15, ln_pdf(dvec![1., 2.])); - test_almost(vec![0., 0.], vec![1., 0., 0., 1.], (1.8618676045881531e-23f64).ln(), 1e-15, ln_pdf(dvec![1., 10.])); - test_almost(vec![0., 0.], vec![1., 0., 0., 1.], (5.920684802611216e-45f64).ln(), 1e-15, ln_pdf(dvec![10., 10.])); - test_almost(vec![0., 0.], vec![1., 0.9, 0.9, 1.], (1.6576716577547003e-05f64).ln(), 1e-14, ln_pdf(dvec![1., -1.])); - test_almost(vec![0., 0.], vec![1., 0.99, 0.99, 1.], (4.1970621773477824e-44f64).ln(), 1e-12, ln_pdf(dvec![1., -1.])); - test_almost(vec![0.5, -0.2], vec![2.0, 0.3, 0.3, 0.5], (0.0013075203140666656f64).ln(), 1e-15, ln_pdf(dvec![2., 2.])); - test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], f64::NEG_INFINITY, ln_pdf(dvec![10., 10.])); - test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], f64::NEG_INFINITY, ln_pdf(dvec![100., 100.])); + let ln_pdf = |arg| move |x: MultivariateNormal<_>| x.ln_pdf(&arg); + test_case( + dvector![0., 0.], + dmatrix![1., 0.; 0., 1.], + (0.05854983152431917f64).ln(), + ln_pdf(dvector![1., 1.]), + ); + test_almost( + dvector![0., 0.], + dmatrix![1., 0.; 0., 1.], + (0.013064233284684921f64).ln(), + 1e-15, + ln_pdf(dvector![1., 2.]), + ); + test_almost( + dvector![0., 0.], + dmatrix![1., 0.; 0., 1.], + (1.8618676045881531e-23f64).ln(), + 1e-15, + ln_pdf(dvector![1., 10.]), + ); + test_almost( + dvector![0., 0.], + dmatrix![1., 0.; 0., 1.], + (5.920684802611216e-45f64).ln(), + 1e-15, + ln_pdf(dvector![10., 10.]), + ); + test_almost( + dvector![0., 0.], + dmatrix![1., 0.9; 0.9, 1.], + (1.6576716577547003e-05f64).ln(), + 1e-14, + ln_pdf(dvector![1., -1.]), + ); + test_almost( + dvector![0., 0.], + dmatrix![1., 0.99; 0.99, 1.], + (4.1970621773477824e-44f64).ln(), + 1e-12, + ln_pdf(dvector![1., -1.]), + ); + test_almost( + dvector![0.5, -0.2], + dmatrix![2.0, 0.3; 0.3, 0.5], + (0.0013075203140666656f64).ln(), + 1e-15, + ln_pdf(dvector![2., 2.]), + ); + test_case( + dvector![0., 0.], + dmatrix![f64::INFINITY, 0.; 0., f64::INFINITY], + f64::NEG_INFINITY, + ln_pdf(dvector![10., 10.]), + ); + test_case( + dvector![0., 0.], + dmatrix![f64::INFINITY, 0.; 0., f64::INFINITY], + f64::NEG_INFINITY, + ln_pdf(dvector![100., 100.]), + ); + } + + #[test] + #[should_panic] + fn test_pdf_mismatched_arg_size() { + let mvn = MultivariateNormal::new(vec![0., 0.], vec![1., 0., 0., 1.,]).unwrap(); + mvn.pdf(&vec![1.].into()); // x.size != mu.size + } + + #[test] + fn test_error_is_sync_send() { + fn assert_sync_send() {} + assert_sync_send::(); } } diff --git a/src/distribution/multivariate_students_t.rs b/src/distribution/multivariate_students_t.rs new file mode 100644 index 00000000..73e8f8f2 --- /dev/null +++ b/src/distribution/multivariate_students_t.rs @@ -0,0 +1,625 @@ +use crate::distribution::Continuous; +use crate::function::gamma; +use crate::statistics::{Max, MeanN, Min, Mode, VarianceN}; +use nalgebra::{Cholesky, Const, DMatrix, Dim, DimMin, Dyn, OMatrix, OVector}; +use std::f64::consts::PI; + +/// Implements the [Multivariate Student's t-distribution](https://en.wikipedia.org/wiki/Multivariate_t-distribution) +/// distribution using the "nalgebra" crate for matrix operations. +/// +/// Assumes all the marginal distributions have the same degree of freedom, ν. +/// +/// # Examples +/// +/// ``` +/// use statrs::distribution::{MultivariateStudent, Continuous}; +/// use nalgebra::{DVector, DMatrix}; +/// use statrs::statistics::{MeanN, VarianceN}; +/// +/// let mvs = MultivariateStudent::new(vec![0., 0.], vec![1., 0., 0., 1.], 4.).unwrap(); +/// assert_eq!(mvs.mean().unwrap(), DVector::from_vec(vec![0., 0.])); +/// assert_eq!(mvs.variance().unwrap(), DMatrix::from_vec(2, 2, vec![2., 0., 0., 2.])); +/// assert_eq!(mvs.pdf(&DVector::from_vec(vec![1., 1.])), 0.04715702017537655); +/// ``` +#[derive(Debug, Clone, PartialEq)] +pub struct MultivariateStudent +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ + scale_chol_decomp: OMatrix, + location: OVector, + scale: OMatrix, + freedom: f64, + precision: OMatrix, + ln_pdf_const: f64, +} + +/// Represents the errors that can occur when creating a [`MultivariateStudent`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum MultivariateStudentError { + /// The scale matrix is asymmetric or contains a NaN. + ScaleInvalid, + + /// The location vector contains a NaN. + LocationInvalid, + + /// The degrees of freedom are NaN, zero or less than zero. + FreedomInvalid, + + /// The amount of rows in the location vector is not equal to the amount + /// of rows in the scale matrix. + DimensionMismatch, + + /// After all other validation, computing the Cholesky decomposition failed. + /// This means that the scale matrix is not definite-positive. + CholeskyFailed, +} + +impl std::fmt::Display for MultivariateStudentError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + MultivariateStudentError::ScaleInvalid => { + write!(f, "Scale matrix is asymmetric or contains a NaN") + } + MultivariateStudentError::LocationInvalid => { + write!(f, "Location vector contains a NaN") + } + MultivariateStudentError::FreedomInvalid => { + write!(f, "Degrees of freedom are NaN, zero or less than zero") + } + MultivariateStudentError::DimensionMismatch => write!( + f, + "Location vector and scale matrix do not have the same number of rows" + ), + MultivariateStudentError::CholeskyFailed => { + write!(f, "Computing the Cholesky decomposition failed") + } + } + } +} + +impl std::error::Error for MultivariateStudentError {} + +impl MultivariateStudent { + /// Constructs a new multivariate students t distribution with a location of `location`, + /// scale matrix `scale` and `freedom` degrees of freedom. + /// + /// # Errors + /// + /// Returns `StatsError::BadParams` if the scale matrix is not symmetric-positive + /// definite and `StatsError::ArgMustBePositive` if freedom is non-positive. + pub fn new( + location: Vec, + scale: Vec, + freedom: f64, + ) -> Result { + let dim = location.len(); + Self::new_from_nalgebra(location.into(), DMatrix::from_vec(dim, dim, scale), freedom) + } + + /// Returns the dimension of the distribution. + pub fn dim(&self) -> usize { + self.location.len() + } +} + +impl MultivariateStudent +where + D: DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator<(usize, usize), D>, +{ + pub fn new_from_nalgebra( + location: OVector, + scale: OMatrix, + freedom: f64, + ) -> Result { + let dim = location.len(); + + if location.iter().any(|f| f.is_nan()) { + return Err(MultivariateStudentError::LocationInvalid); + } + + if !scale.is_square() + || scale.lower_triangle() != scale.upper_triangle().transpose() + || scale.iter().any(|f| f.is_nan()) + { + return Err(MultivariateStudentError::ScaleInvalid); + } + + if freedom.is_nan() || freedom <= 0.0 { + return Err(MultivariateStudentError::FreedomInvalid); + } + + if location.nrows() != scale.nrows() { + return Err(MultivariateStudentError::DimensionMismatch); + } + + let scale_det = scale.determinant(); + let ln_pdf_const = gamma::ln_gamma(0.5 * (freedom + dim as f64)) + - gamma::ln_gamma(0.5 * freedom) + - 0.5 * (dim as f64) * (freedom * PI).ln() + - 0.5 * scale_det.ln(); + + match Cholesky::new(scale.clone()) { + None => Err(MultivariateStudentError::CholeskyFailed), + Some(cholesky_decomp) => { + let precision = cholesky_decomp.inverse(); + Ok(MultivariateStudent { + scale_chol_decomp: cholesky_decomp.unpack(), + location, + scale, + freedom, + precision, + ln_pdf_const, + }) + } + } + } + + /// Returns the cholesky decomposiiton matrix of the scale matrix. + /// + /// Returns A where Σ = AAᵀ. + pub fn scale_chol_decomp(&self) -> &OMatrix { + &self.scale_chol_decomp + } + + /// Returns the location of the distribution. + pub fn location(&self) -> &OVector { + &self.location + } + + /// Returns the scale matrix of the distribution. + pub fn scale(&self) -> &OMatrix { + &self.scale + } + + /// Returns the degrees of freedom of the distribution. + pub fn freedom(&self) -> f64 { + self.freedom + } + + /// Returns the inverse of the cholesky decomposition matrix. + pub fn precision(&self) -> &OMatrix { + &self.precision + } + + /// Returns the logarithmed constant part of the probability + /// distribution function. + pub fn ln_pdf_const(&self) -> f64 { + self.ln_pdf_const + } +} + +#[cfg(feature = "rand")] +impl ::rand::distributions::Distribution> for MultivariateStudent +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ + /// Samples from the multivariate student distribution + /// + /// # Formula + /// + /// ```math + /// W ⋅ L ⋅ Z + μ + /// ``` + /// + /// where `W` has √(ν/Sν) distribution, Sν has Chi-squared + /// distribution with ν degrees of freedom, + /// `L` is the Cholesky decomposition of the scale matrix, + /// `Z` is a vector of normally distributed random variables, and + /// `μ` is the location vector + fn sample(&self, rng: &mut R) -> OVector { + use crate::distribution::{ChiSquared, Normal}; + + let d = Normal::new(0., 1.).unwrap(); + let s = ChiSquared::new(self.freedom).unwrap(); + let w = (self.freedom / s.sample(rng)).sqrt(); + let (r, c) = self.location.shape_generic(); + let z = OVector::::from_distribution_generic(r, c, &d, rng); + (w * &self.scale_chol_decomp * z) + &self.location + } +} + +impl Min> for MultivariateStudent +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ + /// Returns the minimum value in the domain of the + /// multivariate normal distribution represented by a real vector + fn min(&self) -> OVector { + OMatrix::repeat_generic( + self.location.shape_generic().0, + Const::<1>, + f64::NEG_INFINITY, + ) + } +} + +impl Max> for MultivariateStudent +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ + /// Returns the minimum value in the domain of the + /// multivariate normal distribution represented by a real vector + fn max(&self) -> OVector { + OMatrix::repeat_generic(self.location.shape_generic().0, Const::<1>, f64::INFINITY) + } +} + +impl MeanN> for MultivariateStudent +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ + /// Returns the mean of the student distribution. + /// + /// # Remarks + /// + /// This is the same mean used to construct the distribution if + /// the degrees of freedom is larger than 1. + fn mean(&self) -> Option> { + if self.freedom > 1. { + Some(self.location.clone()) + } else { + None + } + } +} + +impl VarianceN> for MultivariateStudent +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ + /// Returns the covariance matrix of the multivariate student distribution. + /// + /// # Formula + /// + /// ```math + /// Σ ⋅ ν / (ν - 2) + /// ``` + /// + /// where `Σ` is the scale matrix and `ν` is the degrees of freedom. + /// Only defined if freedom is larger than 2. + fn variance(&self) -> Option> { + if self.freedom > 2. { + Some(self.scale.clone() * self.freedom / (self.freedom - 2.)) + } else { + None + } + } +} + +impl Mode> for MultivariateStudent +where + D: Dim, + nalgebra::DefaultAllocator: + nalgebra::allocator::Allocator + nalgebra::allocator::Allocator, +{ + /// Returns the mode of the multivariate student distribution. + /// + /// # Formula + /// + /// ```math + /// μ + /// ``` + /// + /// where `μ` is the location. + fn mode(&self) -> OVector { + self.location.clone() + } +} + +impl<'a, D> Continuous<&'a OVector, f64> for MultivariateStudent +where + D: Dim + DimMin, + nalgebra::DefaultAllocator: nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator + + nalgebra::allocator::Allocator<(usize, usize), D>, +{ + /// Calculates the probability density function for the multivariate. + /// student distribution at `x`. + /// + /// # Formula + /// + /// ```math + /// [Γ(ν+p)/2] / [Γ(ν/2) ((ν * π)^p det(Σ))^(1 / 2)] * [1 + 1/ν (x - μ)ᵀ inv(Σ) (x - μ)]^(-(ν+p)/2) + /// ``` + /// + /// where `ν` is the degrees of freedom, `μ` is the mean, `Γ` + /// is the Gamma function, `inv(Σ)` + /// is the precision matrix, `det(Σ)` is the determinant + /// of the scale matrix, and `k` is the dimension of the distribution. + fn pdf(&self, x: &'a OVector) -> f64 { + if self.freedom.is_infinite() { + use super::multivariate_normal::density_normalization_and_exponential; + let (pdf_const, exp_arg) = density_normalization_and_exponential( + &self.location, + &self.scale, + &self.precision, + x, + ) + .unwrap(); + return pdf_const * exp_arg.exp(); + } + + let dv = x - &self.location; + let exp_arg: f64 = (&self.precision * &dv).dot(&dv); + let base_term = 1. + exp_arg / self.freedom; + self.ln_pdf_const.exp() * base_term.powf(-(self.freedom + self.location.len() as f64) / 2.) + } + + /// Calculates the log probability density function for the multivariate + /// student distribution at `x`. Equivalent to pdf(x).ln(). + fn ln_pdf(&self, x: &'a OVector) -> f64 { + if self.freedom.is_infinite() { + use super::multivariate_normal::density_normalization_and_exponential; + let (pdf_const, exp_arg) = density_normalization_and_exponential( + &self.location, + &self.scale, + &self.precision, + x, + ) + .unwrap(); + return pdf_const.ln() + exp_arg; + } + + let dv = x - &self.location; + let exp_arg: f64 = (&self.precision * &dv).dot(&dv); + let base_term = 1. + exp_arg / self.freedom; + self.ln_pdf_const - (self.freedom + self.location.len() as f64) / 2. * base_term.ln() + } +} + +#[rustfmt::skip] +#[cfg(test)] +mod tests { + use core::fmt::Debug; + + use approx::RelativeEq; + use nalgebra::{DMatrix, DVector, Dyn, OMatrix, OVector, U1, U2}; + + use crate::{ + distribution::{Continuous, MultivariateStudent, MultivariateNormal}, + statistics::{Max, MeanN, Min, Mode, VarianceN}, + }; + + use super::MultivariateStudentError; + + fn try_create(location: Vec, scale: Vec, freedom: f64) -> MultivariateStudent + { + let mvs = MultivariateStudent::new(location, scale, freedom); + assert!(mvs.is_ok()); + mvs.unwrap() + } + + fn create_case(location: Vec, scale: Vec, freedom: f64) + { + let mvs = try_create(location.clone(), scale.clone(), freedom); + assert_eq!(DMatrix::from_vec(location.len(), location.len(), scale), mvs.scale); + assert_eq!(DVector::from_vec(location), mvs.location); + } + + fn bad_create_case(location: Vec, scale: Vec, freedom: f64) + { + let mvs = MultivariateStudent::new(location, scale, freedom); + assert!(mvs.is_err()); + } + + fn test_case(location: Vec, scale: Vec, freedom: f64, expected: T, eval: F) + where + T: Debug + PartialEq, + F: FnOnce(MultivariateStudent) -> T, + { + let mvs = try_create(location, scale, freedom); + let x = eval(mvs); + assert_eq!(expected, x); + } + + fn test_almost( + location: Vec, + scale: Vec, + freedom: f64, + expected: f64, + acc: f64, + eval: F, + ) where + F: FnOnce(MultivariateStudent) -> f64, + { + let mvs = try_create(location, scale, freedom); + let x = eval(mvs); + assert_almost_eq!(expected, x, acc); + } + + fn test_almost_multivariate_normal( + location: Vec, + scale: Vec, + freedom: f64, + acc: f64, + x: DVector, + eval_mvs: F1, + eval_mvn: F2, + ) where + F1: FnOnce(MultivariateStudent, DVector) -> f64, + F2: FnOnce(MultivariateNormal, DVector) -> f64, + { + let mvs = try_create(location.clone(), scale.clone(), freedom); + let mvn0 = MultivariateNormal::new(location, scale); + assert!(mvn0.is_ok()); + let mvn = mvn0.unwrap(); + let mvs_x = eval_mvs(mvs, x.clone()); + let mvn_x = eval_mvn(mvn, x.clone()); + assert!(mvs_x.relative_eq(&mvn_x, acc, acc), "mvn: {mvn_x} =/=\nmvs: {mvs_x}"); + // assert_relative_eq!(mvs_x, mvn_x, acc); + } + + + macro_rules! dvec { + ($($x:expr),*) => (DVector::from_vec(vec![$($x),*])); + } + + macro_rules! mat2 { + ($x11:expr, $x12:expr, $x21:expr, $x22:expr) => (DMatrix::from_vec(2,2,vec![$x11, $x12, $x21, $x22])); + } + + // macro_rules! mat3 { + // ($x11:expr, $x12:expr, $x13:expr, $x21:expr, $x22:expr, $x23:expr, $x31:expr, $x32:expr, $x33:expr) => (DMatrix::from_vec(3,3,vec![$x11, $x12, $x13, $x21, $x22, $x23, $x31, $x32, $x33])); + // } + + #[test] + fn test_create() { + create_case(vec![0., 0.], vec![1., 0., 0., 1.], 1.); + create_case(vec![10., 5.], vec![2., 1., 1., 2.], 3.); + create_case(vec![4., 5., 6.], vec![2., 1., 0., 1., 2., 1., 0., 1., 2.], 14.); + create_case(vec![0., f64::INFINITY], vec![1., 0., 0., 1.], f64::INFINITY); + create_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 0.1); + } + + #[test] + fn test_bad_create() { + // scale not symmetric. + bad_create_case(vec![0., 0.], vec![1., 1., 0., 1.], 1.); + // scale not positive-definite. + bad_create_case(vec![0., 0.], vec![1., 2., 2., 1.], 1.); + // NaN in location. + bad_create_case(vec![0., f64::NAN], vec![1., 0., 0., 1.], 1.); + // NaN in scale Matrix. + bad_create_case(vec![0., 0.], vec![1., 0., 0., f64::NAN], 1.); + // NaN in freedom. + bad_create_case(vec![0., 0.], vec![1., 0., 0., 1.], f64::NAN); + // Non-positive freedom. + bad_create_case(vec![0., 0.], vec![1., 0., 0., 1.], 0.); + bad_create_case(vec![0., 0.], vec![1., 0., 0., 1.], -1.); + } + + #[test] + fn test_variance() { + let variance = |x: MultivariateStudent| x.variance().unwrap(); + test_case(vec![0., 0.], vec![1., 0., 0., 1.], 3., 3. * mat2![1., 0., 0., 1.], variance); + test_case(vec![0., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 3., mat2![f64::INFINITY, 0., 0., f64::INFINITY], variance); + } + + // Variance is only defined for freedom > 2. + #[test] + fn test_bad_variance() { + let variance = |x: MultivariateStudent| x.variance(); + test_case(vec![0., 0.], vec![1., 0., 0., 1.], 2., None, variance); + } + + #[test] + fn test_mode() { + let mode = |x: MultivariateStudent| x.mode(); + test_case(vec![0., 0.], vec![1., 0., 0., 1.], 1., dvec![0., 0.], mode); + test_case(vec![f64::INFINITY, f64::INFINITY], vec![1., 0., 0., 1.], 1., dvec![f64::INFINITY, f64::INFINITY], mode); + } + + #[test] + fn test_mean() { + let mean = |x: MultivariateStudent| x.mean().unwrap(); + test_case(vec![0., 0.], vec![1., 0., 0., 1.], 2., dvec![0., 0.], mean); + test_case(vec![-1., 1., 3.], vec![1., 0., 0.5, 0., 2.0, 0., 0.5, 0., 3.0], 2., dvec![-1., 1., 3.], mean); + } + + // Mean is only defined if freedom > 1. + #[test] + fn test_bad_mean() { + let mean = |x: MultivariateStudent| x.mean(); + test_case(vec![0., 0.], vec![1., 0., 0., 1.], 1., None, mean); + } + + #[test] + fn test_min_max() { + let min = |x: MultivariateStudent| x.min(); + let max = |x: MultivariateStudent| x.max(); + test_case(vec![0., 0.], vec![1., 0., 0., 1.], 1., dvec![f64::NEG_INFINITY, f64::NEG_INFINITY], min); + test_case(vec![0., 0.], vec![1., 0., 0., 1.], 1., dvec![f64::INFINITY, f64::INFINITY], max); + test_case(vec![10., 1.], vec![1., 0., 0., 1.], 1., dvec![f64::NEG_INFINITY, f64::NEG_INFINITY], min); + test_case(vec![-3., 5.], vec![1., 0., 0., 1.], 1., dvec![f64::INFINITY, f64::INFINITY], max); + } + + #[test] + fn test_pdf() { + let pdf = |arg: DVector| move |x: MultivariateStudent| x.pdf(&arg); + test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 4., 0.047157020175376416, 1e-15, pdf(dvec![1., 1.])); + test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 4., 0.013972450422333741737457302178882, 1e-15, pdf(dvec![1., 2.])); + test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 2., 0.012992240252399619, 1e-17, pdf(dvec![1., 2.])); + test_almost(vec![2., 1.], vec![5., 0., 0., 1.], 2.5, 2.639780816598878e-5, 1e-19, pdf(dvec![1., 10.])); + test_almost(vec![-1., 0.], vec![2., 1., 1., 6.], 1.5, 6.438051574348526e-5, 1e-19, pdf(dvec![10., 10.])); + // These three are crossed checked against both python's scipy.multivariate_t.pdf and octave's mvtpdf. + test_almost(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8., 6.960998836915657e-16, 1e-30, pdf(dvec![0.9718, 0.1298, 0.8134])); + test_almost(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8., 7.369987979187023e-16, 1e-30, pdf(dvec![0.4922, 0.5522, 0.7185])); + test_almost(vec![-1., 1., 50.], vec![1., 0.5, 0.25, 0.5, 1., -0.1, 0.25, -0.1, 1.], 8.,6.951631724511314e-16, 1e-30, pdf(dvec![0.3020, 0.1491, 0.5008])); + test_case(vec![-1., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 10., 0., pdf(dvec![10., 10.])); + } + + #[test] + fn test_ln_pdf() { + let ln_pdf = |arg: DVector| move |x: MultivariateStudent| x.ln_pdf(&arg); + test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 4., -3.0542723907338383, 1e-14, ln_pdf(dvec![1., 1.])); + test_almost(vec![0., 0.], vec![1., 0., 0., 1.], 2., -4.3434030034000815, 1e-14, ln_pdf(dvec![1., 2.])); + test_almost(vec![2., 1.], vec![5., 0., 0., 1.], 2.5, -10.542229575274265, 1e-14, ln_pdf(dvec![1., 10.])); + test_almost(vec![-1., 0.], vec![2., 1., 1., 6.], 1.5, -9.650699521198622, 1e-14, ln_pdf(dvec![10., 10.])); + // test_case(vec![-1., 0.], vec![f64::INFINITY, 0., 0., f64::INFINITY], 10., f64::NEG_INFINITY, ln_pdf(dvec![10., 10.])); + } + + #[test] + fn test_pdf_freedom_large() { + let pdf_mvs = |mv: MultivariateStudent, arg: DVector| mv.pdf(&arg); + let pdf_mvn = |mv: MultivariateNormal, arg: DVector| mv.pdf(&arg); + test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e5, 1e-6, dvec![1., 1.], pdf_mvs, pdf_mvn); + test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e10, 1e-7, dvec![1., 1.], pdf_mvs, pdf_mvn); + test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn); + test_almost_multivariate_normal(vec![5., -1.,], vec![1., 0.99, 0.99, 1.], f64::INFINITY, 1e-300, dvec![5., 1.], pdf_mvs, pdf_mvn); + } + #[test] + fn test_ln_pdf_freedom_large() { + let pdf_mvs = |mv: MultivariateStudent, arg: DVector| mv.ln_pdf(&arg); + let pdf_mvn = |mv: MultivariateNormal, arg: DVector| mv.ln_pdf(&arg); + test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e5, 1e-5, dvec![1., 1.], pdf_mvs, pdf_mvn); + test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], 1e10, 5e-6, dvec![1., 1.], pdf_mvs, pdf_mvn); + test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0., 0., 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn); + test_almost_multivariate_normal(vec![0., 0.,], vec![1., 0.99, 0.99, 1.], f64::INFINITY, 1e-300, dvec![1., 1.], pdf_mvs, pdf_mvn); + } + + #[test] + fn test_immut_field_access() { + // init as Dyn + let mvs = MultivariateStudent::new(vec![1., 1.], vec![1., 0., 0., 1.], 2.) + .expect("hard coded valid construction"); + assert_eq!(mvs.freedom(), 2.); + assert_relative_eq!(mvs.ln_pdf_const(), std::f64::consts::TAU.recip().ln(), epsilon = 1e-15); + + // compare to static + assert_eq!(mvs.dim(), 2); + assert!(mvs.location().eq(&OVector::::new(1., 1.))); + assert!(mvs.scale().eq(&OMatrix::::identity())); + assert!(mvs.precision().eq(&OMatrix::::identity())); + assert!(mvs.scale_chol_decomp().eq(&OMatrix::::identity())); + + // compare to Dyn + assert_eq!(mvs.location(),&OVector::::from_element_generic(Dyn(2), U1, 1.)); + assert_eq!(mvs.scale(), &OMatrix::::identity(2, 2)); + assert_eq!(mvs.precision(), &OMatrix::::identity(2, 2)); + assert_eq!(mvs.scale_chol_decomp(), &OMatrix::::identity(2, 2)); + } + + #[test] + fn test_error_is_sync_send() { + fn assert_sync_send() {} + assert_sync_send::(); + } +} diff --git a/src/distribution/negative_binomial.rs b/src/distribution/negative_binomial.rs index a9ed077a..29e22eee 100644 --- a/src/distribution/negative_binomial.rs +++ b/src/distribution/negative_binomial.rs @@ -1,8 +1,6 @@ -use crate::distribution::{self, poisson, Discrete, DiscreteCDF}; +use crate::distribution::{Discrete, DiscreteCDF}; use crate::function::{beta, gamma}; use crate::statistics::*; -use crate::{Result, StatsError}; -use rand::Rng; use std::f64; /// Implements the @@ -35,12 +33,35 @@ use std::f64; /// assert!(almost_eq(r.pmf(0), 0.0625, 1e-8)); /// assert!(almost_eq(r.pmf(3), 0.15625, 1e-8)); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct NegativeBinomial { r: f64, p: f64, } +/// Represents the errors that can occur when creating a [`NegativeBinomial`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum NegativeBinomialError { + /// `r` is NaN or less than zero. + RInvalid, + + /// `p` is NaN or not in `[0, 1]`. + PInvalid, +} + +impl std::fmt::Display for NegativeBinomialError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + NegativeBinomialError::RInvalid => write!(f, "r is NaN or less than zero"), + NegativeBinomialError::PInvalid => write!(f, "p is NaN or not in [0, 1]"), + } + } +} + +impl std::error::Error for NegativeBinomialError {} + impl NegativeBinomial { /// Constructs a new negative binomial distribution with parameters `r` /// and `p`. When `r` is an integer, the negative binomial distribution @@ -64,12 +85,16 @@ impl NegativeBinomial { /// result = NegativeBinomial::new(-0.5, 5.0); /// assert!(result.is_err()); /// ``` - pub fn new(r: f64, p: f64) -> Result { - if p.is_nan() || p < 0.0 || p > 1.0 || r.is_nan() || r < 0.0 { - Err(StatsError::BadParams) - } else { - Ok(NegativeBinomial { r, p }) + pub fn new(r: f64, p: f64) -> Result { + if r.is_nan() || r < 0.0 { + return Err(NegativeBinomialError::RInvalid); } + + if p.is_nan() || !(0.0..=1.0).contains(&p) { + return Err(NegativeBinomialError::PInvalid); + } + + Ok(NegativeBinomial { r, p }) } /// Returns the probability of success `p` of a single @@ -104,9 +129,18 @@ impl NegativeBinomial { } } +impl std::fmt::Display for NegativeBinomial { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "NB({},{})", self.r, self.p) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for NegativeBinomial { - fn sample(&self, r: &mut R) -> u64 { - let lambda = distribution::gamma::sample_unchecked(r, self.r, (1.0 - self.p) / self.p); + fn sample(&self, r: &mut R) -> u64 { + use crate::distribution::{gamma, poisson}; + + let lambda = gamma::sample_unchecked(r, self.r, (1.0 - self.p) / self.p); poisson::sample_unchecked(r, lambda).floor() as u64 } } @@ -117,7 +151,7 @@ impl DiscreteCDF for NegativeBinomial { /// /// # Formula /// - /// ```ignore + /// ```text /// I_(p)(r, x+1) /// ``` /// @@ -137,7 +171,7 @@ impl DiscreteCDF for NegativeBinomial { /// /// # Formula /// - /// ```ignore + /// ```text /// I_(1-p)(x+1, r) /// ``` /// @@ -154,7 +188,7 @@ impl Min for NegativeBinomial { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> u64 { @@ -169,11 +203,11 @@ impl Max for NegativeBinomial { /// /// # Formula /// - /// ```ignore + /// ```text /// u64::MAX /// ``` fn max(&self) -> u64 { - std::u64::MAX + u64::MAX } } @@ -182,27 +216,29 @@ impl DiscreteDistribution for NegativeBinomial { /// /// # Formula /// - /// ```ignore + /// ```text /// r * (1-p) / p /// ``` fn mean(&self) -> Option { Some(self.r * (1.0 - self.p) / self.p) } + /// Returns the variance of the negative binomial distribution. /// /// # Formula /// - /// ```ignore + /// ```text /// r * (1-p) / p^2 /// ``` fn variance(&self) -> Option { Some(self.r * (1.0 - self.p) / (self.p * self.p)) } + /// Returns the skewness of the negative binomial distribution. /// /// # Formula /// - /// ```ignore + /// ```text /// (2-p) / sqrt(r * (1-p)) /// ``` fn skewness(&self) -> Option { @@ -215,7 +251,7 @@ impl Mode> for NegativeBinomial { /// /// # Formula /// - /// ```ignore + /// ```text /// if r > 1 then /// floor((r - 1) * (1-p / p)) /// else @@ -239,13 +275,13 @@ impl Discrete for NegativeBinomial { /// /// When `r` is an integer, the formula is: /// - /// ```ignore + /// ```text /// (x + r - 1 choose x) * (1 - p)^x * p^r /// ``` /// /// The general formula for real `r` is: /// - /// ```ignore + /// ```text /// Γ(r + x)/(Γ(r) * Γ(x + 1)) * (1 - p)^x * p^r /// ``` /// @@ -261,13 +297,13 @@ impl Discrete for NegativeBinomial { /// /// When `r` is an integer, the formula is: /// - /// ```ignore + /// ```text /// ln((x + r - 1 choose x) * (1 - p)^x * p^r) /// ``` /// /// The general formula for real `r` is: /// - /// ```ignore + /// ```text /// ln(Γ(r + x)/(Γ(r) * Γ(x + 1)) * (1 - p)^x * p^r) /// ``` /// @@ -281,225 +317,173 @@ impl Discrete for NegativeBinomial { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use std::fmt::Debug; - use crate::statistics::*; - use crate::distribution::{DiscreteCDF, Discrete, NegativeBinomial}; + use super::*; use crate::distribution::internal::test; - use crate::consts::ACC; - - fn try_create(r: f64, p: f64) -> NegativeBinomial { - let r = NegativeBinomial::new(r, p); - assert!(r.is_ok()); - r.unwrap() - } - - fn create_case(r: f64, p: f64) { - let dist = try_create(r, p); - assert_eq!(p, dist.p()); - assert_eq!(r, dist.r()); - } - - fn bad_create_case(r: f64, p: f64) { - let r = NegativeBinomial::new(r, p); - assert!(r.is_err()); - } - - fn get_value(r: f64, p: f64, eval: F) -> T - where T: PartialEq + Debug, - F: Fn(NegativeBinomial) -> T - { - let r = try_create(r, p); - eval(r) - } - - fn test_case(r: f64, p: f64, expected: T, eval: F) - where T: PartialEq + Debug, - F: Fn(NegativeBinomial) -> T - { - let x = get_value(r, p, eval); - assert_eq!(expected, x); - } - + use crate::testing_boiler; - fn test_case_or_nan(r: f64, p: f64, expected: f64, eval: F) - where F: Fn(NegativeBinomial) -> f64 - { - let x = get_value(r, p, eval); - if expected.is_nan() { - assert!(x.is_nan()) - } - else { - assert_eq!(expected, x); - } - } - fn test_almost(r: f64, p: f64, expected: f64, acc: f64, eval: F) - where F: Fn(NegativeBinomial) -> f64 - { - let x = get_value(r, p, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(r: f64, p: f64; NegativeBinomial; NegativeBinomialError); #[test] fn test_create() { - create_case(0.0, 0.0); - create_case(0.3, 0.4); - create_case(1.0, 0.3); + create_ok(0.0, 0.0); + create_ok(0.3, 0.4); + create_ok(1.0, 0.3); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN, 1.0); - bad_create_case(0.0, f64::NAN); - bad_create_case(-1.0, 1.0); - bad_create_case(2.0, 2.0); + test_create_err(f64::NAN, 1.0, NegativeBinomialError::RInvalid); + test_create_err(0.0, f64::NAN, NegativeBinomialError::PInvalid); + create_err(-1.0, 1.0); + create_err(2.0, 2.0); } #[test] fn test_mean() { let mean = |x: NegativeBinomial| x.mean().unwrap(); - test_case(4.0, 0.0, f64::INFINITY, mean); - test_almost(3.0, 0.3, 7.0, 1e-15 , mean); - test_case(2.0, 1.0, 0.0, mean); + test_exact(4.0, 0.0, f64::INFINITY, mean); + test_absolute(3.0, 0.3, 7.0, 1e-15 , mean); + test_exact(2.0, 1.0, 0.0, mean); } #[test] fn test_variance() { let variance = |x: NegativeBinomial| x.variance().unwrap(); - test_case(4.0, 0.0, f64::INFINITY, variance); - test_almost(3.0, 0.3, 23.333333333333, 1e-12, variance); - test_case(2.0, 1.0, 0.0, variance); + test_exact(4.0, 0.0, f64::INFINITY, variance); + test_absolute(3.0, 0.3, 23.333333333333, 1e-12, variance); + test_exact(2.0, 1.0, 0.0, variance); } #[test] fn test_skewness() { let skewness = |x: NegativeBinomial| x.skewness().unwrap(); - test_case(0.0, 0.0, f64::INFINITY, skewness); - test_almost(0.1, 0.3, 6.425396041, 1e-09, skewness); - test_case(1.0, 1.0, f64::INFINITY, skewness); + test_exact(0.0, 0.0, f64::INFINITY, skewness); + test_absolute(0.1, 0.3, 6.425396041, 1e-09, skewness); + test_exact(1.0, 1.0, f64::INFINITY, skewness); } #[test] fn test_mode() { let mode = |x: NegativeBinomial| x.mode().unwrap(); - test_case(0.0, 0.0, 0.0, mode); - test_case(0.3, 0.0, 0.0, mode); - test_case(1.0, 1.0, 0.0, mode); - test_case(10.0, 0.01, 891.0, mode); + test_exact(0.0, 0.0, 0.0, mode); + test_exact(0.3, 0.0, 0.0, mode); + test_exact(1.0, 1.0, 0.0, mode); + test_exact(10.0, 0.01, 891.0, mode); } #[test] fn test_min_max() { let min = |x: NegativeBinomial| x.min(); let max = |x: NegativeBinomial| x.max(); - test_case(1.0, 0.5, 0, min); - test_case(1.0, 0.3, std::u64::MAX, max); + test_exact(1.0, 0.5, 0, min); + test_exact(1.0, 0.3, u64::MAX, max); } #[test] fn test_pmf() { let pmf = |arg: u64| move |x: NegativeBinomial| x.pmf(arg); - test_almost(4.0, 0.5, 0.0625, 1e-8, pmf(0)); - test_almost(4.0, 0.5, 0.15625, 1e-8, pmf(3)); - test_case(1.0, 0.0, 0.0, pmf(0)); - test_case(1.0, 0.0, 0.0, pmf(1)); - test_almost(3.0, 0.2, 0.008, 1e-15, pmf(0)); - test_almost(3.0, 0.2, 0.0192, 1e-15, pmf(1)); - test_almost(3.0, 0.2, 0.04096, 1e-15, pmf(3)); - test_almost(10.0, 0.2, 1.024e-07, 1e-07, pmf(0)); - test_almost(10.0, 0.2, 8.192e-07, 1e-07, pmf(1)); - test_almost(10.0, 0.2, 0.001015706852, 1e-07, pmf(10)); - test_almost(1.0, 0.3, 0.3, 1e-15, pmf(0)); - test_almost(1.0, 0.3, 0.21, 1e-15, pmf(1)); - test_almost(3.0, 0.3, 0.027, 1e-15, pmf(0)); - test_case(0.3, 1.0, 0.0, pmf(1)); - test_case(0.3, 1.0, 0.0, pmf(3)); - test_case_or_nan(0.3, 1.0, f64::NAN, pmf(0)); - test_case(0.3, 1.0, 0.0, pmf(1)); - test_case(0.3, 1.0, 0.0, pmf(10)); - test_case_or_nan(1.0, 1.0, f64::NAN, pmf(0)); - test_case(1.0, 1.0, 0.0, pmf(1)); - test_case_or_nan(3.0, 1.0, f64::NAN, pmf(0)); - test_case(3.0, 1.0, 0.0, pmf(1)); - test_case(3.0, 1.0, 0.0, pmf(3)); - test_case_or_nan(10.0, 1.0, f64::NAN, pmf(0)); - test_case(10.0, 1.0, 0.0, pmf(1)); - test_case(10.0, 1.0, 0.0, pmf(10)); + test_absolute(4.0, 0.5, 0.0625, 1e-8, pmf(0)); + test_absolute(4.0, 0.5, 0.15625, 1e-8, pmf(3)); + test_exact(1.0, 0.0, 0.0, pmf(0)); + test_exact(1.0, 0.0, 0.0, pmf(1)); + test_absolute(3.0, 0.2, 0.008, 1e-15, pmf(0)); + test_absolute(3.0, 0.2, 0.0192, 1e-15, pmf(1)); + test_absolute(3.0, 0.2, 0.04096, 1e-15, pmf(3)); + test_absolute(10.0, 0.2, 1.024e-07, 1e-07, pmf(0)); + test_absolute(10.0, 0.2, 8.192e-07, 1e-07, pmf(1)); + test_absolute(10.0, 0.2, 0.001015706852, 1e-07, pmf(10)); + test_absolute(1.0, 0.3, 0.3, 1e-15, pmf(0)); + test_absolute(1.0, 0.3, 0.21, 1e-15, pmf(1)); + test_absolute(3.0, 0.3, 0.027, 1e-15, pmf(0)); + test_exact(0.3, 1.0, 0.0, pmf(1)); + test_exact(0.3, 1.0, 0.0, pmf(3)); + test_is_nan(0.3, 1.0, pmf(0)); + test_exact(0.3, 1.0, 0.0, pmf(1)); + test_exact(0.3, 1.0, 0.0, pmf(10)); + test_is_nan(1.0, 1.0, pmf(0)); + test_exact(1.0, 1.0, 0.0, pmf(1)); + test_is_nan(3.0, 1.0, pmf(0)); + test_exact(3.0, 1.0, 0.0, pmf(1)); + test_exact(3.0, 1.0, 0.0, pmf(3)); + test_is_nan(10.0, 1.0, pmf(0)); + test_exact(10.0, 1.0, 0.0, pmf(1)); + test_exact(10.0, 1.0, 0.0, pmf(10)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: u64| move |x: NegativeBinomial| x.ln_pmf(arg); - test_case(1.0, 0.0, f64::NEG_INFINITY, ln_pmf(0)); - test_case(1.0, 0.0, f64::NEG_INFINITY, ln_pmf(1)); - test_almost(3.0, 0.2, -4.828313737, 1e-08, ln_pmf(0)); - test_almost(3.0, 0.2, -3.952845, 1e-08, ln_pmf(1)); - test_almost(3.0, 0.2, -3.195159298, 1e-08, ln_pmf(3)); - test_almost(10.0, 0.2, -16.09437912, 1e-08, ln_pmf(0)); - test_almost(10.0, 0.2, -14.01493758, 1e-08, ln_pmf(1)); - test_almost(10.0, 0.2, -6.892170503, 1e-08, ln_pmf(10)); - test_almost(1.0, 0.3, -1.203972804, 1e-08, ln_pmf(0)); - test_almost(1.0, 0.3, -1.560647748, 1e-08, ln_pmf(1)); - test_almost(3.0, 0.3, -3.611918413, 1e-08, ln_pmf(0)); - test_case(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(1)); - test_case(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(3)); - test_case_or_nan(0.3, 1.0, f64::NAN, ln_pmf(0)); - test_case(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(1)); - test_case(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(10)); - test_case_or_nan(1.0, 1.0, f64::NAN, ln_pmf(0)); - test_case(1.0, 1.0, f64::NEG_INFINITY, ln_pmf(1)); - test_case_or_nan(3.0, 1.0, f64::NAN, ln_pmf(0)); - test_case(3.0, 1.0, f64::NEG_INFINITY, ln_pmf(1)); - test_case(3.0, 1.0, f64::NEG_INFINITY, ln_pmf(3)); - test_case_or_nan(10.0, 1.0, f64::NAN, ln_pmf(0)); - test_case(10.0, 1.0, f64::NEG_INFINITY, ln_pmf(1)); - test_case(10.0, 1.0, f64::NEG_INFINITY, ln_pmf(10)); + test_exact(1.0, 0.0, f64::NEG_INFINITY, ln_pmf(0)); + test_exact(1.0, 0.0, f64::NEG_INFINITY, ln_pmf(1)); + test_absolute(3.0, 0.2, -4.828313737, 1e-08, ln_pmf(0)); + test_absolute(3.0, 0.2, -3.952845, 1e-08, ln_pmf(1)); + test_absolute(3.0, 0.2, -3.195159298, 1e-08, ln_pmf(3)); + test_absolute(10.0, 0.2, -16.09437912, 1e-08, ln_pmf(0)); + test_absolute(10.0, 0.2, -14.01493758, 1e-08, ln_pmf(1)); + test_absolute(10.0, 0.2, -6.892170503, 1e-08, ln_pmf(10)); + test_absolute(1.0, 0.3, -1.203972804, 1e-08, ln_pmf(0)); + test_absolute(1.0, 0.3, -1.560647748, 1e-08, ln_pmf(1)); + test_absolute(3.0, 0.3, -3.611918413, 1e-08, ln_pmf(0)); + test_exact(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(1)); + test_exact(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(3)); + test_is_nan(0.3, 1.0, ln_pmf(0)); + test_exact(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(1)); + test_exact(0.3, 1.0, f64::NEG_INFINITY, ln_pmf(10)); + test_is_nan(1.0, 1.0, ln_pmf(0)); + test_exact(1.0, 1.0, f64::NEG_INFINITY, ln_pmf(1)); + test_is_nan(3.0, 1.0, ln_pmf(0)); + test_exact(3.0, 1.0, f64::NEG_INFINITY, ln_pmf(1)); + test_exact(3.0, 1.0, f64::NEG_INFINITY, ln_pmf(3)); + test_is_nan(10.0, 1.0, ln_pmf(0)); + test_exact(10.0, 1.0, f64::NEG_INFINITY, ln_pmf(1)); + test_exact(10.0, 1.0, f64::NEG_INFINITY, ln_pmf(10)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: NegativeBinomial| x.cdf(arg); - test_almost(1.0, 0.3, 0.3, 1e-08, cdf(0)); - test_almost(1.0, 0.3, 0.51, 1e-08, cdf(1)); - test_almost(1.0, 0.3, 0.83193, 1e-08, cdf(4)); - test_almost(1.0, 0.3, 0.9802267326, 1e-08, cdf(10)); - test_case(1.0, 1.0, 1.0, cdf(0)); - test_case(1.0, 1.0, 1.0, cdf(1)); - test_almost(10.0, 0.75, 0.05631351471, 1e-08, cdf(0)); - test_almost(10.0, 0.75, 0.1970973015, 1e-08, cdf(1)); - test_almost(10.0, 0.75, 0.9960578583, 1e-08, cdf(10)); + test_absolute(1.0, 0.3, 0.3, 1e-08, cdf(0)); + test_absolute(1.0, 0.3, 0.51, 1e-08, cdf(1)); + test_absolute(1.0, 0.3, 0.83193, 1e-08, cdf(4)); + test_absolute(1.0, 0.3, 0.9802267326, 1e-08, cdf(10)); + test_exact(1.0, 1.0, 1.0, cdf(0)); + test_exact(1.0, 1.0, 1.0, cdf(1)); + test_absolute(10.0, 0.75, 0.05631351471, 1e-08, cdf(0)); + test_absolute(10.0, 0.75, 0.1970973015, 1e-08, cdf(1)); + test_absolute(10.0, 0.75, 0.9960578583, 1e-08, cdf(10)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: NegativeBinomial| x.sf(arg); - test_almost(1.0, 0.3, 0.7, 1e-08, sf(0)); - test_almost(1.0, 0.3, 0.49, 1e-08, sf(1)); - test_almost(1.0, 0.3, 0.1680699999999986, 1e-08, sf(4)); - test_almost(1.0, 0.3, 0.019773267430000074, 1e-08, sf(10)); - test_case(1.0, 1.0, 0.0, sf(0)); - test_case(1.0, 1.0, 0.0, sf(1)); - test_almost(10.0, 0.75, 0.9436864852905275, 1e-08, sf(0)); - test_almost(10.0, 0.75, 0.8029026985168456, 1e-08, sf(1)); - test_almost(10.0, 0.75, 0.003942141664083465, 1e-08, sf(10)); + test_absolute(1.0, 0.3, 0.7, 1e-08, sf(0)); + test_absolute(1.0, 0.3, 0.49, 1e-08, sf(1)); + test_absolute(1.0, 0.3, 0.1680699999999986, 1e-08, sf(4)); + test_absolute(1.0, 0.3, 0.019773267430000074, 1e-08, sf(10)); + test_exact(1.0, 1.0, 0.0, sf(0)); + test_exact(1.0, 1.0, 0.0, sf(1)); + test_absolute(10.0, 0.75, 0.9436864852905275, 1e-08, sf(0)); + test_absolute(10.0, 0.75, 0.8029026985168456, 1e-08, sf(1)); + test_absolute(10.0, 0.75, 0.003942141664083465, 1e-08, sf(10)); } #[test] fn test_cdf_upper_bound() { let cdf = |arg: u64| move |x: NegativeBinomial| x.cdf(arg); - test_case(3.0, 0.5, 1.0, cdf(100)); + test_exact(3.0, 0.5, 1.0, cdf(100)); } #[test] fn test_discrete() { - test::check_discrete_distribution(&try_create(5.0, 0.3), 35); - test::check_discrete_distribution(&try_create(10.0, 0.7), 21); + test::check_discrete_distribution(&create_ok(5.0, 0.3), 35); + test::check_discrete_distribution(&create_ok(10.0, 0.7), 21); } #[test] fn test_sf_upper_bound() { let sf = |arg: u64| move |x: NegativeBinomial| x.sf(arg); - test_almost(3.0, 0.5, 5.282409836586059e-28, 1e-28, sf(100)); + test_absolute(3.0, 0.5, 5.282409836586059e-28, 1e-28, sf(100)); } } diff --git a/src/distribution/normal.rs b/src/distribution/normal.rs index dabb7915..a264af50 100644 --- a/src/distribution/normal.rs +++ b/src/distribution/normal.rs @@ -1,8 +1,7 @@ -use crate::distribution::{ziggurat, Continuous, ContinuousCDF}; +use crate::consts; +use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::erf; use crate::statistics::*; -use crate::{consts, Result, StatsError}; -use rand::Rng; use std::f64; /// Implements the [Normal](https://en.wikipedia.org/wiki/Normal_distribution) @@ -18,12 +17,37 @@ use std::f64; /// assert_eq!(n.mean().unwrap(), 0.0); /// assert_eq!(n.pdf(1.0), 0.2419707245191433497978); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Normal { mean: f64, std_dev: f64, } +/// Represents the errors that can occur when creating a [`Normal`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum NormalError { + /// The mean is NaN. + MeanInvalid, + + /// The standard deviation is NaN, zero or less than zero. + StandardDeviationInvalid, +} + +impl std::fmt::Display for NormalError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + NormalError::MeanInvalid => write!(f, "Mean is NaN"), + NormalError::StandardDeviationInvalid => { + write!(f, "Standard deviation is NaN, zero or less than zero") + } + } + } +} + +impl std::error::Error for NormalError {} + impl Normal { /// Constructs a new normal distribution with a mean of `mean` /// and a standard deviation of `std_dev` @@ -44,17 +68,46 @@ impl Normal { /// result = Normal::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(mean: f64, std_dev: f64) -> Result { - if mean.is_nan() || std_dev.is_nan() || std_dev <= 0.0 { - Err(StatsError::BadParams) - } else { - Ok(Normal { mean, std_dev }) + pub fn new(mean: f64, std_dev: f64) -> Result { + if mean.is_nan() { + return Err(NormalError::MeanInvalid); + } + + if std_dev.is_nan() || std_dev <= 0.0 { + return Err(NormalError::StandardDeviationInvalid); + } + + Ok(Normal { mean, std_dev }) + } + + /// Constructs a new standard normal distribution with a mean of 0 + /// and a standard deviation of 1. + /// + /// + /// # Examples + /// + /// ``` + /// use statrs::distribution::Normal; + /// + /// let mut result = Normal::standard(); + /// ``` + pub fn standard() -> Normal { + Normal { + mean: 0.0, + std_dev: 1.0, } } } +impl std::fmt::Display for Normal { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "N({},{})", self.mean, self.std_dev) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Normal { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { sample_unchecked(rng, self.mean, self.std_dev) } } @@ -65,7 +118,7 @@ impl ContinuousCDF for Normal { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / 2) * (1 + erf((x - μ) / (σ * sqrt(2)))) /// ``` /// @@ -80,7 +133,7 @@ impl ContinuousCDF for Normal { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / 2) * (1 + erf(-(x - μ) / (σ * sqrt(2)))) /// ``` /// @@ -91,16 +144,17 @@ impl ContinuousCDF for Normal { /// the sign of the argument error function with respect to the cdf. /// /// the normal cdf Φ (and internal error function) as the following property: - /// ```ignore + /// ```text /// Φ(-x) + Φ(x) = 1 - /// Φ(-x) = 1 - Φ(x) + /// Φ(-x) = 1 - Φ(x) /// ``` fn sf(&self, x: f64) -> f64 { sf_unchecked(x, self.mean, self.std_dev) } /// Calculates the inverse cumulative distribution function for the - /// normal distribution at `x` + /// normal distribution at `x`. + /// In other languages, such as R, this is known as the the quantile function. /// /// # Panics /// @@ -108,7 +162,7 @@ impl ContinuousCDF for Normal { /// /// # Formula /// - /// ```ignore + /// ```text /// μ - sqrt(2) * σ * erfc_inv(2x) /// ``` /// @@ -129,8 +183,8 @@ impl Min for Normal { /// /// # Formula /// - /// ```ignore - /// -INF + /// ```text + /// f64::NEG_INFINITY /// ``` fn min(&self) -> f64 { f64::NEG_INFINITY @@ -143,8 +197,8 @@ impl Max for Normal { /// /// # Formula /// - /// ```ignore - /// INF + /// ```text + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY @@ -160,11 +214,12 @@ impl Distribution for Normal { fn mean(&self) -> Option { Some(self.mean) } + /// Returns the variance of the normal distribution /// /// # Formula /// - /// ```ignore + /// ```text /// σ^2 /// ``` /// @@ -172,11 +227,19 @@ impl Distribution for Normal { fn variance(&self) -> Option { Some(self.std_dev * self.std_dev) } + + /// Returns the standard deviation of the normal distribution + /// # Remarks + /// This is the same standard deviation used to construct the distribution + fn std_dev(&self) -> Option { + Some(self.std_dev) + } + /// Returns the entropy of the normal distribution /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / 2) * ln(2σ^2 * π * e) /// ``` /// @@ -184,11 +247,12 @@ impl Distribution for Normal { fn entropy(&self) -> Option { Some(self.std_dev.ln() + consts::LN_SQRT_2PIE) } + /// Returns the skewness of the normal distribution /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn skewness(&self) -> Option { @@ -201,7 +265,7 @@ impl Median for Normal { /// /// # Formula /// - /// ```ignore + /// ```text /// μ /// ``` /// @@ -216,7 +280,7 @@ impl Mode> for Normal { /// /// # Formula /// - /// ```ignore + /// ```text /// μ /// ``` /// @@ -232,7 +296,7 @@ impl Continuous for Normal { /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / sqrt(2σ^2 * π)) * e^(-(x - μ)^2 / 2σ^2) /// ``` /// @@ -247,7 +311,7 @@ impl Continuous for Normal { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((1 / sqrt(2σ^2 * π)) * e^(-(x - μ)^2 / 2σ^2)) /// ``` /// @@ -283,229 +347,221 @@ pub fn ln_pdf_unchecked(x: f64, mean: f64, std_dev: f64) -> f64 { (-0.5 * d * d) - consts::LN_SQRT_2PI - std_dev.ln() } +#[cfg(feature = "rand")] /// draws a sample from a normal distribution using the Box-Muller algorithm -pub fn sample_unchecked(rng: &mut R, mean: f64, std_dev: f64) -> f64 { +pub fn sample_unchecked(rng: &mut R, mean: f64, std_dev: f64) -> f64 { + use crate::distribution::ziggurat; + mean + std_dev * ziggurat::sample_std_normal(rng) } +impl std::default::Default for Normal { + /// Returns the standard normal distribution with a mean of 0 + /// and a standard deviation of 1. + fn default() -> Self { + Self::standard() + } +} + #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use crate::statistics::*; - use crate::distribution::{ContinuousCDF, Continuous, Normal}; + use super::*; use crate::distribution::internal::*; - use crate::consts::ACC; + use crate::testing_boiler; - fn try_create(mean: f64, std_dev: f64) -> Normal { - let n = Normal::new(mean, std_dev); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(mean: f64, std_dev: f64) { - let n = try_create(mean, std_dev); - assert_eq!(mean, n.mean().unwrap()); - assert_eq!(std_dev, n.std_dev().unwrap()); - } - - fn bad_create_case(mean: f64, std_dev: f64) { - let n = Normal::new(mean, std_dev); - assert!(n.is_err()); - } - - fn test_case(mean: f64, std_dev: f64, expected: f64, eval: F) - where F: Fn(Normal) -> f64 - { - let n = try_create(mean, std_dev); - let x = eval(n); - assert_eq!(expected, x); - } - - fn test_almost(mean: f64, std_dev: f64, expected: f64, acc: f64, eval: F) - where F: Fn(Normal) -> f64 - { - let n = try_create(mean, std_dev); - let x = eval(n); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(mean: f64, std_dev: f64; Normal; NormalError); #[test] fn test_create() { - create_case(10.0, 0.1); - create_case(-5.0, 1.0); - create_case(0.0, 10.0); - create_case(10.0, 100.0); - create_case(-5.0, f64::INFINITY); + create_ok(10.0, 0.1); + create_ok(-5.0, 1.0); + create_ok(0.0, 10.0); + create_ok(10.0, 100.0); + create_ok(-5.0, f64::INFINITY); } #[test] fn test_bad_create() { - bad_create_case(0.0, 0.0); - bad_create_case(f64::NAN, 1.0); - bad_create_case(1.0, f64::NAN); - bad_create_case(f64::NAN, f64::NAN); - bad_create_case(1.0, -1.0); + test_create_err(f64::NAN, 1.0, NormalError::MeanInvalid); + test_create_err(1.0, f64::NAN, NormalError::StandardDeviationInvalid); + create_err(0.0, 0.0); + create_err(f64::NAN, f64::NAN); + create_err(1.0, -1.0); } #[test] fn test_variance() { let variance = |x: Normal| x.variance().unwrap(); - test_case(0.0, 0.1, 0.1 * 0.1, variance); - test_case(0.0, 1.0, 1.0, variance); - test_case(0.0, 10.0, 100.0, variance); - test_case(0.0, f64::INFINITY, f64::INFINITY, variance); + test_exact(0.0, 0.1, 0.1 * 0.1, variance); + test_exact(0.0, 1.0, 1.0, variance); + test_exact(0.0, 10.0, 100.0, variance); + test_exact(0.0, f64::INFINITY, f64::INFINITY, variance); } #[test] fn test_entropy() { let entropy = |x: Normal| x.entropy().unwrap(); - test_almost(0.0, 0.1, -0.8836465597893729422377, 1e-15, entropy); - test_case(0.0, 1.0, 1.41893853320467274178, entropy); - test_case(0.0, 10.0, 3.721523626198718425798, entropy); - test_case(0.0, f64::INFINITY, f64::INFINITY, entropy); + test_absolute(0.0, 0.1, -0.8836465597893729422377, 1e-15, entropy); + test_exact(0.0, 1.0, 1.41893853320467274178, entropy); + test_exact(0.0, 10.0, 3.721523626198718425798, entropy); + test_exact(0.0, f64::INFINITY, f64::INFINITY, entropy); } #[test] fn test_skewness() { let skewness = |x: Normal| x.skewness().unwrap(); - test_case(0.0, 0.1, 0.0, skewness); - test_case(4.0, 1.0, 0.0, skewness); - test_case(0.3, 10.0, 0.0, skewness); - test_case(0.0, f64::INFINITY, 0.0, skewness); + test_exact(0.0, 0.1, 0.0, skewness); + test_exact(4.0, 1.0, 0.0, skewness); + test_exact(0.3, 10.0, 0.0, skewness); + test_exact(0.0, f64::INFINITY, 0.0, skewness); } #[test] fn test_mode() { let mode = |x: Normal| x.mode().unwrap(); - test_case(-0.0, 1.0, 0.0, mode); - test_case(0.0, 1.0, 0.0, mode); - test_case(0.1, 1.0, 0.1, mode); - test_case(1.0, 1.0, 1.0, mode); - test_case(-10.0, 1.0, -10.0, mode); - test_case(f64::INFINITY, 1.0, f64::INFINITY, mode); + test_exact(-0.0, 1.0, 0.0, mode); + test_exact(0.0, 1.0, 0.0, mode); + test_exact(0.1, 1.0, 0.1, mode); + test_exact(1.0, 1.0, 1.0, mode); + test_exact(-10.0, 1.0, -10.0, mode); + test_exact(f64::INFINITY, 1.0, f64::INFINITY, mode); } #[test] fn test_median() { let median = |x: Normal| x.median(); - test_case(-0.0, 1.0, 0.0, median); - test_case(0.0, 1.0, 0.0, median); - test_case(0.1, 1.0, 0.1, median); - test_case(1.0, 1.0, 1.0, median); - test_case(-0.0, 1.0, -0.0, median); - test_case(f64::INFINITY, 1.0, f64::INFINITY, median); + test_exact(-0.0, 1.0, 0.0, median); + test_exact(0.0, 1.0, 0.0, median); + test_exact(0.1, 1.0, 0.1, median); + test_exact(1.0, 1.0, 1.0, median); + test_exact(-0.0, 1.0, -0.0, median); + test_exact(f64::INFINITY, 1.0, f64::INFINITY, median); } #[test] fn test_min_max() { let min = |x: Normal| x.min(); let max = |x: Normal| x.max(); - test_case(0.0, 0.1, f64::NEG_INFINITY, min); - test_case(-3.0, 10.0, f64::NEG_INFINITY, min); - test_case(0.0, 0.1, f64::INFINITY, max); - test_case(-3.0, 10.0, f64::INFINITY, max); + test_exact(0.0, 0.1, f64::NEG_INFINITY, min); + test_exact(-3.0, 10.0, f64::NEG_INFINITY, min); + test_exact(0.0, 0.1, f64::INFINITY, max); + test_exact(-3.0, 10.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Normal| x.pdf(arg); - test_almost(10.0, 0.1, 5.530709549844416159162E-49, 1e-64, pdf(8.5)); - test_almost(10.0, 0.1, 0.5399096651318805195056, 1e-14, pdf(9.8)); - test_almost(10.0, 0.1, 3.989422804014326779399, 1e-15, pdf(10.0)); - test_almost(10.0, 0.1, 0.5399096651318805195056, 1e-14, pdf(10.2)); - test_almost(10.0, 0.1, 5.530709549844416159162E-49, 1e-64, pdf(11.5)); - test_case(-5.0, 1.0, 1.486719514734297707908E-6, pdf(-10.0)); - test_case(-5.0, 1.0, 0.01752830049356853736216, pdf(-7.5)); - test_almost(-5.0, 1.0, 0.3989422804014326779399, 1e-16, pdf(-5.0)); - test_case(-5.0, 1.0, 0.01752830049356853736216, pdf(-2.5)); - test_case(-5.0, 1.0, 1.486719514734297707908E-6, pdf(0.0)); - test_case(0.0, 10.0, 0.03520653267642994777747, pdf(-5.0)); - test_almost(0.0, 10.0, 0.03866681168028492069412, 1e-17, pdf(-2.5)); - test_almost(0.0, 10.0, 0.03989422804014326779399, 1e-17, pdf(0.0)); - test_almost(0.0, 10.0, 0.03866681168028492069412, 1e-17, pdf(2.5)); - test_case(0.0, 10.0, 0.03520653267642994777747, pdf(5.0)); - test_almost(10.0, 100.0, 4.398359598042719404845E-4, 1e-19, pdf(-200.0)); - test_case(10.0, 100.0, 0.002178521770325505313831, pdf(-100.0)); - test_case(10.0, 100.0, 0.003969525474770117655105, pdf(0.0)); - test_almost(10.0, 100.0, 0.002660852498987548218204, 1e-18, pdf(100.0)); - test_case(10.0, 100.0, 6.561581477467659126534E-4, pdf(200.0)); - test_case(-5.0, f64::INFINITY, 0.0, pdf(-5.0)); - test_case(-5.0, f64::INFINITY, 0.0, pdf(0.0)); - test_case(-5.0, f64::INFINITY, 0.0, pdf(100.0)); + test_absolute(10.0, 0.1, 5.530709549844416159162E-49, 1e-64, pdf(8.5)); + test_absolute(10.0, 0.1, 0.5399096651318805195056, 1e-14, pdf(9.8)); + test_absolute(10.0, 0.1, 3.989422804014326779399, 1e-15, pdf(10.0)); + test_absolute(10.0, 0.1, 0.5399096651318805195056, 1e-14, pdf(10.2)); + test_absolute(10.0, 0.1, 5.530709549844416159162E-49, 1e-64, pdf(11.5)); + test_exact(-5.0, 1.0, 1.486719514734297707908E-6, pdf(-10.0)); + test_exact(-5.0, 1.0, 0.01752830049356853736216, pdf(-7.5)); + test_absolute(-5.0, 1.0, 0.3989422804014326779399, 1e-16, pdf(-5.0)); + test_exact(-5.0, 1.0, 0.01752830049356853736216, pdf(-2.5)); + test_exact(-5.0, 1.0, 1.486719514734297707908E-6, pdf(0.0)); + test_exact(0.0, 10.0, 0.03520653267642994777747, pdf(-5.0)); + test_absolute(0.0, 10.0, 0.03866681168028492069412, 1e-17, pdf(-2.5)); + test_absolute(0.0, 10.0, 0.03989422804014326779399, 1e-17, pdf(0.0)); + test_absolute(0.0, 10.0, 0.03866681168028492069412, 1e-17, pdf(2.5)); + test_exact(0.0, 10.0, 0.03520653267642994777747, pdf(5.0)); + test_absolute(10.0, 100.0, 4.398359598042719404845E-4, 1e-19, pdf(-200.0)); + test_exact(10.0, 100.0, 0.002178521770325505313831, pdf(-100.0)); + test_exact(10.0, 100.0, 0.003969525474770117655105, pdf(0.0)); + test_absolute(10.0, 100.0, 0.002660852498987548218204, 1e-18, pdf(100.0)); + test_exact(10.0, 100.0, 6.561581477467659126534E-4, pdf(200.0)); + test_exact(-5.0, f64::INFINITY, 0.0, pdf(-5.0)); + test_exact(-5.0, f64::INFINITY, 0.0, pdf(0.0)); + test_exact(-5.0, f64::INFINITY, 0.0, pdf(100.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Normal| x.ln_pdf(arg); - test_almost(10.0, 0.1, (5.530709549844416159162E-49f64).ln(), 1e-13, ln_pdf(8.5)); - test_almost(10.0, 0.1, (0.5399096651318805195056f64).ln(), 1e-13, ln_pdf(9.8)); - test_almost(10.0, 0.1, (3.989422804014326779399f64).ln(), 1e-15, ln_pdf(10.0)); - test_almost(10.0, 0.1, (0.5399096651318805195056f64).ln(), 1e-13, ln_pdf(10.2)); - test_almost(10.0, 0.1, (5.530709549844416159162E-49f64).ln(), 1e-13, ln_pdf(11.5)); - test_case(-5.0, 1.0, (1.486719514734297707908E-6f64).ln(), ln_pdf(-10.0)); - test_case(-5.0, 1.0, (0.01752830049356853736216f64).ln(), ln_pdf(-7.5)); - test_almost(-5.0, 1.0, (0.3989422804014326779399f64).ln(), 1e-15, ln_pdf(-5.0)); - test_case(-5.0, 1.0, (0.01752830049356853736216f64).ln(), ln_pdf(-2.5)); - test_case(-5.0, 1.0, (1.486719514734297707908E-6f64).ln(), ln_pdf(0.0)); - test_case(0.0, 10.0, (0.03520653267642994777747f64).ln(), ln_pdf(-5.0)); - test_case(0.0, 10.0, (0.03866681168028492069412f64).ln(), ln_pdf(-2.5)); - test_case(0.0, 10.0, (0.03989422804014326779399f64).ln(), ln_pdf(0.0)); - test_case(0.0, 10.0, (0.03866681168028492069412f64).ln(), ln_pdf(2.5)); - test_case(0.0, 10.0, (0.03520653267642994777747f64).ln(), ln_pdf(5.0)); - test_case(10.0, 100.0, (4.398359598042719404845E-4f64).ln(), ln_pdf(-200.0)); - test_case(10.0, 100.0, (0.002178521770325505313831f64).ln(), ln_pdf(-100.0)); - test_almost(10.0, 100.0, (0.003969525474770117655105f64).ln(),1e-15, ln_pdf(0.0)); - test_almost(10.0, 100.0, (0.002660852498987548218204f64).ln(), 1e-15, ln_pdf(100.0)); - test_almost(10.0, 100.0, (6.561581477467659126534E-4f64).ln(), 1e-15, ln_pdf(200.0)); - test_case(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-5.0)); - test_case(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.0)); - test_case(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(100.0)); + test_absolute(10.0, 0.1, (5.530709549844416159162E-49f64).ln(), 1e-13, ln_pdf(8.5)); + test_absolute(10.0, 0.1, (0.5399096651318805195056f64).ln(), 1e-13, ln_pdf(9.8)); + test_absolute(10.0, 0.1, (3.989422804014326779399f64).ln(), 1e-15, ln_pdf(10.0)); + test_absolute(10.0, 0.1, (0.5399096651318805195056f64).ln(), 1e-13, ln_pdf(10.2)); + test_absolute(10.0, 0.1, (5.530709549844416159162E-49f64).ln(), 1e-13, ln_pdf(11.5)); + test_exact(-5.0, 1.0, (1.486719514734297707908E-6f64).ln(), ln_pdf(-10.0)); + test_exact(-5.0, 1.0, (0.01752830049356853736216f64).ln(), ln_pdf(-7.5)); + test_absolute(-5.0, 1.0, (0.3989422804014326779399f64).ln(), 1e-15, ln_pdf(-5.0)); + test_exact(-5.0, 1.0, (0.01752830049356853736216f64).ln(), ln_pdf(-2.5)); + test_exact(-5.0, 1.0, (1.486719514734297707908E-6f64).ln(), ln_pdf(0.0)); + test_exact(0.0, 10.0, (0.03520653267642994777747f64).ln(), ln_pdf(-5.0)); + test_exact(0.0, 10.0, (0.03866681168028492069412f64).ln(), ln_pdf(-2.5)); + test_exact(0.0, 10.0, (0.03989422804014326779399f64).ln(), ln_pdf(0.0)); + test_exact(0.0, 10.0, (0.03866681168028492069412f64).ln(), ln_pdf(2.5)); + test_exact(0.0, 10.0, (0.03520653267642994777747f64).ln(), ln_pdf(5.0)); + test_exact(10.0, 100.0, (4.398359598042719404845E-4f64).ln(), ln_pdf(-200.0)); + test_exact(10.0, 100.0, (0.002178521770325505313831f64).ln(), ln_pdf(-100.0)); + test_absolute(10.0, 100.0, (0.003969525474770117655105f64).ln(),1e-15, ln_pdf(0.0)); + test_absolute(10.0, 100.0, (0.002660852498987548218204f64).ln(), 1e-15, ln_pdf(100.0)); + test_absolute(10.0, 100.0, (6.561581477467659126534E-4f64).ln(), 1e-15, ln_pdf(200.0)); + test_exact(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-5.0)); + test_exact(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(0.0)); + test_exact(-5.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(100.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Normal| x.cdf(arg); - test_case(5.0, 2.0, 0.0, cdf(f64::NEG_INFINITY)); - test_almost(5.0, 2.0, 0.0000002866515718, 1e-16, cdf(-5.0)); - test_almost(5.0, 2.0, 0.0002326290790, 1e-13, cdf(-2.0)); - test_almost(5.0, 2.0, 0.006209665325, 1e-12, cdf(0.0)); - test_case(5.0, 2.0, 0.30853753872598689636229538939166226011639782444542207, cdf(4.0)); - test_case(5.0, 2.0, 0.5, cdf(5.0)); - test_case(5.0, 2.0, 0.69146246127401310363770461060833773988360217555457859, cdf(6.0)); - test_almost(5.0, 2.0, 0.993790334674, 1e-12, cdf(10.0)); + test_exact(5.0, 2.0, 0.0, cdf(f64::NEG_INFINITY)); + test_absolute(5.0, 2.0, 0.0000002866515718, 1e-16, cdf(-5.0)); + test_absolute(5.0, 2.0, 0.0002326290790, 1e-13, cdf(-2.0)); + test_absolute(5.0, 2.0, 0.006209665325, 1e-12, cdf(0.0)); + test_exact(5.0, 2.0, 0.30853753872598689636229538939166226011639782444542207, cdf(4.0)); + test_exact(5.0, 2.0, 0.5, cdf(5.0)); + test_exact(5.0, 2.0, 0.69146246127401310363770461060833773988360217555457859, cdf(6.0)); + test_absolute(5.0, 2.0, 0.993790334674, 1e-12, cdf(10.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Normal| x.sf(arg); - test_case(5.0, 2.0, 1.0, sf(f64::NEG_INFINITY)); - test_almost(5.0, 2.0, 0.9999997133484281, 1e-16, sf(-5.0)); - test_almost(5.0, 2.0, 0.9997673709209455, 1e-13, sf(-2.0)); - test_almost(5.0, 2.0, 0.9937903346744879, 1e-12, sf(0.0)); - test_case(5.0, 2.0, 0.6914624612740131, sf(4.0)); - test_case(5.0, 2.0, 0.5, sf(5.0)); - test_case(5.0, 2.0, 0.3085375387259869, sf(6.0)); - test_almost(5.0, 2.0, 0.006209665325512148, 1e-12, sf(10.0)); + test_exact(5.0, 2.0, 1.0, sf(f64::NEG_INFINITY)); + test_absolute(5.0, 2.0, 0.9999997133484281, 1e-16, sf(-5.0)); + test_absolute(5.0, 2.0, 0.9997673709209455, 1e-13, sf(-2.0)); + test_absolute(5.0, 2.0, 0.9937903346744879, 1e-12, sf(0.0)); + test_exact(5.0, 2.0, 0.6914624612740131, sf(4.0)); + test_exact(5.0, 2.0, 0.5, sf(5.0)); + test_exact(5.0, 2.0, 0.3085375387259869, sf(6.0)); + test_absolute(5.0, 2.0, 0.006209665325512148, 1e-12, sf(10.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(0.0, 1.0), -10.0, 10.0); - test::check_continuous_distribution(&try_create(20.0, 0.5), 10.0, 30.0); + test::check_continuous_distribution(&create_ok(0.0, 1.0), -10.0, 10.0); + test::check_continuous_distribution(&create_ok(20.0, 0.5), 10.0, 30.0); } #[test] fn test_inverse_cdf() { let inverse_cdf = |arg: f64| move |x: Normal| x.inverse_cdf(arg); - test_case(5.0, 2.0, f64::NEG_INFINITY, inverse_cdf( 0.0)); - test_almost(5.0, 2.0, -5.0, 1e-14, inverse_cdf(0.00000028665157187919391167375233287464535385442301361187883)); - test_almost(5.0, 2.0, -2.0, 1e-14, inverse_cdf(0.0002326290790355250363499258867279847735487493358890356)); - test_almost(5.0, 2.0, -0.0, 1e-14, inverse_cdf(0.0062096653257761351669781045741922211278977469230927036)); - test_almost(5.0, 2.0, 0.0, 1e-14, inverse_cdf(0.0062096653257761351669781045741922211278977469230927036)); - test_almost(5.0, 2.0, 4.0, 1e-14, inverse_cdf(0.30853753872598689636229538939166226011639782444542207)); - test_almost(5.0, 2.0, 5.0, 1e-14, inverse_cdf(0.5)); - test_almost(5.0, 2.0, 6.0, 1e-14, inverse_cdf(0.69146246127401310363770461060833773988360217555457859)); - test_almost(5.0, 2.0, 10.0, 1e-14, inverse_cdf(0.9937903346742238648330218954258077788721022530769078)); - test_case(5.0, 2.0, f64::INFINITY, inverse_cdf(1.0)); + test_exact(5.0, 2.0, f64::NEG_INFINITY, inverse_cdf( 0.0)); + test_absolute(5.0, 2.0, -5.0, 1e-14, inverse_cdf(0.00000028665157187919391167375233287464535385442301361187883)); + test_absolute(5.0, 2.0, -2.0, 1e-14, inverse_cdf(0.0002326290790355250363499258867279847735487493358890356)); + test_absolute(5.0, 2.0, -0.0, 1e-14, inverse_cdf(0.0062096653257761351669781045741922211278977469230927036)); + test_absolute(5.0, 2.0, 0.0, 1e-14, inverse_cdf(0.0062096653257761351669781045741922211278977469230927036)); + test_absolute(5.0, 2.0, 4.0, 1e-14, inverse_cdf(0.30853753872598689636229538939166226011639782444542207)); + test_absolute(5.0, 2.0, 5.0, 1e-14, inverse_cdf(0.5)); + test_absolute(5.0, 2.0, 6.0, 1e-14, inverse_cdf(0.69146246127401310363770461060833773988360217555457859)); + test_absolute(5.0, 2.0, 10.0, 1e-14, inverse_cdf(0.9937903346742238648330218954258077788721022530769078)); + test_exact(5.0, 2.0, f64::INFINITY, inverse_cdf(1.0)); + } + + #[test] + fn test_default() { + let n = Normal::default(); + + let n_mean = n.mean().unwrap(); + let n_std = n.std_dev().unwrap(); + + // Check that the mean of the distribution is close to 0 + assert_almost_eq!(n_mean, 0.0, 1e-15); + // Check that the standard deviation of the distribution is close to 1 + assert_almost_eq!(n_std, 1.0, 1e-15); } } diff --git a/src/distribution/pareto.rs b/src/distribution/pareto.rs index e59595a2..1de73d84 100644 --- a/src/distribution/pareto.rs +++ b/src/distribution/pareto.rs @@ -1,8 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; -use rand::distributions::OpenClosed01; -use rand::Rng; use std::f64; /// Implements the [Pareto](https://en.wikipedia.org/wiki/Pareto_distribution) @@ -19,12 +16,35 @@ use std::f64; /// assert_eq!(p.mean().unwrap(), 2.0); /// assert!(prec::almost_eq(p.pdf(2.0), 0.25, 1e-15)); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Pareto { scale: f64, shape: f64, } +/// Represents the errors that can occur when creating a [`Pareto`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum ParetoError { + /// The scale is NaN, zero or less than zero. + ScaleInvalid, + + /// The shape is NaN, zero or less than zero. + ShapeInvalid, +} + +impl std::fmt::Display for ParetoError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + ParetoError::ScaleInvalid => write!(f, "Scale is NaN, zero, or less than zero"), + ParetoError::ShapeInvalid => write!(f, "Shape is NaN, zero, or less than zero"), + } + } +} + +impl std::error::Error for ParetoError {} + impl Pareto { /// Constructs a new Pareto distribution with scale `scale`, and `shape` /// shape. @@ -45,13 +65,16 @@ impl Pareto { /// result = Pareto::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(scale: f64, shape: f64) -> Result { - let is_nan = scale.is_nan() || shape.is_nan(); - if is_nan || scale <= 0.0 || shape <= 0.0 { - Err(StatsError::BadParams) - } else { - Ok(Pareto { scale, shape }) + pub fn new(scale: f64, shape: f64) -> Result { + if scale.is_nan() || scale <= 0.0 { + return Err(ParetoError::ScaleInvalid); } + + if shape.is_nan() || shape <= 0.0 { + return Err(ParetoError::ShapeInvalid); + } + + Ok(Pareto { scale, shape }) } /// Returns the scale of the Pareto distribution @@ -83,8 +106,17 @@ impl Pareto { } } +impl std::fmt::Display for Pareto { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Pareto({},{})", self.scale, self.shape) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Pareto { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { + use rand::distributions::OpenClosed01; + // Inverse transform sampling let u: f64 = rng.sample(OpenClosed01); self.scale * u.powf(-1.0 / self.shape) @@ -97,7 +129,7 @@ impl ContinuousCDF for Pareto { /// /// # Formula /// - /// ```ignore + /// ```text /// if x < x_m { /// 0 /// } else { @@ -119,7 +151,7 @@ impl ContinuousCDF for Pareto { /// /// # Formula /// - /// ```ignore + /// ```text /// if x < x_m { /// 1 /// } else { @@ -135,6 +167,24 @@ impl ContinuousCDF for Pareto { (self.scale / x).powf(self.shape) } } + + /// Calculates the inverse cumulative distribution function for the Pareto + /// distribution at `x` + /// + /// # Formula + /// + /// ```text + /// x_m / (1 - x)^(1 / α) + /// ``` + /// + /// where `x_m` is the scale and `α` is the shape + fn inverse_cdf(&self, p: f64) -> f64 { + if !(0.0..=1.0).contains(&p) { + panic!("x must be in [0, 1]"); + } else { + self.scale * (1.0 - p).powf(-1.0 / self.shape) + } + } } impl Min for Pareto { @@ -143,7 +193,7 @@ impl Min for Pareto { /// /// # Formula /// - /// ```ignore + /// ```text /// x_m /// ``` /// @@ -159,8 +209,8 @@ impl Max for Pareto { /// /// # Formula /// - /// ```ignore - /// INF + /// ```text + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY @@ -172,9 +222,9 @@ impl Distribution for Pareto { /// /// # Formula /// - /// ```ignore + /// ```text /// if α <= 1 { - /// INF + /// f64::INFINITY /// } else { /// (α * x_m)/(α - 1) /// } @@ -188,13 +238,14 @@ impl Distribution for Pareto { Some((self.shape * self.scale) / (self.shape - 1.0)) } } + /// Returns the variance of the Pareto distribution /// /// # Formula /// - /// ```ignore + /// ```text /// if α <= 2 { - /// INF + /// f64::INFINITY /// } else { /// (x_m/(α - 1))^2 * (α/(α - 2)) /// } @@ -209,11 +260,12 @@ impl Distribution for Pareto { Some(a * a * self.shape / (self.shape - 2.0)) } } + /// Returns the entropy for the Pareto distribution /// /// # Formula /// - /// ```ignore + /// ```text /// ln(α/x_m) - 1/α - 1 /// ``` /// @@ -221,6 +273,7 @@ impl Distribution for Pareto { fn entropy(&self) -> Option { Some(self.shape.ln() - self.scale.ln() - (1.0 / self.shape) - 1.0) } + /// Returns the skewness of the Pareto distribution /// /// # Panics @@ -231,7 +284,7 @@ impl Distribution for Pareto { /// /// # Formula /// - /// ```ignore + /// ```text /// (2*(α + 1)/(α - 3))*sqrt((α - 2)/α) /// ``` /// @@ -253,7 +306,7 @@ impl Median for Pareto { /// /// # Formula /// - /// ```ignore + /// ```text /// x_m*2^(1/α) /// ``` /// @@ -268,7 +321,7 @@ impl Mode> for Pareto { /// /// # Formula /// - /// ```ignore + /// ```text /// x_m /// ``` /// @@ -284,7 +337,7 @@ impl Continuous for Pareto { /// /// # Formula /// - /// ```ignore + /// ```text /// if x < x_m { /// 0 /// } else { @@ -306,9 +359,9 @@ impl Continuous for Pareto { /// /// # Formula /// - /// ```ignore + /// ```text /// if x < x_m { - /// -INF + /// f64::NEG_INFINITY /// } else { /// ln(α) + α*ln(x_m) - (α + 1)*ln(x) /// } @@ -325,206 +378,175 @@ impl Continuous for Pareto { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use crate::statistics::*; - use crate::distribution::{ContinuousCDF, Continuous, Pareto}; + use super::*; use crate::distribution::internal::*; - use crate::consts::ACC; - - fn try_create(scale: f64, shape: f64) -> Pareto { - let p = Pareto::new(scale, shape); - assert!(p.is_ok()); - p.unwrap() - } + use crate::testing_boiler; - fn create_case(scale: f64, shape: f64) { - let p = try_create(scale, shape); - assert_eq!(scale, p.scale()); - assert_eq!(shape, p.shape()); - } - - fn bad_create_case(scale: f64, shape: f64) { - let p = Pareto::new(scale, shape); - assert!(p.is_err()); - } - - fn get_value(scale: f64, shape: f64, eval: F) -> T - where F: Fn(Pareto) -> T - { - let p = try_create(scale, shape); - eval(p) - } - - fn test_case(scale: f64, shape: f64, expected: f64, eval: F) - where F: Fn(Pareto) -> f64 - { - let x = get_value(scale, shape, eval); - assert_eq!(expected, x); - } - - fn test_almost(scale: f64, shape: f64, expected: f64, acc: f64, eval: F) - where F: Fn(Pareto) -> f64 - { - let p = try_create(scale, shape); - let x = eval(p); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(scale: f64, shape: f64; Pareto; ParetoError); #[test] fn test_create() { - create_case(10.0, 0.1); - create_case(5.0, 1.0); - create_case(0.1, 10.0); - create_case(10.0, 100.0); - create_case(1.0, f64::INFINITY); - create_case(f64::INFINITY, f64::INFINITY); + create_ok(10.0, 0.1); + create_ok(5.0, 1.0); + create_ok(0.1, 10.0); + create_ok(10.0, 100.0); + create_ok(1.0, f64::INFINITY); + create_ok(f64::INFINITY, f64::INFINITY); } #[test] fn test_bad_create() { - bad_create_case(0.0, 0.0); - bad_create_case(1.0, -1.0); - bad_create_case(-1.0, 1.0); - bad_create_case(-1.0, -1.0); - bad_create_case(f64::NAN, 1.0); - bad_create_case(1.0, f64::NAN); - bad_create_case(f64::NAN, f64::NAN); + test_create_err(1.0, -1.0, ParetoError::ShapeInvalid); + test_create_err(-1.0, 1.0, ParetoError::ScaleInvalid); + create_err(0.0, 0.0); + create_err(-1.0, -1.0); + create_err(f64::NAN, 1.0); + create_err(1.0, f64::NAN); + create_err(f64::NAN, f64::NAN); } #[test] fn test_variance() { let variance = |x: Pareto| x.variance().unwrap(); - test_case(1.0, 3.0, 0.75, variance); - test_almost(10.0, 10.0, 125.0 / 81.0, 1e-13, variance); + test_exact(1.0, 3.0, 0.75, variance); + test_absolute(10.0, 10.0, 125.0 / 81.0, 1e-13, variance); } #[test] - #[should_panic] fn test_variance_degen() { - let variance = |x: Pareto| x.variance().unwrap(); - test_case(1.0, 1.0, f64::INFINITY, variance); // shape <= 2.0 + test_none(1.0, 1.0, |dist| dist.variance()); // shape <= 2.0 } #[test] fn test_entropy() { let entropy = |x: Pareto| x.entropy().unwrap(); - test_case(0.1, 0.1, -11.0, entropy); - test_case(1.0, 1.0, -2.0, entropy); - test_case(10.0, 10.0, -1.1, entropy); - test_case(3.0, 1.0, -2.0 - 3f64.ln(), entropy); - test_case(1.0, 3.0, -4.0/3.0 + 3f64.ln(), entropy); + test_exact(0.1, 0.1, -11.0, entropy); + test_exact(1.0, 1.0, -2.0, entropy); + test_exact(10.0, 10.0, -1.1, entropy); + test_exact(3.0, 1.0, -2.0 - 3f64.ln(), entropy); + test_exact(1.0, 3.0, -4.0/3.0 + 3f64.ln(), entropy); } #[test] fn test_skewness() { let skewness = |x: Pareto| x.skewness().unwrap(); - test_case(1.0, 4.0, 5.0*2f64.sqrt(), skewness); - test_case(1.0, 100.0, (707.0/485.0)*2f64.sqrt(), skewness); + test_exact(1.0, 4.0, 5.0*2f64.sqrt(), skewness); + test_exact(1.0, 100.0, (707.0/485.0)*2f64.sqrt(), skewness); } #[test] - #[should_panic] fn test_skewness_invalid_shape() { - let skewness = |x: Pareto| x.skewness().unwrap(); - get_value(1.0, 3.0, skewness); + test_none(1.0, 3.0, |dist| dist.skewness()); } #[test] fn test_mode() { let mode = |x: Pareto| x.mode().unwrap(); - test_case(0.1, 1.0, 0.1, mode); - test_case(2.0, 1.0, 2.0, mode); - test_case(10.0, f64::INFINITY, 10.0, mode); - test_case(f64::INFINITY, 1.0, f64::INFINITY, mode); + test_exact(0.1, 1.0, 0.1, mode); + test_exact(2.0, 1.0, 2.0, mode); + test_exact(10.0, f64::INFINITY, 10.0, mode); + test_exact(f64::INFINITY, 1.0, f64::INFINITY, mode); } #[test] fn test_median() { let median = |x: Pareto| x.median(); - test_case(0.1, 0.1, 102.4, median); - test_case(1.0, 1.0, 2.0, median); - test_case(10.0, 10.0, 10.0*2f64.powf(0.1), median); - test_case(3.0, 0.5, 12.0, median); - test_case(10.0, f64::INFINITY, 10.0, median); + test_exact(0.1, 0.1, 102.4, median); + test_exact(1.0, 1.0, 2.0, median); + test_exact(10.0, 10.0, 10.0*2f64.powf(0.1), median); + test_exact(3.0, 0.5, 12.0, median); + test_exact(10.0, f64::INFINITY, 10.0, median); } #[test] fn test_min_max() { let min = |x: Pareto| x.min(); let max = |x: Pareto| x.max(); - test_case(0.2, f64::INFINITY, 0.2, min); - test_case(10.0, f64::INFINITY, 10.0, min); - test_case(f64::INFINITY, 1.0, f64::INFINITY, min); - test_case(1.0, 0.1, f64::INFINITY, max); - test_case(3.0, 10.0, f64::INFINITY, max); + test_exact(0.2, f64::INFINITY, 0.2, min); + test_exact(10.0, f64::INFINITY, 10.0, min); + test_exact(f64::INFINITY, 1.0, f64::INFINITY, min); + test_exact(1.0, 0.1, f64::INFINITY, max); + test_exact(3.0, 10.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Pareto| x.pdf(arg); - test_case(1.0, 1.0, 0.0, pdf(0.1)); - test_case(1.0, 1.0, 1.0, pdf(1.0)); - test_case(1.0, 1.0, 4.0/9.0, pdf(1.5)); - test_case(1.0, 1.0, 1.0/25.0, pdf(5.0)); - test_case(1.0, 1.0, 1.0/2500.0, pdf(50.0)); - test_case(1.0, 4.0, 4.0, pdf(1.0)); - test_case(1.0, 4.0, 128.0/243.0, pdf(1.5)); - test_case(1.0, 4.0, 1.0/78125000.0, pdf(50.0)); - test_case(3.0, 2.0, 2.0/3.0, pdf(3.0)); - test_case(3.0, 2.0, 18.0/125.0, pdf(5.0)); - test_almost(25.0, 100.0, 1.5777218104420236e-30, 1e-50, pdf(50.0)); - test_almost(100.0, 25.0, 6.6003546737276816e-6, 1e-16, pdf(150.0)); - test_case(1.0, 2.0, 0.0, pdf(f64::INFINITY)); + test_exact(1.0, 1.0, 0.0, pdf(0.1)); + test_exact(1.0, 1.0, 1.0, pdf(1.0)); + test_exact(1.0, 1.0, 4.0/9.0, pdf(1.5)); + test_exact(1.0, 1.0, 1.0/25.0, pdf(5.0)); + test_exact(1.0, 1.0, 1.0/2500.0, pdf(50.0)); + test_exact(1.0, 4.0, 4.0, pdf(1.0)); + test_exact(1.0, 4.0, 128.0/243.0, pdf(1.5)); + test_exact(1.0, 4.0, 1.0/78125000.0, pdf(50.0)); + test_exact(3.0, 2.0, 2.0/3.0, pdf(3.0)); + test_exact(3.0, 2.0, 18.0/125.0, pdf(5.0)); + test_absolute(25.0, 100.0, 1.5777218104420236e-30, 1e-50, pdf(50.0)); + test_absolute(100.0, 25.0, 6.6003546737276816e-6, 1e-16, pdf(150.0)); + test_exact(1.0, 2.0, 0.0, pdf(f64::INFINITY)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Pareto| x.ln_pdf(arg); - test_case(1.0, 1.0, f64::NEG_INFINITY, ln_pdf(0.1)); - test_case(1.0, 1.0, 0.0, ln_pdf(1.0)); - test_almost(1.0, 1.0, 4f64.ln() - 9f64.ln(), 1e-14, ln_pdf(1.5)); - test_almost(1.0, 1.0, -(25f64.ln()), 1e-14, ln_pdf(5.0)); - test_almost(1.0, 1.0, -(2500f64.ln()), 1e-14, ln_pdf(50.0)); - test_almost(1.0, 4.0, 4f64.ln(), 1e-14, ln_pdf(1.0)); - test_almost(1.0, 4.0, 128f64.ln() - 243f64.ln(), 1e-14, ln_pdf(1.5)); - test_almost(1.0, 4.0, -(78125000f64.ln()), 1e-14, ln_pdf(50.0)); - test_almost(3.0, 2.0, 2f64.ln() - 3f64.ln(), 1e-14, ln_pdf(3.0)); - test_almost(3.0, 2.0, 18f64.ln() - 125f64.ln(), 1e-14, ln_pdf(5.0)); - test_almost(25.0, 100.0, 1.5777218104420236e-30f64.ln(), 1e-12, ln_pdf(50.0)); - test_almost(100.0, 25.0, 6.6003546737276816e-6f64.ln(), 1e-12, ln_pdf(150.0)); - test_case(1.0, 2.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); + test_exact(1.0, 1.0, f64::NEG_INFINITY, ln_pdf(0.1)); + test_exact(1.0, 1.0, 0.0, ln_pdf(1.0)); + test_absolute(1.0, 1.0, 4f64.ln() - 9f64.ln(), 1e-14, ln_pdf(1.5)); + test_absolute(1.0, 1.0, -(25f64.ln()), 1e-14, ln_pdf(5.0)); + test_absolute(1.0, 1.0, -(2500f64.ln()), 1e-14, ln_pdf(50.0)); + test_absolute(1.0, 4.0, 4f64.ln(), 1e-14, ln_pdf(1.0)); + test_absolute(1.0, 4.0, 128f64.ln() - 243f64.ln(), 1e-14, ln_pdf(1.5)); + test_absolute(1.0, 4.0, -(78125000f64.ln()), 1e-14, ln_pdf(50.0)); + test_absolute(3.0, 2.0, 2f64.ln() - 3f64.ln(), 1e-14, ln_pdf(3.0)); + test_absolute(3.0, 2.0, 18f64.ln() - 125f64.ln(), 1e-14, ln_pdf(5.0)); + test_absolute(25.0, 100.0, 1.5777218104420236e-30f64.ln(), 1e-12, ln_pdf(50.0)); + test_absolute(100.0, 25.0, 6.6003546737276816e-6f64.ln(), 1e-12, ln_pdf(150.0)); + test_exact(1.0, 2.0, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Pareto| x.cdf(arg); - test_case(0.1, 0.1, 0.0, cdf(0.1)); - test_case(1.0, 1.0, 0.0, cdf(1.0)); - test_case(5.0, 5.0, 0.0, cdf(2.0)); - test_case(7.0, 7.0, 0.9176457, cdf(10.0)); - test_case(10.0, 10.0, 50700551.0/60466176.0, cdf(12.0)); - test_case(5.0, 1.0, 0.5, cdf(10.0)); - test_case(3.0, 10.0, 1023.0/1024.0, cdf(6.0)); - test_case(1.0, 1.0, 1.0, cdf(f64::INFINITY)); + test_exact(0.1, 0.1, 0.0, cdf(0.1)); + test_exact(1.0, 1.0, 0.0, cdf(1.0)); + test_exact(5.0, 5.0, 0.0, cdf(2.0)); + test_exact(7.0, 7.0, 0.9176457, cdf(10.0)); + test_exact(10.0, 10.0, 50700551.0/60466176.0, cdf(12.0)); + test_exact(5.0, 1.0, 0.5, cdf(10.0)); + test_exact(3.0, 10.0, 1023.0/1024.0, cdf(6.0)); + test_exact(1.0, 1.0, 1.0, cdf(f64::INFINITY)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Pareto| x.sf(arg); - test_case(0.1, 0.1, 1.0, sf(0.1)); - test_case(1.0, 1.0, 1.0, sf(1.0)); - test_case(5.0, 5.0, 1.0, sf(2.0)); - test_almost(7.0, 7.0, 0.08235429999999999, 1e-14, sf(10.0)); - test_almost(10.0, 10.0, 0.16150558288984573, 1e14, sf(12.0)); - test_case(5.0, 1.0, 0.5, sf(10.0)); - test_almost(3.0, 10.0, 0.0009765625, 1e-14, sf(6.0)); - test_case(1.0, 1.0, 0.0, sf(f64::INFINITY)); + test_exact(0.1, 0.1, 1.0, sf(0.1)); + test_exact(1.0, 1.0, 1.0, sf(1.0)); + test_exact(5.0, 5.0, 1.0, sf(2.0)); + test_absolute(7.0, 7.0, 0.08235429999999999, 1e-14, sf(10.0)); + test_absolute(10.0, 10.0, 0.16150558288984573, 1e-14, sf(12.0)); + test_exact(5.0, 1.0, 0.5, sf(10.0)); + test_absolute(3.0, 10.0, 0.0009765625, 1e-14, sf(6.0)); + test_exact(1.0, 1.0, 0.0, sf(f64::INFINITY)); + } + + #[test] + fn test_inverse_cdf() { + let func = |arg: f64| move |x: Pareto| x.inverse_cdf(x.cdf(arg)); + test_exact(0.1, 0.1, 0.1, func(0.1)); + test_exact(1.0, 1.0, 1.0, func(1.0)); + test_exact(7.0, 7.0, 10.0, func(10.0)); + test_exact(10.0, 10.0, 12.0, func(12.0)); + test_exact(5.0, 1.0, 10.0, func(10.0)); + test_exact(3.0, 10.0, 6.0, func(6.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(1.0, 10.0), 1.0, 10.0); - test::check_continuous_distribution(&try_create(0.1, 2.0), 0.1, 100.0); + test::check_continuous_distribution(&create_ok(1.0, 10.0), 1.0, 10.0); + test::check_continuous_distribution(&create_ok(0.1, 2.0), 0.1, 100.0); } } diff --git a/src/distribution/poisson.rs b/src/distribution/poisson.rs index 0c1b2379..b3f1ebab 100644 --- a/src/distribution/poisson.rs +++ b/src/distribution/poisson.rs @@ -1,10 +1,7 @@ use crate::distribution::{Discrete, DiscreteCDF}; use crate::function::{factorial, gamma}; use crate::statistics::*; -use crate::{Result, StatsError}; -use rand::Rng; use std::f64; -use std::u64; /// Implements the [Poisson](https://en.wikipedia.org/wiki/Poisson_distribution) /// distribution @@ -20,11 +17,30 @@ use std::u64; /// assert_eq!(n.mean().unwrap(), 1.0); /// assert!(prec::almost_eq(n.pmf(1), 0.367879441171442, 1e-15)); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Poisson { lambda: f64, } +/// Represents the errors that can occur when creating a [`Poisson`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum PoissonError { + /// The lambda is NaN, zero or less than zero. + LambdaInvalid, +} + +impl std::fmt::Display for PoissonError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + PoissonError::LambdaInvalid => write!(f, "Lambda is NaN, zero or less than zero"), + } + } +} + +impl std::error::Error for PoissonError {} + impl Poisson { /// Constructs a new poisson distribution with a rate (λ) /// of `lambda` @@ -44,9 +60,9 @@ impl Poisson { /// result = Poisson::new(0.0); /// assert!(result.is_err()); /// ``` - pub fn new(lambda: f64) -> Result { + pub fn new(lambda: f64) -> Result { if lambda.is_nan() || lambda <= 0.0 { - Err(StatsError::BadParams) + Err(PoissonError::LambdaInvalid) } else { Ok(Poisson { lambda }) } @@ -67,13 +83,20 @@ impl Poisson { } } +impl std::fmt::Display for Poisson { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Pois({})", self.lambda) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Poisson { /// Generates one sample from the Poisson distribution either by /// Knuth's method if lambda < 30.0 or Rejection method PA by /// A. C. Atkinson from the Journal of the Royal Statistical Society /// Series C (Applied Statistics) Vol. 28 No. 1. (1979) pp. 29 - 35 /// otherwise - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { sample_unchecked(rng, self.lambda) } } @@ -84,7 +107,7 @@ impl DiscreteCDF for Poisson { /// /// # Formula /// - /// ```ignore + /// ```text /// P(x + 1, λ) /// ``` /// @@ -98,7 +121,7 @@ impl DiscreteCDF for Poisson { /// /// # Formula /// - /// ```ignore + /// ```text /// P(x + 1, λ) /// ``` /// @@ -114,7 +137,7 @@ impl Min for Poisson { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> u64 { @@ -128,7 +151,7 @@ impl Max for Poisson { /// /// # Formula /// - /// ```ignore + /// ```text /// 2^63 - 1 /// ``` fn max(&self) -> u64 { @@ -141,7 +164,7 @@ impl Distribution for Poisson { /// /// # Formula /// - /// ```ignore + /// ```text /// λ /// ``` /// @@ -149,11 +172,12 @@ impl Distribution for Poisson { fn mean(&self) -> Option { Some(self.lambda) } + /// Returns the variance of the poisson distribution /// /// # Formula /// - /// ```ignore + /// ```text /// λ /// ``` /// @@ -161,11 +185,12 @@ impl Distribution for Poisson { fn variance(&self) -> Option { Some(self.lambda) } + /// Returns the entropy of the poisson distribution /// /// # Formula /// - /// ```ignore + /// ```text /// (1 / 2) * ln(2πeλ) - 1 / (12λ) - 1 / (24λ^2) - 19 / (360λ^3) /// ``` /// @@ -178,11 +203,12 @@ impl Distribution for Poisson { - 19.0 / (360.0 * self.lambda * self.lambda * self.lambda), ) } + /// Returns the skewness of the poisson distribution /// /// # Formula /// - /// ```ignore + /// ```text /// λ^(-1/2) /// ``` /// @@ -197,7 +223,7 @@ impl Median for Poisson { /// /// # Formula /// - /// ```ignore + /// ```text /// floor(λ + 1 / 3 - 0.02 / λ) /// ``` /// @@ -212,7 +238,7 @@ impl Mode> for Poisson { /// /// # Formula /// - /// ```ignore + /// ```text /// floor(λ) /// ``` /// @@ -228,13 +254,13 @@ impl Discrete for Poisson { /// /// # Formula /// - /// ```ignore - /// (λ^k * e^(-λ)) / x! + /// ```text + /// (λ^x * e^(-λ)) / x! /// ``` /// /// where `λ` is the rate fn pmf(&self, x: u64) -> f64 { - (-self.lambda + x as f64 * self.lambda.ln() - factorial::ln_factorial(x as u64)).exp() + (-self.lambda + x as f64 * self.lambda.ln() - factorial::ln_factorial(x)).exp() } /// Calculates the log probability mass function for the poisson @@ -243,13 +269,13 @@ impl Discrete for Poisson { /// /// # Formula /// - /// ```ignore - /// ln((λ^k * e^(-λ)) / x!) + /// ```text + /// ln((λ^x * e^(-λ)) / x!) /// ``` /// /// where `λ` is the rate fn ln_pmf(&self, x: u64) -> f64 { - -self.lambda + x as f64 * self.lambda.ln() - factorial::ln_factorial(x as u64) + -self.lambda + x as f64 * self.lambda.ln() - factorial::ln_factorial(x) } } /// Generates one sample from the Poisson distribution either by @@ -257,7 +283,8 @@ impl Discrete for Poisson { /// A. C. Atkinson from the Journal of the Royal Statistical Society /// Series C (Applied Statistics) Vol. 28 No. 1. (1979) pp. 29 - 35 /// otherwise -pub fn sample_unchecked(rng: &mut R, lambda: f64) -> f64 { +#[cfg(feature = "rand")] +pub fn sample_unchecked(rng: &mut R, lambda: f64) -> f64 { if lambda < 30.0 { let limit = (-lambda).exp(); let mut count = 0.0; @@ -294,186 +321,147 @@ pub fn sample_unchecked(rng: &mut R, lambda: f64) -> f64 { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use std::fmt::Debug; - use crate::statistics::*; - use crate::distribution::{DiscreteCDF, Discrete, Poisson}; + use super::*; use crate::distribution::internal::*; - use crate::consts::ACC; - - fn try_create(lambda: f64) -> Poisson { - let n = Poisson::new(lambda); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(lambda: f64) { - let n = try_create(lambda); - assert_eq!(lambda, n.lambda()); - } - - fn bad_create_case(lambda: f64) { - let n = Poisson::new(lambda); - assert!(n.is_err()); - } - - fn get_value(lambda: f64, eval: F) -> T - where T: PartialEq + Debug, - F: Fn(Poisson) -> T - { - let n = try_create(lambda); - eval(n) - } + use crate::testing_boiler; - fn test_case(lambda: f64, expected: T, eval: F) - where T: PartialEq + Debug, - F: Fn(Poisson) -> T - { - let x = get_value(lambda, eval); - assert_eq!(expected, x); - } - - fn test_almost(lambda: f64, expected: f64, acc: f64, eval: F) - where F: Fn(Poisson) -> f64 - { - let x = get_value(lambda, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(lambda: f64; Poisson; PoissonError); #[test] fn test_create() { - create_case(1.5); - create_case(5.4); - create_case(10.8); + create_ok(1.5); + create_ok(5.4); + create_ok(10.8); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN); - bad_create_case(-1.5); - bad_create_case(0.0); + create_err(f64::NAN); + create_err(-1.5); + create_err(0.0); } #[test] fn test_mean() { let mean = |x: Poisson| x.mean().unwrap(); - test_case(1.5, 1.5, mean); - test_case(5.4, 5.4, mean); - test_case(10.8, 10.8, mean); + test_exact(1.5, 1.5, mean); + test_exact(5.4, 5.4, mean); + test_exact(10.8, 10.8, mean); } #[test] fn test_variance() { let variance = |x: Poisson| x.variance().unwrap(); - test_case(1.5, 1.5, variance); - test_case(5.4, 5.4, variance); - test_case(10.8, 10.8, variance); + test_exact(1.5, 1.5, variance); + test_exact(5.4, 5.4, variance); + test_exact(10.8, 10.8, variance); } #[test] fn test_entropy() { let entropy = |x: Poisson| x.entropy().unwrap(); - test_almost(1.5, 1.531959153102376331946, 1e-15, entropy); - test_almost(5.4, 2.244941839577643504608, 1e-15, entropy); - test_case(10.8, 2.600596429676975222694, entropy); + test_absolute(1.5, 1.531959153102376331946, 1e-15, entropy); + test_absolute(5.4, 2.244941839577643504608, 1e-15, entropy); + test_exact(10.8, 2.600596429676975222694, entropy); } #[test] fn test_skewness() { let skewness = |x: Poisson| x.skewness().unwrap(); - test_almost(1.5, 0.8164965809277260327324, 1e-15, skewness); - test_almost(5.4, 0.4303314829119352094644, 1e-16, skewness); - test_almost(10.8, 0.3042903097250922852539, 1e-16, skewness); + test_absolute(1.5, 0.8164965809277260327324, 1e-15, skewness); + test_absolute(5.4, 0.4303314829119352094644, 1e-16, skewness); + test_absolute(10.8, 0.3042903097250922852539, 1e-16, skewness); } #[test] fn test_median() { let median = |x: Poisson| x.median(); - test_case(1.5, 1.0, median); - test_case(5.4, 5.0, median); - test_case(10.8, 11.0, median); + test_exact(1.5, 1.0, median); + test_exact(5.4, 5.0, median); + test_exact(10.8, 11.0, median); } #[test] fn test_mode() { let mode = |x: Poisson| x.mode().unwrap(); - test_case(1.5, 1, mode); - test_case(5.4, 5, mode); - test_case(10.8, 10, mode); + test_exact(1.5, 1, mode); + test_exact(5.4, 5, mode); + test_exact(10.8, 10, mode); } #[test] fn test_min_max() { let min = |x: Poisson| x.min(); let max = |x: Poisson| x.max(); - test_case(1.5, 0, min); - test_case(5.4, 0, min); - test_case(10.8, 0, min); - test_case(1.5, u64::MAX, max); - test_case(5.4, u64::MAX, max); - test_case(10.8, u64::MAX, max); + test_exact(1.5, 0, min); + test_exact(5.4, 0, min); + test_exact(10.8, 0, min); + test_exact(1.5, u64::MAX, max); + test_exact(5.4, u64::MAX, max); + test_exact(10.8, u64::MAX, max); } #[test] fn test_pmf() { let pmf = |arg: u64| move |x: Poisson| x.pmf(arg); - test_almost(1.5, 0.334695240222645000000000000000, 1e-15, pmf(1)); - test_almost(1.5, 0.000003545747740570180000000000, 1e-20, pmf(10)); - test_almost(1.5, 0.000000000000000304971208961018, 1e-30, pmf(20)); - test_almost(5.4, 0.024389537090108400000000000000, 1e-17, pmf(1)); - test_almost(5.4, 0.026241240591792300000000000000, 1e-16, pmf(10)); - test_almost(5.4, 0.000000825202200316548000000000, 1e-20, pmf(20)); - test_almost(10.8, 0.000220314636840657000000000000, 1e-18, pmf(1)); - test_almost(10.8, 0.121365183659420000000000000000, 1e-15, pmf(10)); - test_almost(10.8, 0.003908139778574110000000000000, 1e-16, pmf(20)); + test_absolute(1.5, 0.334695240222645000000000000000, 1e-15, pmf(1)); + test_absolute(1.5, 0.000003545747740570180000000000, 1e-20, pmf(10)); + test_absolute(1.5, 0.000000000000000304971208961018, 1e-30, pmf(20)); + test_absolute(5.4, 0.024389537090108400000000000000, 1e-17, pmf(1)); + test_absolute(5.4, 0.026241240591792300000000000000, 1e-16, pmf(10)); + test_absolute(5.4, 0.000000825202200316548000000000, 1e-20, pmf(20)); + test_absolute(10.8, 0.000220314636840657000000000000, 1e-18, pmf(1)); + test_absolute(10.8, 0.121365183659420000000000000000, 1e-15, pmf(10)); + test_absolute(10.8, 0.003908139778574110000000000000, 1e-16, pmf(20)); } #[test] fn test_ln_pmf() { let ln_pmf = |arg: u64| move |x: Poisson| x.ln_pmf(arg); - test_almost(1.5, -1.09453489189183485135413967177, 1e-15, ln_pmf(1)); - test_almost(1.5, -12.5497614919938728510400000000, 1e-14, ln_pmf(10)); - test_almost(1.5, -35.7263142985901000000000000000, 1e-13, ln_pmf(20)); - test_case(5.4, -3.71360104642977159156055355910, ln_pmf(1)); - test_almost(5.4, -3.64042303737322774736223038530, 1e-15, ln_pmf(10)); - test_almost(5.4, -14.0076373893489089949388000000, 1e-14, ln_pmf(20)); - test_almost(10.8, -8.42045386586982559781714423000, 1e-14, ln_pmf(1)); - test_almost(10.8, -2.10895123177378079525424989992, 1e-14, ln_pmf(10)); - test_almost(10.8, -5.54469377815000936289610059500, 1e-14, ln_pmf(20)); + test_absolute(1.5, -1.09453489189183485135413967177, 1e-15, ln_pmf(1)); + test_absolute(1.5, -12.5497614919938728510400000000, 1e-14, ln_pmf(10)); + test_absolute(1.5, -35.7263142985901000000000000000, 1e-13, ln_pmf(20)); + test_exact(5.4, -3.71360104642977159156055355910, ln_pmf(1)); + test_absolute(5.4, -3.64042303737322774736223038530, 1e-15, ln_pmf(10)); + test_absolute(5.4, -14.0076373893489089949388000000, 1e-14, ln_pmf(20)); + test_absolute(10.8, -8.42045386586982559781714423000, 1e-14, ln_pmf(1)); + test_absolute(10.8, -2.10895123177378079525424989992, 1e-14, ln_pmf(10)); + test_absolute(10.8, -5.54469377815000936289610059500, 1e-14, ln_pmf(20)); } #[test] fn test_cdf() { let cdf = |arg: u64| move |x: Poisson| x.cdf(arg); - test_almost(1.5, 0.5578254003710750000000, 1e-15, cdf(1)); - test_almost(1.5, 0.9999994482467640000000, 1e-15, cdf(10)); - test_case(1.5, 1.0, cdf(20)); - test_almost(5.4, 0.0289061180327211000000, 1e-16, cdf(1)); - test_almost(5.4, 0.9774863006897650000000, 1e-15, cdf(10)); - test_almost(5.4, 0.9999997199928290000000, 1e-15, cdf(20)); - test_almost(10.8, 0.0002407141402518290000, 1e-16, cdf(1)); - test_almost(10.8, 0.4839692359955690000000, 1e-15, cdf(10)); - test_almost(10.8, 0.9961800769608090000000, 1e-15, cdf(20)); + test_absolute(1.5, 0.5578254003710750000000, 1e-15, cdf(1)); + test_absolute(1.5, 0.9999994482467640000000, 1e-15, cdf(10)); + test_exact(1.5, 1.0, cdf(20)); + test_absolute(5.4, 0.0289061180327211000000, 1e-16, cdf(1)); + test_absolute(5.4, 0.9774863006897650000000, 1e-15, cdf(10)); + test_absolute(5.4, 0.9999997199928290000000, 1e-15, cdf(20)); + test_absolute(10.8, 0.0002407141402518290000, 1e-16, cdf(1)); + test_absolute(10.8, 0.4839692359955690000000, 1e-15, cdf(10)); + test_absolute(10.8, 0.9961800769608090000000, 1e-15, cdf(20)); } #[test] fn test_sf() { let sf = |arg: u64| move |x: Poisson| x.sf(arg); - test_almost(1.5, 0.44217459962892536, 1e-15, sf(1)); - test_almost(1.5, 0.0000005517532358246565, 1e-15, sf(10)); - test_almost(1.5, 2.3372210700347092e-17, 1e-15, sf(20)); - test_almost(5.4, 0.971093881967279, 1e-16, sf(1)); - test_almost(5.4, 0.022513699310235582, 1e-15, sf(10)); - test_almost(5.4, 0.0000002800071708975261, 1e-15, sf(20)); - test_almost(10.8, 0.9997592858597482, 1e-16, sf(1)); - test_almost(10.8, 0.5160307640044303, 1e-15, sf(10)); - test_almost(10.8, 0.003819923039191422, 1e-15, sf(20)); + test_absolute(1.5, 0.44217459962892536, 1e-15, sf(1)); + test_absolute(1.5, 0.0000005517532358246565, 1e-15, sf(10)); + test_absolute(1.5, 2.3372210700347092e-17, 1e-15, sf(20)); + test_absolute(5.4, 0.971093881967279, 1e-16, sf(1)); + test_absolute(5.4, 0.022513699310235582, 1e-15, sf(10)); + test_absolute(5.4, 0.0000002800071708975261, 1e-15, sf(20)); + test_absolute(10.8, 0.9997592858597482, 1e-16, sf(1)); + test_absolute(10.8, 0.5160307640044303, 1e-15, sf(10)); + test_absolute(10.8, 0.003819923039191422, 1e-15, sf(20)); } #[test] fn test_discrete() { - test::check_discrete_distribution(&try_create(0.3), 10); - test::check_discrete_distribution(&try_create(4.5), 30); + test::check_discrete_distribution(&create_ok(0.3), 10); + test::check_discrete_distribution(&create_ok(4.5), 30); } } diff --git a/src/distribution/students_t.rs b/src/distribution/students_t.rs index 277647bf..117fbc87 100644 --- a/src/distribution/students_t.rs +++ b/src/distribution/students_t.rs @@ -1,9 +1,6 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::{beta, gamma}; -use crate::is_zero; use crate::statistics::*; -use crate::{Result, StatsError}; -use rand::Rng; use std::f64; /// Implements the [Student's @@ -20,22 +17,50 @@ use std::f64; /// assert_eq!(n.mean().unwrap(), 0.0); /// assert!(prec::almost_eq(n.pdf(0.0), 0.353553390593274, 1e-15)); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct StudentsT { location: f64, scale: f64, freedom: f64, } +/// Represents the errors that can occur when creating a [`StudentsT`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum StudentsTError { + /// The location is NaN. + LocationInvalid, + + /// The scale is NaN, zero or less than zero. + ScaleInvalid, + + /// The degrees of freedom are NaN, zero or less than zero. + FreedomInvalid, +} + +impl std::fmt::Display for StudentsTError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + StudentsTError::LocationInvalid => write!(f, "Location is NaN"), + StudentsTError::ScaleInvalid => write!(f, "Scale is NaN, zero or less than zero"), + StudentsTError::FreedomInvalid => { + write!(f, "Degrees of freedom are NaN, zero or less than zero") + } + } + } +} + +impl std::error::Error for StudentsTError {} + impl StudentsT { /// Constructs a new student's t-distribution with location `location`, - /// scale `scale`, - /// and `freedom` freedom. + /// scale `scale`, and `freedom` freedom. /// /// # Errors /// /// Returns an error if any of `location`, `scale`, or `freedom` are `NaN`. - /// Returns an error if `scale <= 0.0` or `freedom <= 0.0` + /// Returns an error if `scale <= 0.0` or `freedom <= 0.0`. /// /// # Examples /// @@ -48,17 +73,24 @@ impl StudentsT { /// result = StudentsT::new(0.0, 0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(location: f64, scale: f64, freedom: f64) -> Result { - let is_nan = location.is_nan() || scale.is_nan() || freedom.is_nan(); - if is_nan || scale <= 0.0 || freedom <= 0.0 { - Err(StatsError::BadParams) - } else { - Ok(StudentsT { - location, - scale, - freedom, - }) + pub fn new(location: f64, scale: f64, freedom: f64) -> Result { + if location.is_nan() { + return Err(StudentsTError::LocationInvalid); + } + + if scale.is_nan() || scale <= 0.0 { + return Err(StudentsTError::ScaleInvalid); + } + + if freedom.is_nan() || freedom <= 0.0 { + return Err(StudentsTError::FreedomInvalid); } + + Ok(StudentsT { + location, + scale, + freedom, + }) } /// Returns the location of the student's t-distribution @@ -104,8 +136,15 @@ impl StudentsT { } } +impl std::fmt::Display for StudentsT { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "t_{}({},{})", self.freedom, self.location, self.scale) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for StudentsT { - fn sample(&self, r: &mut R) -> f64 { + fn sample(&self, r: &mut R) -> f64 { // based on method 2, section 5 in chapter 9 of L. Devroye's // "Non-Uniform Random Variate Generation" let gamma = super::gamma::sample_unchecked(r, 0.5 * self.freedom, 0.5); @@ -124,7 +163,7 @@ impl ContinuousCDF for StudentsT { /// /// # Formula /// - /// ```ignore + /// ```text /// if x < μ { /// (1 / 2) * I(t, v / 2, 1 / 2) /// } else { @@ -156,7 +195,7 @@ impl ContinuousCDF for StudentsT { /// /// # Formula /// - /// ```ignore + /// ```text /// if x < μ { /// 1 - (1 / 2) * I(t, v / 2, 1 / 2) /// } else { @@ -209,8 +248,8 @@ impl Min for StudentsT { /// /// # Formula /// - /// ```ignore - /// -INF + /// ```text + /// f64::NEG_INFINITY /// ``` fn min(&self) -> f64 { f64::NEG_INFINITY @@ -223,8 +262,8 @@ impl Max for StudentsT { /// /// # Formula /// - /// ```ignore - /// INF + /// ```text + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY @@ -240,7 +279,7 @@ impl Distribution for StudentsT { /// /// # Formula /// - /// ```ignore + /// ```text /// μ /// ``` /// @@ -252,6 +291,7 @@ impl Distribution for StudentsT { Some(self.location) } } + /// Returns the variance of the student's t-distribution /// /// # None @@ -260,8 +300,8 @@ impl Distribution for StudentsT { /// /// # Formula /// - /// ```ignore - /// if v == INF { + /// ```text + /// if v == f64::INFINITY { /// Some(σ^2) /// } else if freedom > 2.0 { /// Some(v * σ^2 / (v - 2)) @@ -280,11 +320,12 @@ impl Distribution for StudentsT { None } } + /// Returns the entropy for the student's t-distribution /// /// # Formula /// - /// ```ignore + /// ```text /// - ln(σ) + (v + 1) / 2 * (ψ((v + 1) / 2) - ψ(v / 2)) + ln(sqrt(v) * B(v / 2, 1 / /// 2)) /// ``` @@ -301,6 +342,7 @@ impl Distribution for StudentsT { + (self.freedom.sqrt() * beta::beta(self.freedom / 2.0, 0.5)).ln(); Some(result + shift) } + /// Returns the skewness of the student's t-distribution /// /// # None @@ -309,7 +351,7 @@ impl Distribution for StudentsT { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn skewness(&self) -> Option { @@ -326,7 +368,7 @@ impl Median for StudentsT { /// /// # Formula /// - /// ```ignore + /// ```text /// μ /// ``` /// @@ -341,7 +383,7 @@ impl Mode> for StudentsT { /// /// # Formula /// - /// ```ignore + /// ```text /// μ /// ``` /// @@ -358,7 +400,7 @@ impl Continuous for StudentsT { /// /// # Formula /// - /// ```ignore + /// ```text /// Γ((v + 1) / 2) / (sqrt(vπ) * Γ(v / 2) * σ) * (1 + k^2 / v)^(-1 / 2 * (v /// + 1)) /// ``` @@ -387,7 +429,7 @@ impl Continuous for StudentsT { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(Γ((v + 1) / 2) / (sqrt(vπ) * Γ(v / 2) * σ) * (1 + k^2 / v)^(-1 / 2 * /// (v + 1))) /// ``` @@ -411,23 +453,21 @@ impl Continuous for StudentsT { } } -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { + use super::*; use crate::consts::ACC; use crate::distribution::internal::*; - use crate::distribution::{Continuous, ContinuousCDF, StudentsT}; - use crate::statistics::*; use crate::testing_boiler; - use std::panic; - testing_boiler!((f64, f64, f64), StudentsT); + testing_boiler!(location: f64, scale: f64, freedom: f64; StudentsT; StudentsTError); #[test] fn test_create() { - try_create((0.0, 0.1, 1.0)); - try_create((0.0, 1.0, 1.0)); - try_create((-5.0, 1.0, 3.0)); - try_create((10.0, 10.0, f64::INFINITY)); + create_ok(0.0, 0.1, 1.0); + create_ok(0.0, 1.0, 1.0); + create_ok(-5.0, 1.0, 3.0); + create_ok(10.0, 10.0, f64::INFINITY); } // #[test] @@ -438,192 +478,191 @@ mod tests { #[test] fn test_bad_create() { - bad_create_case((f64::NAN, 1.0, 1.0)); - bad_create_case((0.0, f64::NAN, 1.0)); - bad_create_case((0.0, 1.0, f64::NAN)); - bad_create_case((0.0, -10.0, 1.0)); - bad_create_case((0.0, 10.0, -1.0)); + let invalid = [ + (f64::NAN, 1.0, 1.0, StudentsTError::LocationInvalid), + (0.0, f64::NAN, 1.0, StudentsTError::ScaleInvalid), + (0.0, 1.0, f64::NAN, StudentsTError::FreedomInvalid), + (0.0, -10.0, 1.0, StudentsTError::ScaleInvalid), + (0.0, 10.0, -1.0, StudentsTError::FreedomInvalid), + ]; + + for (l, s, f, err) in invalid { + test_create_err(l, s, f, err); + } } #[test] fn test_mean() { let mean = |x: StudentsT| x.mean().unwrap(); - test_case((0.0, 1.0, 3.0), 0.0, mean); - test_case((0.0, 10.0, 2.0), 0.0, mean); - test_case((0.0, 10.0, f64::INFINITY), 0.0, mean); - test_case((-5.0, 100.0, 1.5), -5.0, mean); + test_relative(0.0, 1.0, 3.0, 0.0, mean); + test_relative(0.0, 10.0, 2.0, 0.0, mean); + test_relative(0.0, 10.0, f64::INFINITY, 0.0, mean); + test_relative(-5.0, 100.0, 1.5, -5.0, mean); let mean = |x: StudentsT| x.mean(); - test_none((0.0, 1.0, 1.0), mean); - test_none((0.0, 0.1, 1.0), mean); - test_none((0.0, 10.0, 1.0), mean); - test_none((10.0, 1.0, 1.0), mean); - test_none((0.0, f64::INFINITY, 1.0), mean); + test_none(0.0, 1.0, 1.0, mean); + test_none(0.0, 0.1, 1.0, mean); + test_none(0.0, 10.0, 1.0, mean); + test_none(10.0, 1.0, 1.0, mean); + test_none(0.0, f64::INFINITY, 1.0, mean); } #[test] - #[should_panic] fn test_mean_freedom_lte_1() { - let mean = |x: StudentsT| x.mean().unwrap(); - get_value((1.0, 1.0, 0.5), mean); + test_none(1.0, 1.0, 0.5, |dist| dist.mean()); } #[test] fn test_variance() { let variance = |x: StudentsT| x.variance().unwrap(); - test_case((0.0, 1.0, 3.0), 3.0, variance); - test_case((0.0, 10.0, 2.5), 500.0, variance); - test_case((10.0, 1.0, 2.5), 5.0, variance); + test_relative(0.0, 1.0, 3.0, 3.0, variance); + test_relative(0.0, 10.0, 2.5, 500.0, variance); + test_relative(10.0, 1.0, 2.5, 5.0, variance); let variance = |x: StudentsT| x.variance(); - test_none((0.0, 10.0, 2.0), variance); - test_none((0.0, 1.0, 1.0), variance); - test_none((0.0, 0.1, 1.0), variance); - test_none((0.0, 10.0, 1.0), variance); - test_none((10.0, 1.0, 1.0), variance); - test_none((-5.0, 100.0, 1.5), variance); - test_none((0.0, f64::INFINITY, 1.0), variance); + test_none(0.0, 10.0, 2.0, variance); + test_none(0.0, 1.0, 1.0, variance); + test_none(0.0, 0.1, 1.0, variance); + test_none(0.0, 10.0, 1.0, variance); + test_none(10.0, 1.0, 1.0, variance); + test_none(-5.0, 100.0, 1.5, variance); + test_none(0.0, f64::INFINITY, 1.0, variance); } #[test] - #[should_panic] fn test_variance_freedom_lte1() { - let variance = |x: StudentsT| x.variance().unwrap(); - get_value((1.0, 1.0, 0.5), variance); + test_none(1.0, 1.0, 0.5, |dist| dist.variance()); } // TODO: valid skewness tests #[test] - #[should_panic] fn test_skewness_freedom_lte_3() { - let skewness = |x: StudentsT| x.skewness().unwrap(); - get_value((1.0, 1.0, 1.0), skewness); + test_none(1.0, 1.0, 1.0, |dist| dist.skewness()); } #[test] fn test_mode() { let mode = |x: StudentsT| x.mode().unwrap(); - test_case((0.0, 1.0, 1.0), 0.0, mode); - test_case((0.0, 0.1, 1.0), 0.0, mode); - test_case((0.0, 1.0, 3.0), 0.0, mode); - test_case((0.0, 10.0, 1.0), 0.0, mode); - test_case((0.0, 10.0, 2.0), 0.0, mode); - test_case((0.0, 10.0, 2.5), 0.0, mode); - test_case((0.0, 10.0, f64::INFINITY), 0.0, mode); - test_case((10.0, 1.0, 1.0), 10.0, mode); - test_case((10.0, 1.0, 2.5), 10.0, mode); - test_case((-5.0, 100.0, 1.5), -5.0, mode); - test_case((0.0, f64::INFINITY, 1.0), 0.0, mode); + test_relative(0.0, 1.0, 1.0, 0.0, mode); + test_relative(0.0, 0.1, 1.0, 0.0, mode); + test_relative(0.0, 1.0, 3.0, 0.0, mode); + test_relative(0.0, 10.0, 1.0, 0.0, mode); + test_relative(0.0, 10.0, 2.0, 0.0, mode); + test_relative(0.0, 10.0, 2.5, 0.0, mode); + test_relative(0.0, 10.0, f64::INFINITY, 0.0, mode); + test_relative(10.0, 1.0, 1.0, 10.0, mode); + test_relative(10.0, 1.0, 2.5, 10.0, mode); + test_relative(-5.0, 100.0, 1.5, -5.0, mode); + test_relative(0.0, f64::INFINITY, 1.0, 0.0, mode); } #[test] fn test_median() { let median = |x: StudentsT| x.median(); - test_case((0.0, 1.0, 1.0), 0.0, median); - test_case((0.0, 0.1, 1.0), 0.0, median); - test_case((0.0, 1.0, 3.0), 0.0, median); - test_case((0.0, 10.0, 1.0), 0.0, median); - test_case((0.0, 10.0, 2.0), 0.0, median); - test_case((0.0, 10.0, 2.5), 0.0, median); - test_case((0.0, 10.0, f64::INFINITY), 0.0, median); - test_case((10.0, 1.0, 1.0), 10.0, median); - test_case((10.0, 1.0, 2.5), 10.0, median); - test_case((-5.0, 100.0, 1.5), -5.0, median); - test_case((0.0, f64::INFINITY, 1.0), 0.0, median); + test_relative(0.0, 1.0, 1.0, 0.0, median); + test_relative(0.0, 0.1, 1.0, 0.0, median); + test_relative(0.0, 1.0, 3.0, 0.0, median); + test_relative(0.0, 10.0, 1.0, 0.0, median); + test_relative(0.0, 10.0, 2.0, 0.0, median); + test_relative(0.0, 10.0, 2.5, 0.0, median); + test_relative(0.0, 10.0, f64::INFINITY, 0.0, median); + test_relative(10.0, 1.0, 1.0, 10.0, median); + test_relative(10.0, 1.0, 2.5, 10.0, median); + test_relative(-5.0, 100.0, 1.5, -5.0, median); + test_relative(0.0, f64::INFINITY, 1.0, 0.0, median); } #[test] fn test_min_max() { let min = |x: StudentsT| x.min(); let max = |x: StudentsT| x.max(); - test_case((0.0, 1.0, 1.0), f64::NEG_INFINITY, min); - test_case((2.5, 100.0, 1.5), f64::NEG_INFINITY, min); - test_case((10.0, f64::INFINITY, 3.5), f64::NEG_INFINITY, min); - test_case((0.0, 1.0, 1.0), f64::INFINITY, max); - test_case((2.5, 100.0, 1.5), f64::INFINITY, max); - test_case((10.0, f64::INFINITY, 5.5), f64::INFINITY, max); + test_relative(0.0, 1.0, 1.0, f64::NEG_INFINITY, min); + test_relative(2.5, 100.0, 1.5, f64::NEG_INFINITY, min); + test_relative(10.0, f64::INFINITY, 3.5, f64::NEG_INFINITY, min); + test_relative(0.0, 1.0, 1.0, f64::INFINITY, max); + test_relative(2.5, 100.0, 1.5, f64::INFINITY, max); + test_relative(10.0, f64::INFINITY, 5.5, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: StudentsT| x.pdf(arg); - test_case((0.0, 1.0, 1.0), 0.318309886183791, pdf(0.0)); - test_case((0.0, 1.0, 1.0), 0.159154943091895, pdf(1.0)); - test_case((0.0, 1.0, 1.0), 0.159154943091895, pdf(-1.0)); - test_case((0.0, 1.0, 1.0), 0.063661977236758, pdf(2.0)); - test_case((0.0, 1.0, 1.0), 0.063661977236758, pdf(-2.0)); - test_case((0.0, 1.0, 2.0), 0.353553390593274, pdf(0.0)); - test_case((0.0, 1.0, 2.0), 0.192450089729875, pdf(1.0)); - test_case((0.0, 1.0, 2.0), 0.192450089729875, pdf(-1.0)); - test_case((0.0, 1.0, 2.0), 0.068041381743977, pdf(2.0)); - test_case((0.0, 1.0, 2.0), 0.068041381743977, pdf(-2.0)); - test_case((0.0, 1.0, f64::INFINITY), 0.398942280401433, pdf(0.0)); - test_case((0.0, 1.0, f64::INFINITY), 0.241970724519143, pdf(1.0)); - test_case((0.0, 1.0, f64::INFINITY), 0.053990966513188, pdf(2.0)); + test_relative(0.0, 1.0, 1.0, std::f64::consts::FRAC_1_PI, pdf(0.0)); + test_relative(0.0, 1.0, 1.0, 0.159154943091895, pdf(1.0)); + test_relative(0.0, 1.0, 1.0, 0.159154943091895, pdf(-1.0)); + test_relative(0.0, 1.0, 1.0, 0.063661977236758, pdf(2.0)); + test_relative(0.0, 1.0, 1.0, 0.063661977236758, pdf(-2.0)); + test_relative(0.0, 1.0, 2.0, 0.353553390593274, pdf(0.0)); + test_relative(0.0, 1.0, 2.0, 0.192450089729875, pdf(1.0)); + test_relative(0.0, 1.0, 2.0, 0.192450089729875, pdf(-1.0)); + test_relative(0.0, 1.0, 2.0, 0.068041381743977, pdf(2.0)); + test_relative(0.0, 1.0, 2.0, 0.068041381743977, pdf(-2.0)); + test_relative(0.0, 1.0, f64::INFINITY, 0.398942280401433, pdf(0.0)); + test_relative(0.0, 1.0, f64::INFINITY, 0.241970724519143, pdf(1.0)); + test_relative(0.0, 1.0, f64::INFINITY, 0.053990966513188, pdf(2.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: StudentsT| x.ln_pdf(arg); - test_case((0.0, 1.0, 1.0), -1.144729885849399, ln_pdf(0.0)); - test_case((0.0, 1.0, 1.0), -1.837877066409348, ln_pdf(1.0)); - test_case((0.0, 1.0, 1.0), -1.837877066409348, ln_pdf(-1.0)); - test_case((0.0, 1.0, 1.0), -2.754167798283503, ln_pdf(2.0)); - test_case((0.0, 1.0, 1.0), -2.754167798283503, ln_pdf(-2.0)); - test_case((0.0, 1.0, 2.0), -1.039720770839917, ln_pdf(0.0)); - test_case((0.0, 1.0, 2.0), -1.647918433002166, ln_pdf(1.0)); - test_case((0.0, 1.0, 2.0), -1.647918433002166, ln_pdf(-1.0)); - test_case((0.0, 1.0, 2.0), -2.687639203842085, ln_pdf(2.0)); - test_case((0.0, 1.0, 2.0), -2.687639203842085, ln_pdf(-2.0)); - test_case((0.0, 1.0, f64::INFINITY), -0.918938533204672, ln_pdf(0.0)); - test_case((0.0, 1.0, f64::INFINITY), -1.418938533204674, ln_pdf(1.0)); - test_case((0.0, 1.0, f64::INFINITY), -2.918938533204674, ln_pdf(2.0)); + test_relative(0.0, 1.0, 1.0, -1.144729885849399, ln_pdf(0.0)); + test_relative(0.0, 1.0, 1.0, -1.837877066409348, ln_pdf(1.0)); + test_relative(0.0, 1.0, 1.0, -1.837877066409348, ln_pdf(-1.0)); + test_relative(0.0, 1.0, 1.0, -2.754167798283503, ln_pdf(2.0)); + test_relative(0.0, 1.0, 1.0, -2.754167798283503, ln_pdf(-2.0)); + test_relative(0.0, 1.0, 2.0, -1.039720770839917, ln_pdf(0.0)); + test_relative(0.0, 1.0, 2.0, -1.647918433002166, ln_pdf(1.0)); + test_relative(0.0, 1.0, 2.0, -1.647918433002166, ln_pdf(-1.0)); + test_relative(0.0, 1.0, 2.0, -2.687639203842085, ln_pdf(2.0)); + test_relative(0.0, 1.0, 2.0, -2.687639203842085, ln_pdf(-2.0)); + test_relative(0.0, 1.0, f64::INFINITY, -0.918938533204672, ln_pdf(0.0)); + test_relative(0.0, 1.0, f64::INFINITY, -1.418938533204674, ln_pdf(1.0)); + test_relative(0.0, 1.0, f64::INFINITY, -2.918938533204674, ln_pdf(2.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: StudentsT| x.cdf(arg); - test_case((0.0, 1.0, 1.0), 0.5, cdf(0.0)); - test_case((0.0, 1.0, 1.0), 0.75, cdf(1.0)); - test_case((0.0, 1.0, 1.0), 0.25, cdf(-1.0)); - test_case((0.0, 1.0, 1.0), 0.852416382349567, cdf(2.0)); - test_case((0.0, 1.0, 1.0), 0.147583617650433, cdf(-2.0)); - test_case((0.0, 1.0, 2.0), 0.5, cdf(0.0)); - test_case((0.0, 1.0, 2.0), 0.788675134594813, cdf(1.0)); - test_case((0.0, 1.0, 2.0), 0.211324865405187, cdf(-1.0)); - test_case((0.0, 1.0, 2.0), 0.908248290463863, cdf(2.0)); - test_case((0.0, 1.0, 2.0), 0.091751709536137, cdf(-2.0)); - test_case((0.0, 1.0, f64::INFINITY), 0.5, cdf(0.0)); + test_relative(0.0, 1.0, 1.0, 0.5, cdf(0.0)); + test_relative(0.0, 1.0, 1.0, 0.75, cdf(1.0)); + test_relative(0.0, 1.0, 1.0, 0.25, cdf(-1.0)); + test_relative(0.0, 1.0, 1.0, 0.852416382349567, cdf(2.0)); + test_relative(0.0, 1.0, 1.0, 0.147583617650433, cdf(-2.0)); + test_relative(0.0, 1.0, 2.0, 0.5, cdf(0.0)); + test_relative(0.0, 1.0, 2.0, 0.788675134594813, cdf(1.0)); + test_relative(0.0, 1.0, 2.0, 0.211324865405187, cdf(-1.0)); + test_relative(0.0, 1.0, 2.0, 0.908248290463863, cdf(2.0)); + test_relative(0.0, 1.0, 2.0, 0.091751709536137, cdf(-2.0)); + test_relative(0.0, 1.0, f64::INFINITY, 0.5, cdf(0.0)); // TODO: these are curiously low accuracy and should be re-examined - test_case((0.0, 1.0, f64::INFINITY), 0.841344746068543, cdf(1.0)); - test_case((0.0, 1.0, f64::INFINITY), 0.977249868051821, cdf(2.0)); + test_relative(0.0, 1.0, f64::INFINITY, 0.841344746068543, cdf(1.0)); + test_relative(0.0, 1.0, f64::INFINITY, 0.977249868051821, cdf(2.0)); } - #[test] fn test_sf() { let sf = |arg: f64| move |x: StudentsT| x.sf(arg); - test_case((0.0, 1.0, 1.0), 0.5, sf(0.0)); - test_case((0.0, 1.0, 1.0), 0.25, sf(1.0)); - test_case((0.0, 1.0, 1.0), 0.75, sf(-1.0)); - test_case((0.0, 1.0, 1.0), 0.147583617650433, sf(2.0)); - test_case((0.0, 1.0, 1.0), 0.852416382349566, sf(-2.0)); - test_case((0.0, 1.0, 2.0), 0.5, sf(0.0)); - test_case((0.0, 1.0, 2.0), 0.211324865405186, sf(1.0)); - test_case((0.0, 1.0, 2.0), 0.788675134594813, sf(-1.0)); - test_case((0.0, 1.0, 2.0), 0.091751709536137, sf(2.0)); - test_case((0.0, 1.0, 2.0), 0.908248290463862, sf(-2.0)); - test_case((0.0, 1.0, f64::INFINITY), 0.5, sf(0.0)); + test_relative(0.0, 1.0, 1.0, 0.5, sf(0.0)); + test_relative(0.0, 1.0, 1.0, 0.25, sf(1.0)); + test_relative(0.0, 1.0, 1.0, 0.75, sf(-1.0)); + test_relative(0.0, 1.0, 1.0, 0.147583617650433, sf(2.0)); + test_relative(0.0, 1.0, 1.0, 0.852416382349566, sf(-2.0)); + test_relative(0.0, 1.0, 2.0, 0.5, sf(0.0)); + test_relative(0.0, 1.0, 2.0, 0.211324865405186, sf(1.0)); + test_relative(0.0, 1.0, 2.0, 0.788675134594813, sf(-1.0)); + test_relative(0.0, 1.0, 2.0, 0.091751709536137, sf(2.0)); + test_relative(0.0, 1.0, 2.0, 0.908248290463862, sf(-2.0)); + test_relative(0.0, 1.0, f64::INFINITY, 0.5, sf(0.0)); // TODO: these are curiously low accuracy and should be re-examined - test_case((0.0, 1.0, f64::INFINITY), 0.158655253945057, sf(1.0)); - test_case((0.0, 1.0, f64::INFINITY), 0.022750131947162, sf(2.0)); + test_relative(0.0, 1.0, f64::INFINITY, 0.158655253945057, sf(1.0)); + test_relative(0.0, 1.0, f64::INFINITY, 0.022750131947162, sf(2.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create((0.0, 1.0, 3.0)), -30.0, 30.0); - test::check_continuous_distribution(&try_create((0.0, 1.0, 10.0)), -10.0, 10.0); - test::check_continuous_distribution(&try_create((20.0, 0.5, 10.0)), 10.0, 30.0); + test::check_continuous_distribution(&create_ok(0.0, 1.0, 3.0), -30.0, 30.0); + test::check_continuous_distribution(&create_ok(0.0, 1.0, 10.0), -10.0, 10.0); + test::check_continuous_distribution(&create_ok(20.0, 0.5, 10.0), 10.0, 30.0); } #[test] @@ -766,6 +805,8 @@ mod tests { test(0.9, 011.0, 1.363); test(0.95, 011.0, 1.796); test(0.975, 011.0, 2.201); + // 2.718 is roughly equal to E + #[allow(clippy::approx_constant)] test(0.99, 011.0, 2.718); test(0.995, 011.0, 3.106); test(0.9975, 011.0, 3.497); @@ -1096,7 +1137,7 @@ mod tests { // for p in ps: // q = t.invcdf(p, df) // print(f"({p:5.3f}, {df:5.1f}, {float(q)}),") - // + #[rustfmt::skip] let invcdf_data = [ // p df inverse_cdf(p, df) (0.001, 1.0, -318.30883898555044), @@ -1150,12 +1191,12 @@ mod tests { #[test] fn test_inv_cdf_p0() { let d = StudentsT::new(0.0, 1.0, 12.0).unwrap(); - assert_eq!(d.inverse_cdf(0.0), std::f64::NEG_INFINITY); + assert_eq!(d.inverse_cdf(0.0), f64::NEG_INFINITY); } #[test] fn test_inv_cdf_p1() { let d = StudentsT::new(0.0, 1.0, 12.0).unwrap(); - assert_eq!(d.inverse_cdf(1.0), std::f64::INFINITY); + assert_eq!(d.inverse_cdf(1.0), f64::INFINITY); } } diff --git a/src/distribution/triangular.rs b/src/distribution/triangular.rs index 068f6c9a..1a83be2c 100644 --- a/src/distribution/triangular.rs +++ b/src/distribution/triangular.rs @@ -1,7 +1,5 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; -use rand::Rng; use std::f64; /// Implements the @@ -18,13 +16,50 @@ use std::f64; /// assert_eq!(n.mean().unwrap(), 7.5 / 3.0); /// assert_eq!(n.pdf(2.5), 5.0 / 12.5); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Triangular { min: f64, max: f64, mode: f64, } +/// Represents the errors that can occur when creating a [`Triangular`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum TriangularError { + /// The minimum is NaN or infinite. + MinInvalid, + + /// The maximum is NaN or infinite. + MaxInvalid, + + /// The mode is NaN or infinite. + ModeInvalid, + + /// The mode is less than the minimum or greater than the maximum. + ModeOutOfRange, + + /// The minimum equals the maximum. + MinEqualsMax, +} + +impl std::fmt::Display for TriangularError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + TriangularError::MinInvalid => write!(f, "Minimum is NaN or infinite."), + TriangularError::MaxInvalid => write!(f, "Maximum is NaN or infinite."), + TriangularError::ModeInvalid => write!(f, "Mode is NaN or infinite."), + TriangularError::ModeOutOfRange => { + write!(f, "Mode is less than minimum or greater than maximum") + } + TriangularError::MinEqualsMax => write!(f, "Minimum equals Maximum"), + } + } +} + +impl std::error::Error for TriangularError {} + impl Triangular { /// Constructs a new triangular distribution with a minimum of `min`, /// maximum of `max`, and a mode of `mode`. @@ -45,22 +80,40 @@ impl Triangular { /// result = Triangular::new(2.5, 1.5, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(min: f64, max: f64, mode: f64) -> Result { - if !min.is_finite() || !max.is_finite() || !mode.is_finite() { - return Err(StatsError::BadParams); + pub fn new(min: f64, max: f64, mode: f64) -> Result { + if !min.is_finite() { + return Err(TriangularError::MinInvalid); + } + + if !max.is_finite() { + return Err(TriangularError::MaxInvalid); + } + + if !mode.is_finite() { + return Err(TriangularError::ModeInvalid); } + if max < mode || mode < min { - return Err(StatsError::BadParams); + return Err(TriangularError::ModeOutOfRange); } - if ulps_eq!(max, min, max_ulps = 0) { - return Err(StatsError::BadParams); + + if min == max { + return Err(TriangularError::MinEqualsMax); } + Ok(Triangular { min, max, mode }) } } +impl std::fmt::Display for Triangular { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Triangular([{},{}], {})", self.min, self.max, self.mode) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Triangular { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { sample_unchecked(rng, self.min, self.max, self.mode) } } @@ -72,13 +125,13 @@ impl ContinuousCDF for Triangular { /// /// # Formula /// - /// ```ignore + /// ```text /// if x == min { /// 0 /// } if min < x <= mode { /// (x - min)^2 / ((max - min) * (mode - min)) /// } else if mode < x < max { - /// 1 - (max - min)^2 / ((max - min) * (max - mode)) + /// 1 - (max - x)^2 / ((max - min) * (max - mode)) /// } else { /// 1 /// } @@ -103,7 +156,7 @@ impl ContinuousCDF for Triangular { /// /// # Formula /// - /// ```ignore + /// ```text /// if x == min { /// 1 /// } if min < x <= mode { @@ -128,6 +181,34 @@ impl ContinuousCDF for Triangular { 0.0 } } + + /// Calculates the inverse cumulative distribution function for the triangular + /// distribution + /// at `x` + /// + /// # Formula + /// + /// ```text + /// if x < (mode - min) / (max - min) { + /// min + ((max - min) * (mode - min) * x)^(1 / 2) + /// } else { + /// max - ((max - min) * (max - mode) * (1 - x))^(1 / 2) + /// } + /// ``` + fn inverse_cdf(&self, p: f64) -> f64 { + let a = self.min; + let b = self.max; + let c = self.mode; + if !(0.0..=1.0).contains(&p) { + panic!("x must be in [0, 1]"); + } + + if p < (c - a) / (b - a) { + a + ((c - a) * (b - a) * p).sqrt() + } else { + b - ((b - a) * (b - c) * (1.0 - p)).sqrt() + } + } } impl Min for Triangular { @@ -159,17 +240,18 @@ impl Distribution for Triangular { /// /// # Formula /// - /// ```ignore + /// ```text /// (min + max + mode) / 3 /// ``` fn mean(&self) -> Option { Some((self.min + self.max + self.mode) / 3.0) } + /// Returns the variance of the triangular distribution /// /// # Formula /// - /// ```ignore + /// ```text /// (min^2 + max^2 + mode^2 - min * max - min * mode - max * mode) / 18 /// ``` fn variance(&self) -> Option { @@ -178,21 +260,23 @@ impl Distribution for Triangular { let c = self.mode; Some((a * a + b * b + c * c - a * b - a * c - b * c) / 18.0) } + /// Returns the entropy of the triangular distribution /// /// # Formula /// - /// ```ignore + /// ```text /// 1 / 2 + ln((max - min) / 2) /// ``` fn entropy(&self) -> Option { Some(0.5 + ((self.max - self.min) / 2.0).ln()) } + /// Returns the skewness of the triangular distribution /// /// # Formula /// - /// ```ignore + /// ```text /// (sqrt(2) * (min + max - 2 * mode) * (2 * min - max - mode) * (min - 2 * /// max + mode)) / /// ( 5 * (min^2 + max^2 + mode^2 - min * max - min * mode - max * mode)^(3 @@ -213,7 +297,7 @@ impl Median for Triangular { /// /// # Formula /// - /// ```ignore + /// ```text /// if mode >= (min + max) / 2 { /// min + sqrt((max - min) * (mode - min) / 2) /// } else { @@ -237,7 +321,7 @@ impl Mode> for Triangular { /// /// # Formula /// - /// ```ignore + /// ```text /// mode /// ``` fn mode(&self) -> Option { @@ -252,7 +336,7 @@ impl Continuous for Triangular { /// /// # Formula /// - /// ```ignore + /// ```text /// if x < min { /// 0 /// } else if min <= x <= mode { @@ -282,7 +366,7 @@ impl Continuous for Triangular { /// /// # Formula /// - /// ```ignore + /// ```text /// ln( if x < min { /// 0 /// } else if min <= x <= mode { @@ -298,7 +382,8 @@ impl Continuous for Triangular { } } -fn sample_unchecked(rng: &mut R, min: f64, max: f64, mode: f64) -> f64 { +#[cfg(feature = "rand")] +fn sample_unchecked(rng: &mut R, min: f64, max: f64, mode: f64) -> f64 { let f: f64 = rng.gen(); if f < (mode - min) / (max - min) { min + (f * (max - min) * (mode - min)).sqrt() @@ -308,231 +393,211 @@ fn sample_unchecked(rng: &mut R, min: f64, max: f64, mode: f64) } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use std::fmt::Debug; - use crate::statistics::*; - use crate::distribution::{ContinuousCDF, Continuous, Triangular}; + use super::*; use crate::distribution::internal::*; - use crate::consts::ACC; - - fn try_create(min: f64, max: f64, mode: f64) -> Triangular { - let n = Triangular::new(min, max, mode); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(min: f64, max: f64, mode: f64) { - let n = try_create(min, max, mode); - assert_eq!(n.min(), min); - assert_eq!(n.max(), max); - assert_eq!(n.mode().unwrap(), mode); - } - - fn bad_create_case(min: f64, max: f64, mode: f64) { - let n = Triangular::new(min, max, mode); - assert!(n.is_err()); - } - - fn get_value(min: f64, max: f64, mode: f64, eval: F) -> T - where T: PartialEq + Debug, - F: Fn(Triangular) -> T - { - let n = try_create(min, max, mode); - eval(n) - } + use crate::testing_boiler; - fn test_case(min: f64, max: f64, mode: f64, expected: f64, eval: F) - where F: Fn(Triangular) -> f64 - { - let x = get_value(min, max, mode, eval); - assert_eq!(expected, x); - } - - fn test_almost(min: f64, max: f64, mode: f64, expected: f64, acc: f64, eval: F) - where F: Fn(Triangular) -> f64 - { - let x = get_value(min, max, mode, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(min: f64, max: f64, mode: f64; Triangular; TriangularError); #[test] fn test_create() { - create_case(-1.0, 1.0, 0.0); - create_case(1.0, 2.0, 1.0); - create_case(5.0, 25.0, 25.0); - create_case(1.0e-5, 1.0e5, 1.0e-3); - create_case(0.0, 1.0, 0.9); - create_case(-4.0, -0.5, -2.0); - create_case(-13.039, 8.42, 1.17); + create_ok(-1.0, 1.0, 0.0); + create_ok(1.0, 2.0, 1.0); + create_ok(5.0, 25.0, 25.0); + create_ok(1.0e-5, 1.0e5, 1.0e-3); + create_ok(0.0, 1.0, 0.9); + create_ok(-4.0, -0.5, -2.0); + create_ok(-13.039, 8.42, 1.17); } #[test] fn test_bad_create() { - bad_create_case(0.0, 0.0, 0.0); - bad_create_case(0.0, 1.0, -0.1); - bad_create_case(0.0, 1.0, 1.1); - bad_create_case(0.0, -1.0, 0.5); - bad_create_case(2.0, 1.0, 1.5); - bad_create_case(f64::NAN, 1.0, 0.5); - bad_create_case(0.2, f64::NAN, 0.5); - bad_create_case(0.5, 1.0, f64::NAN); - bad_create_case(f64::NAN, f64::NAN, f64::NAN); - bad_create_case(f64::NEG_INFINITY, 1.0, 0.5); - bad_create_case(0.0, f64::INFINITY, 0.5); + let invalid = [ + (0.0, 0.0, 0.0, TriangularError::MinEqualsMax), + (0.0, 1.0, -0.1, TriangularError::ModeOutOfRange), + (0.0, 1.0, 1.1, TriangularError::ModeOutOfRange), + (0.0, -1.0, 0.5, TriangularError::ModeOutOfRange), + (2.0, 1.0, 1.5, TriangularError::ModeOutOfRange), + (f64::NAN, 1.0, 0.5, TriangularError::MinInvalid), + (0.2, f64::NAN, 0.5, TriangularError::MaxInvalid), + (0.5, 1.0, f64::NAN, TriangularError::ModeInvalid), + (f64::NAN, f64::NAN, f64::NAN, TriangularError::MinInvalid), + (f64::NEG_INFINITY, 1.0, 0.5, TriangularError::MinInvalid), + (0.0, f64::INFINITY, 0.5, TriangularError::MaxInvalid), + ]; + + for (min, max, mode, err) in invalid { + test_create_err(min, max, mode, err); + } } #[test] fn test_variance() { let variance = |x: Triangular| x.variance().unwrap(); - test_case(0.0, 1.0, 0.5, 0.75 / 18.0, variance); - test_case(0.0, 1.0, 0.75, 0.8125 / 18.0, variance); - test_case(-5.0, 8.0, -3.5, 151.75 / 18.0, variance); - test_case(-5.0, 8.0, 5.0, 139.0 / 18.0, variance); - test_case(-5.0, -3.0, -4.0, 3.0 / 18.0, variance); - test_case(15.0, 134.0, 21.0, 13483.0 / 18.0, variance); + test_exact(0.0, 1.0, 0.5, 0.75 / 18.0, variance); + test_exact(0.0, 1.0, 0.75, 0.8125 / 18.0, variance); + test_exact(-5.0, 8.0, -3.5, 151.75 / 18.0, variance); + test_exact(-5.0, 8.0, 5.0, 139.0 / 18.0, variance); + test_exact(-5.0, -3.0, -4.0, 3.0 / 18.0, variance); + test_exact(15.0, 134.0, 21.0, 13483.0 / 18.0, variance); } #[test] fn test_entropy() { let entropy = |x: Triangular| x.entropy().unwrap(); - test_almost(0.0, 1.0, 0.5, -0.1931471805599453094172, 1e-16, entropy); - test_almost(0.0, 1.0, 0.75, -0.1931471805599453094172, 1e-16, entropy); - test_case(-5.0, 8.0, -3.5, 2.371802176901591426636, entropy); - test_case(-5.0, 8.0, 5.0, 2.371802176901591426636, entropy); - test_case(-5.0, -3.0, -4.0, 0.5, entropy); - test_case(15.0, 134.0, 21.0, 4.585976312551584075938, entropy); + test_absolute(0.0, 1.0, 0.5, -0.1931471805599453094172, 1e-16, entropy); + test_absolute(0.0, 1.0, 0.75, -0.1931471805599453094172, 1e-16, entropy); + test_exact(-5.0, 8.0, -3.5, 2.371802176901591426636, entropy); + test_exact(-5.0, 8.0, 5.0, 2.371802176901591426636, entropy); + test_exact(-5.0, -3.0, -4.0, 0.5, entropy); + test_exact(15.0, 134.0, 21.0, 4.585976312551584075938, entropy); } #[test] fn test_skewness() { let skewness = |x: Triangular| x.skewness().unwrap(); - test_case(0.0, 1.0, 0.5, 0.0, skewness); - test_case(0.0, 1.0, 0.75, -0.4224039833745502226059, skewness); - test_case(-5.0, 8.0, -3.5, 0.5375093589712976359809, skewness); - test_case(-5.0, 8.0, 5.0, -0.4445991743012595633537, skewness); - test_case(-5.0, -3.0, -4.0, 0.0, skewness); - test_case(15.0, 134.0, 21.0, 0.5605920922751860613217, skewness); + test_exact(0.0, 1.0, 0.5, 0.0, skewness); + test_exact(0.0, 1.0, 0.75, -0.4224039833745502226059, skewness); + test_exact(-5.0, 8.0, -3.5, 0.5375093589712976359809, skewness); + test_exact(-5.0, 8.0, 5.0, -0.4445991743012595633537, skewness); + test_exact(-5.0, -3.0, -4.0, 0.0, skewness); + test_exact(15.0, 134.0, 21.0, 0.5605920922751860613217, skewness); } #[test] fn test_mode() { let mode = |x: Triangular| x.mode().unwrap(); - test_case(0.0, 1.0, 0.5, 0.5, mode); - test_case(0.0, 1.0, 0.75, 0.75, mode); - test_case(-5.0, 8.0, -3.5, -3.5, mode); - test_case(-5.0, 8.0, 5.0, 5.0, mode); - test_case(-5.0, -3.0, -4.0, -4.0, mode); - test_case(15.0, 134.0, 21.0, 21.0, mode); + test_exact(0.0, 1.0, 0.5, 0.5, mode); + test_exact(0.0, 1.0, 0.75, 0.75, mode); + test_exact(-5.0, 8.0, -3.5, -3.5, mode); + test_exact(-5.0, 8.0, 5.0, 5.0, mode); + test_exact(-5.0, -3.0, -4.0, -4.0, mode); + test_exact(15.0, 134.0, 21.0, 21.0, mode); } #[test] fn test_median() { let median = |x: Triangular| x.median(); - test_case(0.0, 1.0, 0.5, 0.5, median); - test_case(0.0, 1.0, 0.75, 0.6123724356957945245493, median); - test_almost(-5.0, 8.0, -3.5, -0.6458082328952913226724, 1e-15, median); - test_almost(-5.0, 8.0, 5.0, 3.062257748298549652367, 1e-15, median); - test_case(-5.0, -3.0, -4.0, -4.0, median); - test_almost(15.0, 134.0, 21.0, 52.00304883716712238797, 1e-14, median); + test_exact(0.0, 1.0, 0.5, 0.5, median); + test_exact(0.0, 1.0, 0.75, 0.6123724356957945245493, median); + test_absolute(-5.0, 8.0, -3.5, -0.6458082328952913226724, 1e-15, median); + test_absolute(-5.0, 8.0, 5.0, 3.062257748298549652367, 1e-15, median); + test_exact(-5.0, -3.0, -4.0, -4.0, median); + test_absolute(15.0, 134.0, 21.0, 52.00304883716712238797, 1e-14, median); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Triangular| x.pdf(arg); - test_case(0.0, 1.0, 0.5, 0.0, pdf(-1.0)); - test_case(0.0, 1.0, 0.5, 0.0, pdf(1.1)); - test_case(0.0, 1.0, 0.5, 1.0, pdf(0.25)); - test_case(0.0, 1.0, 0.5, 2.0, pdf(0.5)); - test_case(0.0, 1.0, 0.5, 1.0, pdf(0.75)); - test_case(-5.0, 8.0, -3.5, 0.0, pdf(-5.1)); - test_case(-5.0, 8.0, -3.5, 0.0, pdf(8.1)); - test_case(-5.0, 8.0, -3.5, 0.1025641025641025641026, pdf(-4.0)); - test_case(-5.0, 8.0, -3.5, 0.1538461538461538461538, pdf(-3.5)); - test_case(-5.0, 8.0, -3.5, 0.05351170568561872909699, pdf(4.0)); - test_case(-5.0, -3.0, -4.0, 0.0, pdf(-5.1)); - test_case(-5.0, -3.0, -4.0, 0.0, pdf(-2.9)); - test_case(-5.0, -3.0, -4.0, 0.5, pdf(-4.5)); - test_case(-5.0, -3.0, -4.0, 1.0, pdf(-4.0)); - test_case(-5.0, -3.0, -4.0, 0.5, pdf(-3.5)); + test_exact(0.0, 1.0, 0.5, 0.0, pdf(-1.0)); + test_exact(0.0, 1.0, 0.5, 0.0, pdf(1.1)); + test_exact(0.0, 1.0, 0.5, 1.0, pdf(0.25)); + test_exact(0.0, 1.0, 0.5, 2.0, pdf(0.5)); + test_exact(0.0, 1.0, 0.5, 1.0, pdf(0.75)); + test_exact(-5.0, 8.0, -3.5, 0.0, pdf(-5.1)); + test_exact(-5.0, 8.0, -3.5, 0.0, pdf(8.1)); + test_exact(-5.0, 8.0, -3.5, 0.1025641025641025641026, pdf(-4.0)); + test_exact(-5.0, 8.0, -3.5, 0.1538461538461538461538, pdf(-3.5)); + test_exact(-5.0, 8.0, -3.5, 0.05351170568561872909699, pdf(4.0)); + test_exact(-5.0, -3.0, -4.0, 0.0, pdf(-5.1)); + test_exact(-5.0, -3.0, -4.0, 0.0, pdf(-2.9)); + test_exact(-5.0, -3.0, -4.0, 0.5, pdf(-4.5)); + test_exact(-5.0, -3.0, -4.0, 1.0, pdf(-4.0)); + test_exact(-5.0, -3.0, -4.0, 0.5, pdf(-3.5)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Triangular| x.ln_pdf(arg); - test_case(0.0, 1.0, 0.5, f64::NEG_INFINITY, ln_pdf(-1.0)); - test_case(0.0, 1.0, 0.5, f64::NEG_INFINITY, ln_pdf(1.1)); - test_case(0.0, 1.0, 0.5, 0.0, ln_pdf(0.25)); - test_case(0.0, 1.0, 0.5, 2f64.ln(), ln_pdf(0.5)); - test_case(0.0, 1.0, 0.5, 0.0, ln_pdf(0.75)); - test_case(-5.0, 8.0, -3.5, f64::NEG_INFINITY, ln_pdf(-5.1)); - test_case(-5.0, 8.0, -3.5, f64::NEG_INFINITY, ln_pdf(8.1)); - test_case(-5.0, 8.0, -3.5, 0.1025641025641025641026f64.ln(), ln_pdf(-4.0)); - test_case(-5.0, 8.0, -3.5, 0.1538461538461538461538f64.ln(), ln_pdf(-3.5)); - test_case(-5.0, 8.0, -3.5, 0.05351170568561872909699f64.ln(), ln_pdf(4.0)); - test_case(-5.0, -3.0, -4.0, f64::NEG_INFINITY, ln_pdf(-5.1)); - test_case(-5.0, -3.0, -4.0, f64::NEG_INFINITY, ln_pdf(-2.9)); - test_case(-5.0, -3.0, -4.0, 0.5f64.ln(), ln_pdf(-4.5)); - test_case(-5.0, -3.0, -4.0, 0.0, ln_pdf(-4.0)); - test_case(-5.0, -3.0, -4.0, 0.5f64.ln(), ln_pdf(-3.5)); + test_exact(0.0, 1.0, 0.5, f64::NEG_INFINITY, ln_pdf(-1.0)); + test_exact(0.0, 1.0, 0.5, f64::NEG_INFINITY, ln_pdf(1.1)); + test_exact(0.0, 1.0, 0.5, 0.0, ln_pdf(0.25)); + test_exact(0.0, 1.0, 0.5, 2f64.ln(), ln_pdf(0.5)); + test_exact(0.0, 1.0, 0.5, 0.0, ln_pdf(0.75)); + test_exact(-5.0, 8.0, -3.5, f64::NEG_INFINITY, ln_pdf(-5.1)); + test_exact(-5.0, 8.0, -3.5, f64::NEG_INFINITY, ln_pdf(8.1)); + test_exact(-5.0, 8.0, -3.5, 0.1025641025641025641026f64.ln(), ln_pdf(-4.0)); + test_exact(-5.0, 8.0, -3.5, 0.1538461538461538461538f64.ln(), ln_pdf(-3.5)); + test_exact(-5.0, 8.0, -3.5, 0.05351170568561872909699f64.ln(), ln_pdf(4.0)); + test_exact(-5.0, -3.0, -4.0, f64::NEG_INFINITY, ln_pdf(-5.1)); + test_exact(-5.0, -3.0, -4.0, f64::NEG_INFINITY, ln_pdf(-2.9)); + test_exact(-5.0, -3.0, -4.0, 0.5f64.ln(), ln_pdf(-4.5)); + test_exact(-5.0, -3.0, -4.0, 0.0, ln_pdf(-4.0)); + test_exact(-5.0, -3.0, -4.0, 0.5f64.ln(), ln_pdf(-3.5)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Triangular| x.cdf(arg); - test_case(0.0, 1.0, 0.5, 0.125, cdf(0.25)); - test_case(0.0, 1.0, 0.5, 0.5, cdf(0.5)); - test_case(0.0, 1.0, 0.5, 0.875, cdf(0.75)); - test_case(-5.0, 8.0, -3.5, 0.05128205128205128205128, cdf(-4.0)); - test_case(-5.0, 8.0, -3.5, 0.1153846153846153846154, cdf(-3.5)); - test_case(-5.0, 8.0, -3.5, 0.892976588628762541806, cdf(4.0)); - test_case(-5.0, -3.0, -4.0, 0.125, cdf(-4.5)); - test_case(-5.0, -3.0, -4.0, 0.5, cdf(-4.0)); - test_case(-5.0, -3.0, -4.0, 0.875, cdf(-3.5)); + test_exact(0.0, 1.0, 0.5, 0.125, cdf(0.25)); + test_exact(0.0, 1.0, 0.5, 0.5, cdf(0.5)); + test_exact(0.0, 1.0, 0.5, 0.875, cdf(0.75)); + test_exact(-5.0, 8.0, -3.5, 0.05128205128205128205128, cdf(-4.0)); + test_exact(-5.0, 8.0, -3.5, 0.1153846153846153846154, cdf(-3.5)); + test_exact(-5.0, 8.0, -3.5, 0.892976588628762541806, cdf(4.0)); + test_exact(-5.0, -3.0, -4.0, 0.125, cdf(-4.5)); + test_exact(-5.0, -3.0, -4.0, 0.5, cdf(-4.0)); + test_exact(-5.0, -3.0, -4.0, 0.875, cdf(-3.5)); } #[test] fn test_cdf_lower_bound() { let cdf = |arg: f64| move |x: Triangular| x.cdf(arg); - test_case(0.0, 3.0, 1.5, 0.0, cdf(-1.0)); + test_exact(0.0, 3.0, 1.5, 0.0, cdf(-1.0)); } #[test] fn test_cdf_upper_bound() { let cdf = |arg: f64| move |x: Triangular| x.cdf(arg); - test_case(0.0, 3.0, 1.5, 1.0, cdf(5.0)); + test_exact(0.0, 3.0, 1.5, 1.0, cdf(5.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Triangular| x.sf(arg); - test_case(0.0, 1.0, 0.5, 0.875, sf(0.25)); - test_case(0.0, 1.0, 0.5, 0.5, sf(0.5)); - test_case(0.0, 1.0, 0.5, 0.125, sf(0.75)); - test_case(-5.0, 8.0, -3.5, 0.9487179487179487, sf(-4.0)); - test_case(-5.0, 8.0, -3.5, 0.8846153846153846, sf(-3.5)); - test_case(-5.0, 8.0, -3.5, 0.10702341137123746, sf(4.0)); - test_case(-5.0, -3.0, -4.0, 0.875, sf(-4.5)); - test_case(-5.0, -3.0, -4.0, 0.5, sf(-4.0)); - test_case(-5.0, -3.0, -4.0, 0.125, sf(-3.5)); + test_exact(0.0, 1.0, 0.5, 0.875, sf(0.25)); + test_exact(0.0, 1.0, 0.5, 0.5, sf(0.5)); + test_exact(0.0, 1.0, 0.5, 0.125, sf(0.75)); + test_exact(-5.0, 8.0, -3.5, 0.9487179487179487, sf(-4.0)); + test_exact(-5.0, 8.0, -3.5, 0.8846153846153846, sf(-3.5)); + test_exact(-5.0, 8.0, -3.5, 0.10702341137123746, sf(4.0)); + test_exact(-5.0, -3.0, -4.0, 0.875, sf(-4.5)); + test_exact(-5.0, -3.0, -4.0, 0.5, sf(-4.0)); + test_exact(-5.0, -3.0, -4.0, 0.125, sf(-3.5)); } #[test] fn test_sf_lower_bound() { let sf = |arg: f64| move |x: Triangular| x.sf(arg); - test_case(0.0, 3.0, 1.5, 1.0, sf(-1.0)); + test_exact(0.0, 3.0, 1.5, 1.0, sf(-1.0)); } #[test] fn test_sf_upper_bound() { let sf = |arg: f64| move |x: Triangular| x.sf(arg); - test_case(0.0, 3.0, 1.5, 0.0, sf(5.0)); + test_exact(0.0, 3.0, 1.5, 0.0, sf(5.0)); + } + + #[test] + fn test_inverse_cdf() { + let func = |arg: f64| move |x: Triangular| x.inverse_cdf(x.cdf(arg)); + test_absolute(0.0, 1.0, 0.5, 0.25, 1e-15, func(0.25)); + test_absolute(0.0, 1.0, 0.5, 0.5, 1e-15, func(0.5)); + test_absolute(0.0, 1.0, 0.5, 0.75, 1e-15, func(0.75)); + test_absolute(-5.0, 8.0, -3.5, -4.0, 1e-15, func(-4.0)); + test_absolute(-5.0, 8.0, -3.5, -3.5, 1e-15, func(-3.5)); + test_absolute(-5.0, 8.0, -3.5, 4.0, 1e-15, func(4.0)); + test_absolute(-5.0, -3.0, -4.0, -4.5, 1e-15, func(-4.5)); + test_absolute(-5.0, -3.0, -4.0, -4.0, 1e-15, func(-4.0)); + test_absolute(-5.0, -3.0, -4.0, -3.5, 1e-15, func(-3.5)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(-5.0, 5.0, 0.0), -5.0, 5.0); - test::check_continuous_distribution(&try_create(-15.0, -2.0, -3.0), -15.0, -2.0); + test::check_continuous_distribution(&create_ok(-5.0, 5.0, 0.0), -5.0, 5.0); + test::check_continuous_distribution(&create_ok(-15.0, -2.0, -3.0), -15.0, -2.0); } } diff --git a/src/distribution/uniform.rs b/src/distribution/uniform.rs index bca6590b..3d637734 100644 --- a/src/distribution/uniform.rs +++ b/src/distribution/uniform.rs @@ -1,9 +1,7 @@ use crate::distribution::{Continuous, ContinuousCDF}; use crate::statistics::*; -use crate::{Result, StatsError}; -use rand::distributions::Uniform as RandUniform; -use rand::Rng; use std::f64; +use std::fmt::Debug; /// Implements the [Continuous /// Uniform](https://en.wikipedia.org/wiki/Uniform_distribution_(continuous)) @@ -25,13 +23,43 @@ pub struct Uniform { max: f64, } +/// Represents the errors that can occur when creating a [`Uniform`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum UniformError { + /// The minimum is NaN or infinite. + MinInvalid, + + /// The maximum is NaN or infinite. + MaxInvalid, + + /// The maximum is not greater than the minimum. + MaxNotGreaterThanMin, +} + +impl std::fmt::Display for UniformError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + UniformError::MinInvalid => write!(f, "Minimum is NaN or infinite"), + UniformError::MaxInvalid => write!(f, "Maximum is NaN or infinite"), + UniformError::MaxNotGreaterThanMin => { + write!(f, "Maximum is not greater than the minimum") + } + } + } +} + +impl std::error::Error for UniformError {} + impl Uniform { /// Constructs a new uniform distribution with a min of `min` and a max - /// of `max` + /// of `max`. /// /// # Errors /// - /// Returns an error if `min` or `max` are `NaN` + /// Returns an error if `min` or `max` are `NaN` or infinite. + /// Returns an error if `min >= max`. /// /// # Examples /// @@ -44,19 +72,57 @@ impl Uniform { /// /// result = Uniform::new(f64::NAN, f64::NAN); /// assert!(result.is_err()); + /// + /// result = Uniform::new(f64::NEG_INFINITY, 1.0); + /// assert!(result.is_err()); /// ``` - pub fn new(min: f64, max: f64) -> Result { - if min > max || min.is_nan() || max.is_nan() { - Err(StatsError::BadParams) - } else { + pub fn new(min: f64, max: f64) -> Result { + if !min.is_finite() { + return Err(UniformError::MinInvalid); + } + + if !max.is_finite() { + return Err(UniformError::MaxInvalid); + } + + if min < max { Ok(Uniform { min, max }) + } else { + Err(UniformError::MaxNotGreaterThanMin) } } + + /// Constructs a new standard uniform distribution with + /// a lower bound 0 and an upper bound of 1. + /// + /// # Examples + /// + /// ``` + /// use statrs::distribution::Uniform; + /// + /// let uniform = Uniform::standard(); + /// ``` + pub fn standard() -> Self { + Self { min: 0.0, max: 1.0 } + } +} + +impl Default for Uniform { + fn default() -> Self { + Self::standard() + } +} + +impl std::fmt::Display for Uniform { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Uni([{},{}])", self.min, self.max) + } } +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Uniform { - fn sample(&self, rng: &mut R) -> f64 { - let d = RandUniform::new_inclusive(self.min, self.max); + fn sample(&self, rng: &mut R) -> f64 { + let d = rand::distributions::Uniform::new_inclusive(self.min, self.max); rng.sample(d) } } @@ -68,7 +134,7 @@ impl ContinuousCDF for Uniform { /// /// # Formula /// - /// ```ignore + /// ```text /// (x - min) / (max - min) /// ``` fn cdf(&self, x: f64) -> f64 { @@ -86,7 +152,7 @@ impl ContinuousCDF for Uniform { /// /// # Formula /// - /// ```ignore + /// ```text /// (max - x) / (max - min) /// ``` fn sf(&self, x: f64) -> f64 { @@ -94,14 +160,23 @@ impl ContinuousCDF for Uniform { 1.0 } else if x >= self.max { 0.0 - } else if x.is_infinite() && self.max.is_infinite() { - 0.0 - } else if self.max.is_infinite() { - 1.0 } else { (self.max - x) / (self.max - self.min) } } + + /// Finds the value of `x` where `F(p) = x` + fn inverse_cdf(&self, p: f64) -> f64 { + if !(0.0..=1.0).contains(&p) { + panic!("p must be in [0, 1], was {}", p); + } else if p == 0.0 { + self.min + } else if p == 1.0 { + self.max + } else { + (self.max - self.min) * p + self.min + } + } } impl Min for Uniform { @@ -121,37 +196,40 @@ impl Distribution for Uniform { /// /// # Formula /// - /// ```ignore + /// ```text /// (min + max) / 2 /// ``` fn mean(&self) -> Option { Some((self.min + self.max) / 2.0) } + /// Returns the variance for the continuous uniform distribution /// /// # Formula /// - /// ```ignore + /// ```text /// (max - min)^2 / 12 /// ``` fn variance(&self) -> Option { Some((self.max - self.min) * (self.max - self.min) / 12.0) } + /// Returns the entropy for the continuous uniform distribution /// /// # Formula /// - /// ```ignore + /// ```text /// ln(max - min) /// ``` fn entropy(&self) -> Option { Some((self.max - self.min).ln()) } + /// Returns the skewness for the continuous uniform distribution /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn skewness(&self) -> Option { @@ -164,7 +242,7 @@ impl Median for Uniform { /// /// # Formula /// - /// ```ignore + /// ```text /// (min + max) / 2 /// ``` fn median(&self) -> f64 { @@ -182,7 +260,7 @@ impl Mode> for Uniform { /// /// # Formula /// - /// ```ignore + /// ```text /// N/A // (max + min) / 2 for the middle element /// ``` fn mode(&self) -> Option { @@ -200,7 +278,7 @@ impl Continuous for Uniform { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 / (max - min) /// ``` fn pdf(&self, x: f64) -> f64 { @@ -221,7 +299,7 @@ impl Continuous for Uniform { /// /// # Formula /// - /// ```ignore + /// ```text /// ln(1 / (max - min)) /// ``` fn ln_pdf(&self, x: f64) -> f64 { @@ -234,234 +312,189 @@ impl Continuous for Uniform { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use crate::statistics::*; - use crate::distribution::{ContinuousCDF, Continuous, Uniform}; + use super::*; use crate::distribution::internal::*; - use crate::consts::ACC; - - fn try_create(min: f64, max: f64) -> Uniform { - let n = Uniform::new(min, max); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(min: f64, max: f64) { - let n = try_create(min, max); - assert_eq!(n.min(), min); - assert_eq!(n.max(), max); - } - - fn bad_create_case(min: f64, max: f64) { - let n = Uniform::new(min, max); - assert!(n.is_err()); - } - - fn get_value(min: f64, max: f64, eval: F) -> f64 - where F: Fn(Uniform) -> f64 - { - let n = try_create(min, max); - eval(n) - } - - fn test_case(min: f64, max: f64, expected: f64, eval: F) - where F: Fn(Uniform) -> f64 - { - - let x = get_value(min, max, eval); - assert_eq!(expected, x); - } + use crate::testing_boiler; - fn test_almost(min: f64, max: f64, expected: f64, acc: f64, eval: F) - where F: Fn(Uniform) -> f64 - { - - let x = get_value(min, max, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(min: f64, max: f64; Uniform; UniformError); #[test] fn test_create() { - create_case(0.0, 0.0); - create_case(0.0, 0.1); - create_case(0.0, 1.0); - create_case(10.0, 10.0); - create_case(-5.0, 11.0); - create_case(-5.0, 100.0); + create_ok(0.0, 0.1); + create_ok(0.0, 1.0); + create_ok(-5.0, 11.0); + create_ok(-5.0, 100.0); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN, 1.0); - bad_create_case(1.0, f64::NAN); - bad_create_case(f64::NAN, f64::NAN); - bad_create_case(1.0, 0.0); + let invalid = [ + (0.0, 0.0, UniformError::MaxNotGreaterThanMin), + (f64::NAN, 1.0, UniformError::MinInvalid), + (1.0, f64::NAN, UniformError::MaxInvalid), + (f64::NAN, f64::NAN, UniformError::MinInvalid), + (0.0, f64::INFINITY, UniformError::MaxInvalid), + (1.0, 0.0, UniformError::MaxNotGreaterThanMin), + ]; + + for (min, max, err) in invalid { + test_create_err(min, max, err); + } } #[test] fn test_variance() { let variance = |x: Uniform| x.variance().unwrap(); - test_case(-0.0, 2.0, 1.0 / 3.0, variance); - test_case(0.0, 2.0, 1.0 / 3.0, variance); - test_almost(0.1, 4.0, 1.2675, 1e-15, variance); - test_case(10.0, 11.0, 1.0 / 12.0, variance); - test_case(0.0, f64::INFINITY, f64::INFINITY, variance); + test_exact(-0.0, 2.0, 1.0 / 3.0, variance); + test_exact(0.0, 2.0, 1.0 / 3.0, variance); + test_absolute(0.1, 4.0, 1.2675, 1e-15, variance); + test_exact(10.0, 11.0, 1.0 / 12.0, variance); } #[test] fn test_entropy() { let entropy = |x: Uniform| x.entropy().unwrap(); - test_case(-0.0, 2.0, 0.6931471805599453094172, entropy); - test_case(0.0, 2.0, 0.6931471805599453094172, entropy); - test_almost(0.1, 4.0, 1.360976553135600743431, 1e-15, entropy); - test_case(1.0, 10.0, 2.19722457733621938279, entropy); - test_case(10.0, 11.0, 0.0, entropy); - test_case(0.0, f64::INFINITY, f64::INFINITY, entropy); + test_exact(-0.0, 2.0, 0.6931471805599453094172, entropy); + test_exact(0.0, 2.0, 0.6931471805599453094172, entropy); + test_absolute(0.1, 4.0, 1.360976553135600743431, 1e-15, entropy); + test_exact(1.0, 10.0, 2.19722457733621938279, entropy); + test_exact(10.0, 11.0, 0.0, entropy); } #[test] fn test_skewness() { let skewness = |x: Uniform| x.skewness().unwrap(); - test_case(-0.0, 2.0, 0.0, skewness); - test_case(0.0, 2.0, 0.0, skewness); - test_case(0.1, 4.0, 0.0, skewness); - test_case(1.0, 10.0, 0.0, skewness); - test_case(10.0, 11.0, 0.0, skewness); - test_case(0.0, f64::INFINITY, 0.0, skewness); + test_exact(-0.0, 2.0, 0.0, skewness); + test_exact(0.0, 2.0, 0.0, skewness); + test_exact(0.1, 4.0, 0.0, skewness); + test_exact(1.0, 10.0, 0.0, skewness); + test_exact(10.0, 11.0, 0.0, skewness); } #[test] fn test_mode() { let mode = |x: Uniform| x.mode().unwrap(); - test_case(-0.0, 2.0, 1.0, mode); - test_case(0.0, 2.0, 1.0, mode); - test_case(0.1, 4.0, 2.05, mode); - test_case(1.0, 10.0, 5.5, mode); - test_case(10.0, 11.0, 10.5, mode); - test_case(0.0, f64::INFINITY, f64::INFINITY, mode); + test_exact(-0.0, 2.0, 1.0, mode); + test_exact(0.0, 2.0, 1.0, mode); + test_exact(0.1, 4.0, 2.05, mode); + test_exact(1.0, 10.0, 5.5, mode); + test_exact(10.0, 11.0, 10.5, mode); } #[test] fn test_median() { let median = |x: Uniform| x.median(); - test_case(-0.0, 2.0, 1.0, median); - test_case(0.0, 2.0, 1.0, median); - test_case(0.1, 4.0, 2.05, median); - test_case(1.0, 10.0, 5.5, median); - test_case(10.0, 11.0, 10.5, median); - test_case(0.0, f64::INFINITY, f64::INFINITY, median); + test_exact(-0.0, 2.0, 1.0, median); + test_exact(0.0, 2.0, 1.0, median); + test_exact(0.1, 4.0, 2.05, median); + test_exact(1.0, 10.0, 5.5, median); + test_exact(10.0, 11.0, 10.5, median); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Uniform| x.pdf(arg); - test_case(0.0, 0.0, 0.0, pdf(-5.0)); - test_case(0.0, 0.0, f64::INFINITY, pdf(0.0)); - test_case(0.0, 0.0, 0.0, pdf(5.0)); - test_case(0.0, 0.1, 0.0, pdf(-5.0)); - test_case(0.0, 0.1, 10.0, pdf(0.05)); - test_case(0.0, 0.1, 0.0, pdf(5.0)); - test_case(0.0, 1.0, 0.0, pdf(-5.0)); - test_case(0.0, 1.0, 1.0, pdf(0.5)); - test_case(0.0, 0.1, 0.0, pdf(5.0)); - test_case(0.0, 10.0, 0.0, pdf(-5.0)); - test_case(0.0, 10.0, 0.1, pdf(1.0)); - test_case(0.0, 10.0, 0.1, pdf(5.0)); - test_case(0.0, 10.0, 0.0, pdf(11.0)); - test_case(-5.0, 100.0, 0.0, pdf(-10.0)); - test_case(-5.0, 100.0, 0.009523809523809523809524, pdf(-5.0)); - test_case(-5.0, 100.0, 0.009523809523809523809524, pdf(0.0)); - test_case(-5.0, 100.0, 0.0, pdf(101.0)); - test_case(0.0, f64::INFINITY, 0.0, pdf(-5.0)); - test_case(0.0, f64::INFINITY, 0.0, pdf(10.0)); - test_case(0.0, f64::INFINITY, 0.0, pdf(f64::INFINITY)); + test_exact(0.0, 0.1, 0.0, pdf(-5.0)); + test_exact(0.0, 0.1, 10.0, pdf(0.05)); + test_exact(0.0, 0.1, 0.0, pdf(5.0)); + test_exact(0.0, 1.0, 0.0, pdf(-5.0)); + test_exact(0.0, 1.0, 1.0, pdf(0.5)); + test_exact(0.0, 0.1, 0.0, pdf(5.0)); + test_exact(0.0, 10.0, 0.0, pdf(-5.0)); + test_exact(0.0, 10.0, 0.1, pdf(1.0)); + test_exact(0.0, 10.0, 0.1, pdf(5.0)); + test_exact(0.0, 10.0, 0.0, pdf(11.0)); + test_exact(-5.0, 100.0, 0.0, pdf(-10.0)); + test_exact(-5.0, 100.0, 0.009523809523809523809524, pdf(-5.0)); + test_exact(-5.0, 100.0, 0.009523809523809523809524, pdf(0.0)); + test_exact(-5.0, 100.0, 0.0, pdf(101.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Uniform| x.ln_pdf(arg); - test_case(0.0, 0.0, f64::NEG_INFINITY, ln_pdf(-5.0)); - test_case(0.0, 0.0, f64::INFINITY, ln_pdf(0.0)); - test_case(0.0, 0.0, f64::NEG_INFINITY, ln_pdf(5.0)); - test_case(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(-5.0)); - test_almost(0.0, 0.1, 2.302585092994045684018, 1e-15, ln_pdf(0.05)); - test_case(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(5.0)); - test_case(0.0, 1.0, f64::NEG_INFINITY, ln_pdf(-5.0)); - test_case(0.0, 1.0, 0.0, ln_pdf(0.5)); - test_case(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(5.0)); - test_case(0.0, 10.0, f64::NEG_INFINITY, ln_pdf(-5.0)); - test_case(0.0, 10.0, -2.302585092994045684018, ln_pdf(1.0)); - test_case(0.0, 10.0, -2.302585092994045684018, ln_pdf(5.0)); - test_case(0.0, 10.0, f64::NEG_INFINITY, ln_pdf(11.0)); - test_case(-5.0, 100.0, f64::NEG_INFINITY, ln_pdf(-10.0)); - test_case(-5.0, 100.0, -4.653960350157523371101, ln_pdf(-5.0)); - test_case(-5.0, 100.0, -4.653960350157523371101, ln_pdf(0.0)); - test_case(-5.0, 100.0, f64::NEG_INFINITY, ln_pdf(101.0)); - test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(-5.0)); - test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(10.0)); - test_case(0.0, f64::INFINITY, f64::NEG_INFINITY, ln_pdf(f64::INFINITY)); + test_exact(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(-5.0)); + test_absolute(0.0, 0.1, 2.302585092994045684018, 1e-15, ln_pdf(0.05)); + test_exact(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(5.0)); + test_exact(0.0, 1.0, f64::NEG_INFINITY, ln_pdf(-5.0)); + test_exact(0.0, 1.0, 0.0, ln_pdf(0.5)); + test_exact(0.0, 0.1, f64::NEG_INFINITY, ln_pdf(5.0)); + test_exact(0.0, 10.0, f64::NEG_INFINITY, ln_pdf(-5.0)); + test_exact(0.0, 10.0, -2.302585092994045684018, ln_pdf(1.0)); + test_exact(0.0, 10.0, -2.302585092994045684018, ln_pdf(5.0)); + test_exact(0.0, 10.0, f64::NEG_INFINITY, ln_pdf(11.0)); + test_exact(-5.0, 100.0, f64::NEG_INFINITY, ln_pdf(-10.0)); + test_exact(-5.0, 100.0, -4.653960350157523371101, ln_pdf(-5.0)); + test_exact(-5.0, 100.0, -4.653960350157523371101, ln_pdf(0.0)); + test_exact(-5.0, 100.0, f64::NEG_INFINITY, ln_pdf(101.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Uniform| x.cdf(arg); - test_case(0.0, 0.0, 0.0, cdf(0.0)); - test_case(0.0, 0.1, 0.5, cdf(0.05)); - test_case(0.0, 1.0, 0.5, cdf(0.5)); - test_case(0.0, 10.0, 0.1, cdf(1.0)); - test_case(0.0, 10.0, 0.5, cdf(5.0)); - test_case(-5.0, 100.0, 0.0, cdf(-5.0)); - test_case(-5.0, 100.0, 0.04761904761904761904762, cdf(0.0)); - test_case(0.0, f64::INFINITY, 0.0, cdf(10.0)); - test_case(0.0, f64::INFINITY, 1.0, cdf(f64::INFINITY)); + test_exact(0.0, 0.1, 0.5, cdf(0.05)); + test_exact(0.0, 1.0, 0.5, cdf(0.5)); + test_exact(0.0, 10.0, 0.1, cdf(1.0)); + test_exact(0.0, 10.0, 0.5, cdf(5.0)); + test_exact(-5.0, 100.0, 0.0, cdf(-5.0)); + test_exact(-5.0, 100.0, 0.04761904761904761904762, cdf(0.0)); + } + + #[test] + fn test_inverse_cdf() { + let inverse_cdf = |arg: f64| move |x: Uniform| x.inverse_cdf(arg); + test_exact(0.0, 0.1, 0.05, inverse_cdf(0.5)); + test_exact(0.0, 10.0, 5.0, inverse_cdf(0.5)); + test_exact(1.0, 10.0, 1.0, inverse_cdf(0.0)); + test_exact(1.0, 10.0, 4.0, inverse_cdf(1.0 / 3.0)); + test_exact(1.0, 10.0, 10.0, inverse_cdf(1.0)); } #[test] fn test_cdf_lower_bound() { let cdf = |arg: f64| move |x: Uniform| x.cdf(arg); - test_case(0.0, 3.0, 0.0, cdf(-1.0)); + test_exact(0.0, 3.0, 0.0, cdf(-1.0)); } #[test] fn test_cdf_upper_bound() { let cdf = |arg: f64| move |x: Uniform| x.cdf(arg); - test_case(0.0, 3.0, 1.0, cdf(5.0)); + test_exact(0.0, 3.0, 1.0, cdf(5.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Uniform| x.sf(arg); - test_case(0.0, 0.0, 1.0, sf(0.0)); - test_case(0.0, 0.1, 0.5, sf(0.05)); - test_case(0.0, 1.0, 0.5, sf(0.5)); - test_case(0.0, 10.0, 0.9, sf(1.0)); - test_case(0.0, 10.0, 0.5, sf(5.0)); - test_case(-5.0, 100.0, 1.0, sf(-5.0)); - test_case(-5.0, 100.0, 0.9523809523809523, sf(0.0)); - test_case(0.0, f64::INFINITY, 1.0, sf(10.0)); - test_case(0.0, f64::INFINITY, 0.0, sf(f64::INFINITY)); + test_exact(0.0, 0.1, 0.5, sf(0.05)); + test_exact(0.0, 1.0, 0.5, sf(0.5)); + test_exact(0.0, 10.0, 0.9, sf(1.0)); + test_exact(0.0, 10.0, 0.5, sf(5.0)); + test_exact(-5.0, 100.0, 1.0, sf(-5.0)); + test_exact(-5.0, 100.0, 0.9523809523809523, sf(0.0)); } #[test] fn test_sf_lower_bound() { let sf = |arg: f64| move |x: Uniform| x.sf(arg); - test_case(0.0, 3.0, 1.0, sf(-1.0)); + test_exact(0.0, 3.0, 1.0, sf(-1.0)); } #[test] fn test_sf_upper_bound() { let sf = |arg: f64| move |x: Uniform| x.sf(arg); - test_case(0.0, 3.0, 0.0, sf(5.0)); + test_exact(0.0, 3.0, 0.0, sf(5.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(0.0, 10.0), 0.0, 10.0); - test::check_continuous_distribution(&try_create(-2.0, 15.0), -2.0, 15.0); + test::check_continuous_distribution(&create_ok(0.0, 10.0), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(-2.0, 15.0), -2.0, 15.0); } + #[cfg(feature = "rand")] #[test] fn test_samples_in_range() { use rand::rngs::StdRng; @@ -477,11 +510,24 @@ mod tests { let min = -0.5; let max = 0.5; let num_trials = 10_000; - let n = try_create(min, max); + let n = create_ok(min, max); assert!((0..num_trials) .map(|_| n.sample::(&mut r)) .all(|v| (min <= v) && (v < max)) ); } + + #[test] + fn test_default() { + let n = Uniform::default(); + + let n_mean = n.mean().unwrap(); + let n_std = n.std_dev().unwrap(); + + // Check that the mean of the distribution is close to 1 / 2 + assert_almost_eq!(n_mean, 0.5, 1e-15); + // Check that the standard deviation of the distribution is close to 1 / sqrt(12) + assert_almost_eq!(n_std, 0.288_675_134_594_812_9, 1e-15); + } } diff --git a/src/distribution/weibull.rs b/src/distribution/weibull.rs index eab7d942..aacf7bb6 100644 --- a/src/distribution/weibull.rs +++ b/src/distribution/weibull.rs @@ -1,9 +1,7 @@ +use crate::consts; use crate::distribution::{Continuous, ContinuousCDF}; use crate::function::gamma; -use crate::is_zero; use crate::statistics::*; -use crate::{consts, Result, StatsError}; -use rand::Rng; use std::f64; /// Implements the [Weibull](https://en.wikipedia.org/wiki/Weibull_distribution) @@ -21,13 +19,36 @@ use std::f64; /// 0.95135076986687318362924871772654021925505786260884, 1e-15)); /// assert_eq!(n.pdf(1.0), 3.6787944117144232159552377016146086744581113103177); /// ``` -#[derive(Debug, Copy, Clone, PartialEq)] +#[derive(Copy, Clone, PartialEq, Debug)] pub struct Weibull { shape: f64, scale: f64, scale_pow_shape_inv: f64, } +/// Represents the errors that can occur when creating a [`Weibull`]. +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum WeibullError { + /// The shape is NaN, zero or less than zero. + ShapeInvalid, + + /// The scale is NaN, zero or less than zero. + ScaleInvalid, +} + +impl std::fmt::Display for WeibullError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + WeibullError::ShapeInvalid => write!(f, "Shape is NaN, zero or less than zero."), + WeibullError::ScaleInvalid => write!(f, "Scale is NaN, zero or less than zero."), + } + } +} + +impl std::error::Error for WeibullError {} + impl Weibull { /// Constructs a new weibull distribution with a shape (k) of `shape` /// and a scale (λ) of `scale` @@ -48,17 +69,20 @@ impl Weibull { /// result = Weibull::new(0.0, 0.0); /// assert!(result.is_err()); /// ``` - pub fn new(shape: f64, scale: f64) -> Result { - let is_nan = shape.is_nan() || scale.is_nan(); - match (shape, scale, is_nan) { - (_, _, true) => Err(StatsError::BadParams), - (_, _, false) if shape <= 0.0 || scale <= 0.0 => Err(StatsError::BadParams), - (_, _, false) => Ok(Weibull { - shape, - scale, - scale_pow_shape_inv: scale.powf(-shape), - }), + pub fn new(shape: f64, scale: f64) -> Result { + if shape.is_nan() || shape <= 0.0 { + return Err(WeibullError::ShapeInvalid); } + + if scale.is_nan() || scale <= 0.0 { + return Err(WeibullError::ScaleInvalid); + } + + Ok(Weibull { + shape, + scale, + scale_pow_shape_inv: scale.powf(-shape), + }) } /// Returns the shape of the weibull distribution @@ -90,8 +114,15 @@ impl Weibull { } } +impl std::fmt::Display for Weibull { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Weibull({},{})", self.scale, self.shape) + } +} + +#[cfg(feature = "rand")] impl ::rand::distributions::Distribution for Weibull { - fn sample(&self, rng: &mut R) -> f64 { + fn sample(&self, rng: &mut R) -> f64 { let x: f64 = rng.gen(); self.scale * (-x.ln()).powf(1.0 / self.shape) } @@ -103,7 +134,7 @@ impl ContinuousCDF for Weibull { /// /// # Formula /// - /// ```ignore + /// ```text /// 1 - e^-((x/λ)^k) /// ``` /// @@ -121,7 +152,7 @@ impl ContinuousCDF for Weibull { /// /// # Formula /// - /// ```ignore + /// ```text /// e^-((x/λ)^k) /// ``` /// @@ -133,6 +164,24 @@ impl ContinuousCDF for Weibull { (-x.powf(self.shape) * self.scale_pow_shape_inv).exp() } } + + /// Calculates the inverse cumulative distribution function for the weibull + /// distribution at `x` + /// + /// # Formula + /// + /// ```text + /// λ (-ln(1 - x))^(1 / k) + /// ``` + /// + /// where `k` is the shape and `λ` is the scale + fn inverse_cdf(&self, p: f64) -> f64 { + if !(0.0..=1.0).contains(&p) { + panic!("x must be in [0, 1]"); + } + + (-((-p).ln_1p() / self.scale_pow_shape_inv)).powf(1.0 / self.shape) + } } impl Min for Weibull { @@ -141,7 +190,7 @@ impl Min for Weibull { /// /// # Formula /// - /// ```ignore + /// ```text /// 0 /// ``` fn min(&self) -> f64 { @@ -155,8 +204,8 @@ impl Max for Weibull { /// /// # Formula /// - /// ```ignore - /// INF + /// ```text + /// f64::INFINITY /// ``` fn max(&self) -> f64 { f64::INFINITY @@ -168,7 +217,7 @@ impl Distribution for Weibull { /// /// # Formula /// - /// ```ignore + /// ```text /// λΓ(1 + 1 / k) /// ``` /// @@ -177,11 +226,12 @@ impl Distribution for Weibull { fn mean(&self) -> Option { Some(self.scale * gamma::gamma(1.0 + 1.0 / self.shape)) } + /// Returns the variance of the weibull distribution /// /// # Formula /// - /// ```ignore + /// ```text /// λ^2 * (Γ(1 + 2 / k) - Γ(1 + 1 / k)^2) /// ``` /// @@ -191,11 +241,12 @@ impl Distribution for Weibull { let mean = self.mean()?; Some(self.scale * self.scale * gamma::gamma(1.0 + 2.0 / self.shape) - mean * mean) } + /// Returns the entropy of the weibull distribution /// /// # Formula /// - /// ```ignore + /// ```text /// γ(1 - 1 / k) + ln(λ / k) + 1 /// ``` /// @@ -207,11 +258,12 @@ impl Distribution for Weibull { + 1.0; Some(entr) } + /// Returns the skewness of the weibull distribution /// /// # Formula /// - /// ```ignore + /// ```text /// (Γ(1 + 3 / k) * λ^3 - 3μσ^2 - μ^3) / σ^3 /// ``` /// @@ -236,7 +288,7 @@ impl Median for Weibull { /// /// # Formula /// - /// ```ignore + /// ```text /// λ(ln(2))^(1 / k) /// ``` /// @@ -251,7 +303,7 @@ impl Mode> for Weibull { /// /// # Formula /// - /// ```ignore + /// ```text /// if k == 1 { /// 0 /// } else { @@ -276,7 +328,7 @@ impl Continuous for Weibull { /// /// # Formula /// - /// ```ignore + /// ```text /// (k / λ) * (x / λ)^(k - 1) * e^(-(x / λ)^k) /// ``` /// @@ -284,7 +336,7 @@ impl Continuous for Weibull { fn pdf(&self, x: f64) -> f64 { if x < 0.0 { 0.0 - } else if is_zero(x) && ulps_eq!(self.shape, 1.0) { + } else if x == 0.0 && ulps_eq!(self.shape, 1.0) { 1.0 / self.scale } else if x.is_infinite() { 0.0 @@ -301,7 +353,7 @@ impl Continuous for Weibull { /// /// # Formula /// - /// ```ignore + /// ```text /// ln((k / λ) * (x / λ)^(k - 1) * e^(-(x / λ)^k)) /// ``` /// @@ -309,7 +361,7 @@ impl Continuous for Weibull { fn ln_pdf(&self, x: f64) -> f64 { if x < 0.0 { f64::NEG_INFINITY - } else if is_zero(x) && ulps_eq!(self.shape, 1.0) { + } else if x == 0.0 && ulps_eq!(self.shape, 1.0) { 0.0 - self.scale.ln() } else if x.is_infinite() { f64::NEG_INFINITY @@ -322,204 +374,182 @@ impl Continuous for Weibull { } #[rustfmt::skip] -#[cfg(all(test, feature = "nightly"))] +#[cfg(test)] mod tests { - use crate::statistics::*; - use crate::distribution::{ContinuousCDF, Continuous, Weibull}; + use super::*; use crate::distribution::internal::*; - use crate::consts::ACC; + use crate::testing_boiler; - fn try_create(shape: f64, scale: f64) -> Weibull { - let n = Weibull::new(shape, scale); - assert!(n.is_ok()); - n.unwrap() - } - - fn create_case(shape: f64, scale: f64) { - let n = try_create(shape, scale); - assert_eq!(shape, n.shape()); - assert_eq!(scale, n.scale()); - } - - fn bad_create_case(shape: f64, scale: f64) { - let n = Weibull::new(shape, scale); - assert!(n.is_err()); - } - - fn get_value(shape: f64, scale: f64, eval: F) -> f64 - where F: Fn(Weibull) -> f64 - { - let n = try_create(shape, scale); - eval(n) - } - - fn test_case(shape: f64, scale: f64, expected: f64, eval: F) - where F: Fn(Weibull) -> f64 - { - let x = get_value(shape, scale, eval); - assert_eq!(expected, x); - } - - fn test_almost(shape: f64, scale: f64, expected: f64, acc: f64, eval: F) - where F: Fn(Weibull) -> f64 - { - let x = get_value(shape, scale, eval); - assert_almost_eq!(expected, x, acc); - } + testing_boiler!(shape: f64, scale: f64; Weibull; WeibullError); #[test] fn test_create() { - create_case(1.0, 0.1); - create_case(10.0, 1.0); - create_case(11.0, 10.0); - create_case(12.0, f64::INFINITY); + create_ok(1.0, 0.1); + create_ok(10.0, 1.0); + create_ok(11.0, 10.0); + create_ok(12.0, f64::INFINITY); } #[test] fn test_bad_create() { - bad_create_case(f64::NAN, 1.0); - bad_create_case(1.0, f64::NAN); - bad_create_case(f64::NAN, f64::NAN); - bad_create_case(1.0, -1.0); - bad_create_case(-1.0, 1.0); - bad_create_case(-1.0, -1.0); - bad_create_case(0.0, 0.0); - bad_create_case(0.0, 1.0); - bad_create_case(1.0, 0.0); + test_create_err(f64::NAN, 1.0, WeibullError::ShapeInvalid); + test_create_err(1.0, f64::NAN, WeibullError::ScaleInvalid); + create_err(f64::NAN, f64::NAN); + create_err(1.0, -1.0); + create_err(-1.0, 1.0); + create_err(-1.0, -1.0); + create_err(0.0, 0.0); + create_err(0.0, 1.0); + create_err(1.0, 0.0); } #[test] fn test_mean() { let mean = |x: Weibull| x.mean().unwrap(); - test_case(1.0, 0.1, 0.1, mean); - test_case(1.0, 1.0, 1.0, mean); - test_almost(10.0, 10.0, 9.5135076986687318362924871772654021925505786260884, 1e-14, mean); - test_almost(10.0, 1.0, 0.95135076986687318362924871772654021925505786260884, 1e-15, mean); + test_exact(1.0, 0.1, 0.1, mean); + test_exact(1.0, 1.0, 1.0, mean); + test_absolute(10.0, 10.0, 9.5135076986687318362924871772654021925505786260884, 1e-14, mean); + test_absolute(10.0, 1.0, 0.95135076986687318362924871772654021925505786260884, 1e-15, mean); } #[test] fn test_variance() { let variance = |x: Weibull| x.variance().unwrap(); - test_almost(1.0, 0.1, 0.01, 1e-16, variance); - test_almost(1.0, 1.0, 1.0, 1e-14, variance); - test_almost(10.0, 10.0, 1.3100455073468309147154581687505295026863354547057, 1e-12, variance); - test_almost(10.0, 1.0, 0.013100455073468309147154581687505295026863354547057, 1e-14, variance); + test_absolute(1.0, 0.1, 0.01, 1e-16, variance); + test_absolute(1.0, 1.0, 1.0, 1e-14, variance); + test_absolute(10.0, 10.0, 1.3100455073468309147154581687505295026863354547057, 1e-12, variance); + test_absolute(10.0, 1.0, 0.013100455073468309147154581687505295026863354547057, 1e-14, variance); } #[test] fn test_entropy() { let entropy = |x: Weibull| x.entropy().unwrap(); - test_almost(1.0, 0.1, -1.302585092994045684018, 1e-15, entropy); - test_case(1.0, 1.0, 1.0, entropy); - test_case(10.0, 10.0, 1.519494098411379574546, entropy); - test_almost(10.0, 1.0, -0.783090994582666109472, 1e-15, entropy); + test_absolute(1.0, 0.1, -1.302585092994045684018, 1e-15, entropy); + test_exact(1.0, 1.0, 1.0, entropy); + test_exact(10.0, 10.0, 1.519494098411379574546, entropy); + test_absolute(10.0, 1.0, -0.783090994582666109472, 1e-15, entropy); } #[test] fn test_skewnewss() { let skewness = |x: Weibull| x.skewness().unwrap(); - test_almost(1.0, 0.1, 2.0, 1e-13, skewness); - test_almost(1.0, 1.0, 2.0, 1e-13, skewness); - test_almost(10.0, 10.0, -0.63763713390314440916597757156663888653981696212127, 1e-11, skewness); - test_almost(10.0, 1.0, -0.63763713390314440916597757156663888653981696212127, 1e-11, skewness); + test_absolute(1.0, 0.1, 2.0, 1e-13, skewness); + test_absolute(1.0, 1.0, 2.0, 1e-13, skewness); + test_absolute(10.0, 10.0, -0.63763713390314440916597757156663888653981696212127, 1e-11, skewness); + test_absolute(10.0, 1.0, -0.63763713390314440916597757156663888653981696212127, 1e-11, skewness); } #[test] fn test_median() { let median = |x: Weibull| x.median(); - test_case(1.0, 0.1, 0.069314718055994530941723212145817656807550013436026, median); - test_case(1.0, 1.0, 0.69314718055994530941723212145817656807550013436026, median); - test_case(10.0, 10.0, 9.6401223546778973665856033763604752124634905617583, median); - test_case(10.0, 1.0, 0.96401223546778973665856033763604752124634905617583, median); + test_exact(1.0, 0.1, 0.069314718055994530941723212145817656807550013436026, median); + test_exact(1.0, 1.0, 0.69314718055994530941723212145817656807550013436026, median); + test_exact(10.0, 10.0, 9.6401223546778973665856033763604752124634905617583, median); + test_exact(10.0, 1.0, 0.96401223546778973665856033763604752124634905617583, median); } #[test] fn test_mode() { let mode = |x: Weibull| x.mode().unwrap(); - test_case(1.0, 0.1, 0.0, mode); - test_case(1.0, 1.0, 0.0, mode); - test_case(10.0, 10.0, 9.8951925820621439264623017041980483215553841533709, mode); - test_case(10.0, 1.0, 0.98951925820621439264623017041980483215553841533709, mode); + test_exact(1.0, 0.1, 0.0, mode); + test_exact(1.0, 1.0, 0.0, mode); + test_exact(10.0, 10.0, 9.8951925820621439264623017041980483215553841533709, mode); + test_exact(10.0, 1.0, 0.98951925820621439264623017041980483215553841533709, mode); } #[test] fn test_min_max() { let min = |x: Weibull| x.min(); let max = |x: Weibull| x.max(); - test_case(1.0, 1.0, 0.0, min); - test_case(1.0, 1.0, f64::INFINITY, max); + test_exact(1.0, 1.0, 0.0, min); + test_exact(1.0, 1.0, f64::INFINITY, max); } #[test] fn test_pdf() { let pdf = |arg: f64| move |x: Weibull| x.pdf(arg); - test_case(1.0, 0.1, 10.0, pdf(0.0)); - test_case(1.0, 0.1, 0.00045399929762484851535591515560550610237918088866565, pdf(1.0)); - test_case(1.0, 0.1, 3.7200759760208359629596958038631183373588922923768e-43, pdf(10.0)); - test_case(1.0, 1.0, 1.0, pdf(0.0)); - test_case(1.0, 1.0, 0.36787944117144232159552377016146086744581113103177, pdf(1.0)); - test_case(1.0, 1.0, 0.000045399929762484851535591515560550610237918088866565, pdf(10.0)); - test_case(10.0, 10.0, 0.0, pdf(0.0)); - test_almost(10.0, 10.0, 9.9999999990000000000499999999983333333333750000000e-10, 1e-24, pdf(1.0)); - test_case(10.0, 10.0, 0.36787944117144232159552377016146086744581113103177, pdf(10.0)); - test_case(10.0, 1.0, 0.0, pdf(0.0)); - test_case(10.0, 1.0, 3.6787944117144232159552377016146086744581113103177, pdf(1.0)); - test_case(10.0, 1.0, 0.0, pdf(10.0)); + test_exact(1.0, 0.1, 10.0, pdf(0.0)); + test_exact(1.0, 0.1, 0.00045399929762484851535591515560550610237918088866565, pdf(1.0)); + test_exact(1.0, 0.1, 3.7200759760208359629596958038631183373588922923768e-43, pdf(10.0)); + test_exact(1.0, 1.0, 1.0, pdf(0.0)); + test_exact(1.0, 1.0, 0.36787944117144232159552377016146086744581113103177, pdf(1.0)); + test_exact(1.0, 1.0, 0.000045399929762484851535591515560550610237918088866565, pdf(10.0)); + test_exact(10.0, 10.0, 0.0, pdf(0.0)); + test_absolute(10.0, 10.0, 9.9999999990000000000499999999983333333333750000000e-10, 1e-24, pdf(1.0)); + test_exact(10.0, 10.0, 0.36787944117144232159552377016146086744581113103177, pdf(10.0)); + test_exact(10.0, 1.0, 0.0, pdf(0.0)); + test_exact(10.0, 1.0, 3.6787944117144232159552377016146086744581113103177, pdf(1.0)); + test_exact(10.0, 1.0, 0.0, pdf(10.0)); } #[test] fn test_ln_pdf() { let ln_pdf = |arg: f64| move |x: Weibull| x.ln_pdf(arg); - test_almost(1.0, 0.1, 2.3025850929940456840179914546843642076011014886288, 1e-15, ln_pdf(0.0)); - test_almost(1.0, 0.1, -7.6974149070059543159820085453156357923988985113712, 1e-15, ln_pdf(1.0)); - test_case(1.0, 0.1, -97.697414907005954315982008545315635792398898511371, ln_pdf(10.0)); - test_case(1.0, 1.0, 0.0, ln_pdf(0.0)); - test_case(1.0, 1.0, -1.0, ln_pdf(1.0)); - test_case(1.0, 1.0, -10.0, ln_pdf(10.0)); - test_case(10.0, 10.0, f64::NEG_INFINITY, ln_pdf(0.0)); - test_almost(10.0, 10.0, -20.723265837046411156161923092159277868409913397659, 1e-14, ln_pdf(1.0)); - test_case(10.0, 10.0, -1.0, ln_pdf(10.0)); - test_case(10.0, 1.0, f64::NEG_INFINITY, ln_pdf(0.0)); - test_almost(10.0, 1.0, 1.3025850929940456840179914546843642076011014886288, 1e-15, ln_pdf(1.0)); - test_case(10.0, 1.0, -9.999999976974149070059543159820085453156357923988985113712e9, ln_pdf(10.0)); + test_absolute(1.0, 0.1, 2.3025850929940456840179914546843642076011014886288, 1e-15, ln_pdf(0.0)); + test_absolute(1.0, 0.1, -7.6974149070059543159820085453156357923988985113712, 1e-15, ln_pdf(1.0)); + test_exact(1.0, 0.1, -97.697414907005954315982008545315635792398898511371, ln_pdf(10.0)); + test_exact(1.0, 1.0, 0.0, ln_pdf(0.0)); + test_exact(1.0, 1.0, -1.0, ln_pdf(1.0)); + test_exact(1.0, 1.0, -10.0, ln_pdf(10.0)); + test_exact(10.0, 10.0, f64::NEG_INFINITY, ln_pdf(0.0)); + test_absolute(10.0, 10.0, -20.723265837046411156161923092159277868409913397659, 1e-14, ln_pdf(1.0)); + test_exact(10.0, 10.0, -1.0, ln_pdf(10.0)); + test_exact(10.0, 1.0, f64::NEG_INFINITY, ln_pdf(0.0)); + test_absolute(10.0, 1.0, 1.3025850929940456840179914546843642076011014886288, 1e-15, ln_pdf(1.0)); + test_exact(10.0, 1.0, -9.999999976974149070059543159820085453156357923988985113712e9, ln_pdf(10.0)); } #[test] fn test_cdf() { let cdf = |arg: f64| move |x: Weibull| x.cdf(arg); - test_case(1.0, 0.1, 0.0, cdf(0.0)); - test_case(1.0, 0.1, 0.99995460007023751514846440848443944938976208191113, cdf(1.0)); - test_case(1.0, 0.1, 0.99999999999999999999999999999999999999999996279924, cdf(10.0)); - test_case(1.0, 1.0, 0.0, cdf(0.0)); - test_case(1.0, 1.0, 0.63212055882855767840447622983853913255418886896823, cdf(1.0)); - test_case(1.0, 1.0, 0.99995460007023751514846440848443944938976208191113, cdf(10.0)); - test_case(10.0, 10.0, 0.0, cdf(0.0)); - test_almost(10.0, 10.0, 9.9999999995000000000166666666662500000000083333333e-11, 1e-25, cdf(1.0)); - test_case(10.0, 10.0, 0.63212055882855767840447622983853913255418886896823, cdf(10.0)); - test_case(10.0, 1.0, 0.0, cdf(0.0)); - test_case(10.0, 1.0, 0.63212055882855767840447622983853913255418886896823, cdf(1.0)); - test_case(10.0, 1.0, 1.0, cdf(10.0)); + test_exact(1.0, 0.1, 0.0, cdf(0.0)); + test_exact(1.0, 0.1, 0.99995460007023751514846440848443944938976208191113, cdf(1.0)); + test_exact(1.0, 0.1, 0.99999999999999999999999999999999999999999996279924, cdf(10.0)); + test_exact(1.0, 1.0, 0.0, cdf(0.0)); + test_exact(1.0, 1.0, 0.63212055882855767840447622983853913255418886896823, cdf(1.0)); + test_exact(1.0, 1.0, 0.99995460007023751514846440848443944938976208191113, cdf(10.0)); + test_exact(10.0, 10.0, 0.0, cdf(0.0)); + test_absolute(10.0, 10.0, 9.9999999995000000000166666666662500000000083333333e-11, 1e-25, cdf(1.0)); + test_exact(10.0, 10.0, 0.63212055882855767840447622983853913255418886896823, cdf(10.0)); + test_exact(10.0, 1.0, 0.0, cdf(0.0)); + test_exact(10.0, 1.0, 0.63212055882855767840447622983853913255418886896823, cdf(1.0)); + test_exact(10.0, 1.0, 1.0, cdf(10.0)); } #[test] fn test_sf() { let sf = |arg: f64| move |x: Weibull| x.sf(arg); - test_case(1.0, 0.1, 1.0, sf(0.0)); - test_case(1.0, 0.1, 4.5399929762484854e-5, sf(1.0)); - test_case(1.0, 0.1, 3.720075976020836e-44, sf(10.0)); - test_case(1.0, 1.0, 1.0, sf(0.0)); - test_case(1.0, 1.0, 0.36787944117144233, sf(1.0)); - test_case(1.0, 1.0, 4.5399929762484854e-5, sf(10.0)); - test_case(10.0, 10.0, 1.0, sf(0.0)); - test_almost(10.0, 10.0, 0.9999999999, 1e-25, sf(1.0)); - test_case(10.0, 10.0, 0.36787944117144233, sf(10.0)); - test_case(10.0, 1.0, 1.0, sf(0.0)); - test_case(10.0, 1.0, 0.36787944117144233, sf(1.0)); - test_case(10.0, 1.0, 0.0, sf(10.0)); + test_exact(1.0, 0.1, 1.0, sf(0.0)); + test_exact(1.0, 0.1, 4.5399929762484854e-5, sf(1.0)); + test_exact(1.0, 0.1, 3.720075976020836e-44, sf(10.0)); + test_exact(1.0, 1.0, 1.0, sf(0.0)); + test_exact(1.0, 1.0, 0.36787944117144233, sf(1.0)); + test_exact(1.0, 1.0, 4.5399929762484854e-5, sf(10.0)); + test_exact(10.0, 10.0, 1.0, sf(0.0)); + test_absolute(10.0, 10.0, 0.9999999999, 1e-25, sf(1.0)); + test_exact(10.0, 10.0, 0.36787944117144233, sf(10.0)); + test_exact(10.0, 1.0, 1.0, sf(0.0)); + test_exact(10.0, 1.0, 0.36787944117144233, sf(1.0)); + test_exact(10.0, 1.0, 0.0, sf(10.0)); + } + + #[test] + fn test_inverse_cdf() { + let func = |arg: f64| move |x: Weibull| x.inverse_cdf(x.cdf(arg)); + test_exact(1.0, 0.1, 0.0, func(0.0)); + test_absolute(1.0, 0.1, 1.0, 1e-13, func(1.0)); + test_exact(1.0, 1.0, 0.0, func(0.0)); + test_exact(1.0, 1.0, 1.0, func(1.0)); + test_absolute(1.0, 1.0, 10.0, 1e-10, func(10.0)); + test_exact(10.0, 10.0, 0.0, func(0.0)); + test_absolute(10.0, 10.0, 1.0, 1e-5, func(1.0)); + test_absolute(10.0, 10.0, 10.0, 1e-10, func(10.0)); + test_exact(10.0, 1.0, 0.0, func(0.0)); + test_exact(10.0, 1.0, 1.0, func(1.0)); } #[test] fn test_continuous() { - test::check_continuous_distribution(&try_create(1.0, 0.2), 0.0, 10.0); + test::check_continuous_distribution(&create_ok(1.0, 0.2), 0.0, 10.0); } } diff --git a/src/error.rs b/src/error.rs index f13ecfc6..e6f7ca40 100644 --- a/src/error.rs +++ b/src/error.rs @@ -2,10 +2,12 @@ use std::error::Error; use std::fmt; /// Enumeration of possible errors thrown within the `statrs` library -#[derive(Debug)] +#[derive(Clone, PartialEq, Debug)] pub enum StatsError { /// Generic bad input parameter error BadParams, + /// An argument must be finite + ArgFinite(&'static str), /// An argument should have been positive and was not ArgMustBePositive(&'static str), /// An argument should have been non-negative and was not @@ -48,16 +50,13 @@ pub enum StatsError { SpecialCase(&'static str), } -impl Error for StatsError { - fn description(&self) -> &str { - "Error performing statistical calculation" - } -} +impl Error for StatsError {} impl fmt::Display for StatsError { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match *self { StatsError::BadParams => write!(f, "Bad distribution parameters"), + StatsError::ArgFinite(s) => write!(f, "Argument {} must be finite", s), StatsError::ArgMustBePositive(s) => write!(f, "Argument {} must be positive", s), StatsError::ArgNotNegative(s) => write!(f, "Argument {} must be non-negative", s), StatsError::ArgIntervalIncl(s, min, max) => { @@ -104,3 +103,18 @@ impl fmt::Display for StatsError { } } } + +#[cfg(test)] +mod tests { + use super::*; + + fn assert_sync() {} + fn assert_send() {} + + #[test] + fn test_sync_send() { + // Error types should implement Sync and Send + assert_sync::(); + assert_send::(); + } +} diff --git a/src/function/beta.rs b/src/function/beta.rs index 23fb3430..794a8504 100644 --- a/src/function/beta.rs +++ b/src/function/beta.rs @@ -3,7 +3,6 @@ use crate::error::StatsError; use crate::function::gamma; -use crate::is_zero; use crate::prec; use crate::Result; use std::f64; @@ -118,7 +117,7 @@ pub fn checked_beta_reg(a: f64, b: f64, x: f64) -> Result { } else if !(0.0..=1.0).contains(&x) { Err(StatsError::ArgIntervalIncl("x", 0.0, 1.0)) } else { - let bt = if is_zero(x) || ulps_eq!(x, 1.0) { + let bt = if x == 0.0 || ulps_eq!(x, 1.0) { 0.0 } else { (gamma::ln_gamma(a + b) - gamma::ln_gamma(a) - gamma::ln_gamma(b) @@ -204,7 +203,6 @@ pub fn checked_beta_reg(a: f64, b: f64, x: f64) -> Result { } /// Computes the inverse of the regularized incomplete beta function -// // This code is based on the implementation in the ["special"][1] crate, // which in turn is based on a [C implementation][2] by John Burkardt. The // original algorithm was published in Applied Statistics and is known as @@ -327,11 +325,7 @@ pub fn inv_beta_reg(mut a: f64, mut b: f64, mut x: f64) -> f64 { } } - if p < 0.0001 { - p = 0.0001; - } else if 0.9999 < p { - p = 0.9999; - } + p = p.clamp(0.0001, 0.9999); // Remark AS R83 // http://www.jstor.org/stable/2347779 @@ -365,7 +359,7 @@ pub fn inv_beta_reg(mut a: f64, mut b: f64, mut x: f64) -> f64 { if sq < prev { pnext = p - adj; - if 0.0 <= pnext && pnext <= 1.0 { + if (0.0..=1.0).contains(&pnext) { break; } } diff --git a/src/function/erf.rs b/src/function/erf.rs index e42fd584..92b68858 100644 --- a/src/function/erf.rs +++ b/src/function/erf.rs @@ -2,7 +2,6 @@ //! related functions use crate::function::evaluate; -use crate::is_zero; use std::f64; /// `erf` calculates the error function at `x`. @@ -13,7 +12,7 @@ pub fn erf(x: f64) -> f64 { 1.0 } else if x <= 0.0 && x.is_infinite() { -1.0 - } else if is_zero(x) { + } else if x == 0.0 { 0.0 } else { erf_impl(x, false) diff --git a/src/function/exponential.rs b/src/function/exponential.rs index 1124c55c..55280d7c 100644 --- a/src/function/exponential.rs +++ b/src/function/exponential.rs @@ -14,26 +14,18 @@ use crate::{consts, Result, StatsError}; /// # Remarks /// /// This implementation follows the derivation in -///
-///
-/// "Handbook of Mathematical Functions, Applied Mathematics Series, Volume -/// 55" - Abramowitz, M., and Stegun, I.A 1964 -///
+/// +/// _"Handbook of Mathematical Functions, Applied Mathematics Series, Volume +/// 55"_ - Abramowitz, M., and Stegun, I.A 1964 +/// /// AND -///
-///
-/// "Advanced mathematical methods for scientists and engineers" - Bender, -/// Carl M.; Steven A. Orszag (1978). page 253 -///
-///
-/// The continued fraction approac is used for `x > 1.0` while the taylor -/// series expansions -/// is used for `0.0 < x <= 1` /// -/// # Examples +/// _"Advanced mathematical methods for scientists and engineers"_ - Bender, +/// Carl M.; Steven A. Orszag (1978). page 253 /// -/// ``` -/// ``` +/// The continued fraction approach is used for `x > 1.0` while the taylor +/// series expansions is used for `0.0 < x <= 1`. +// TODO: Add examples pub fn integral(x: f64, n: u64) -> Result { let eps = 0.00000000000000001; let max_iter = 100; diff --git a/src/function/factorial.rs b/src/function/factorial.rs index eac59bcf..77bcaa5b 100644 --- a/src/function/factorial.rs +++ b/src/function/factorial.rs @@ -4,7 +4,6 @@ use crate::error::StatsError; use crate::function::gamma; use crate::Result; -use core::f64::INFINITY as INF; /// The maximum factorial representable /// by a 64-bit floating point without @@ -20,7 +19,7 @@ pub const MAX_FACTORIAL: usize = 170; /// Returns `f64::INFINITY` if `x > 170` pub fn factorial(x: u64) -> f64 { let x = x as usize; - FCACHE.get(x).map_or(INF, |&fac| fac) + FCACHE.get(x).map_or(f64::INFINITY, |&fac| fac) } /// Computes the logarithmic factorial function `x -> ln(x!)` @@ -91,26 +90,34 @@ pub fn checked_multinomial(n: u64, ni: &[u64]) -> Result { // Initialization for pre-computed cache of 171 factorial // values 0!...170! -lazy_static! { - static ref FCACHE: [f64; MAX_FACTORIAL + 1] = { - let mut fcache = [1.0; MAX_FACTORIAL + 1]; - fcache - .iter_mut() - .enumerate() - .skip(1) - .fold(1.0, |acc, (i, elt)| { - let fac = acc * i as f64; - *elt = fac; - fac - }); - fcache - }; -} +const FCACHE: [f64; MAX_FACTORIAL + 1] = { + let mut fcache = [1.0; MAX_FACTORIAL + 1]; + + // `const` only allow while loops + let mut i = 1; + while i < MAX_FACTORIAL + 1 { + fcache[i] = fcache[i - 1] * i as f64; + i += 1; + } + + fcache +}; #[cfg(test)] mod tests { use super::*; + #[test] + fn test_fcache() { + assert!((FCACHE[0] - 1.0).abs() < f64::EPSILON); + assert!((FCACHE[1] - 1.0).abs() < f64::EPSILON); + assert!((FCACHE[2] - 2.0).abs() < f64::EPSILON); + assert!((FCACHE[3] - 6.0).abs() < f64::EPSILON); + assert!((FCACHE[4] - 24.0).abs() < f64::EPSILON); + assert!((FCACHE[70] - 1197857166996989e85).abs() < f64::EPSILON); + assert!((FCACHE[170] - 7257415615307994e291).abs() < f64::EPSILON); + } + #[test] fn test_factorial_and_ln_factorial() { let mut fac = 1.0; @@ -124,8 +131,8 @@ mod tests { #[test] fn test_factorial_overflow() { - assert_eq!(factorial(172), INF); - assert_eq!(factorial(u64::MAX), INF); + assert_eq!(factorial(172), f64::INFINITY); + assert_eq!(factorial(u64::MAX), f64::INFINITY); } #[test] diff --git a/src/function/gamma.rs b/src/function/gamma.rs index 9d5124f9..ced63871 100644 --- a/src/function/gamma.rs +++ b/src/function/gamma.rs @@ -3,7 +3,6 @@ use crate::consts; use crate::error::StatsError; -use crate::is_zero; use crate::prec; use crate::Result; use std::f64; @@ -216,7 +215,7 @@ pub fn checked_gamma_ur(a: f64, x: f64) -> Result { qkm1 *= big_inv; } - if !is_zero(qk) { + if qk != 0.0 { let r = pk / qk; let t = ((ans - r) / r).abs(); ans = r; diff --git a/src/generate.rs b/src/generate.rs index 1f6102d2..e834c6c5 100644 --- a/src/generate.rs +++ b/src/generate.rs @@ -29,6 +29,7 @@ pub fn log_spaced(length: usize, start_exp: f64, stop_exp: f64) -> Vec { } /// Infinite iterator returning floats that form a periodic wave +#[derive(Clone, Copy, PartialEq, Debug)] pub struct InfinitePeriodic { amplitude: f64, step: f64, @@ -80,6 +81,12 @@ impl InfinitePeriodic { } } +impl std::fmt::Display for InfinitePeriodic { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:#?}", self) + } +} + impl Iterator for InfinitePeriodic { type Item = f64; @@ -96,6 +103,7 @@ impl Iterator for InfinitePeriodic { } /// Infinite iterator returning floats that form a sinusoidal wave +#[derive(Debug, Clone, Copy, PartialEq)] pub struct InfiniteSinusoidal { amplitude: f64, mean: f64, @@ -159,6 +167,12 @@ impl InfiniteSinusoidal { } } +impl std::fmt::Display for InfiniteSinusoidal { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:#?}", &self) + } +} + impl Iterator for InfiniteSinusoidal { type Item = f64; @@ -175,6 +189,7 @@ impl Iterator for InfiniteSinusoidal { /// Infinite iterator returning floats forming a square wave starting /// with the high phase +#[derive(Debug, Clone, Copy, PartialEq)] pub struct InfiniteSquare { periodic: InfinitePeriodic, high_duration: f64, @@ -212,6 +227,12 @@ impl InfiniteSquare { } } +impl std::fmt::Display for InfiniteSquare { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:#?}", &self) + } +} + impl Iterator for InfiniteSquare { type Item = f64; @@ -228,6 +249,7 @@ impl Iterator for InfiniteSquare { /// Infinite iterator returning floats forming a triangle wave starting with /// the raise phase from the lowest sample +#[derive(Debug, Clone, Copy, PartialEq)] pub struct InfiniteTriangle { periodic: InfinitePeriodic, raise_duration: f64, @@ -278,6 +300,12 @@ impl InfiniteTriangle { } } +impl std::fmt::Display for InfiniteTriangle { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:#?}", &self) + } +} + impl Iterator for InfiniteTriangle { type Item = f64; @@ -294,6 +322,7 @@ impl Iterator for InfiniteTriangle { /// Infinite iterator returning floats forming a sawtooth wave /// starting with the lowest sample +#[derive(Debug, Clone, Copy, PartialEq)] pub struct InfiniteSawtooth { periodic: InfinitePeriodic, low_value: f64, @@ -323,11 +352,17 @@ impl InfiniteSawtooth { 0.0, delay, ), - low_value: low_value as f64, + low_value, } } } +impl std::fmt::Display for InfiniteSawtooth { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{:#?}", &self) + } +} + impl Iterator for InfiniteSawtooth { type Item = f64; diff --git a/src/lib.rs b/src/lib.rs index 9a9d0a70..7ca3157c 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,42 +5,57 @@ //! Math.NET in so far as they are used in the computation of distribution //! values. This crate depends on the `rand` crate to provide RNG. //! -//! # Example -//! The following example samples from a standard normal distribution -//! +//! # Sampling +//! The common use case is to set up the distributions and sample from them which depends on the `Rand` crate for random number generation. //! ``` -//! # extern crate rand; -//! # extern crate statrs; +//! use statrs::distribution::Exp; //! use rand::distributions::Distribution; -//! use statrs::distribution::Normal; +//! let mut r = rand::rngs::OsRng; +//! let n = Exp::new(0.5).unwrap(); +//! print!("{}", n.sample(&mut r)); +//! ``` +//! +//! # Introspecting distributions +//! Statrs also comes with a number of useful utility traits for more detailed introspection of distributions. +//! ``` +//! use statrs::distribution::{Exp, Continuous, ContinuousCDF}; // `cdf` and `pdf` +//! use statrs::statistics::Distribution; // statistical moments and entropy +//! +//! let n = Exp::new(1.0).unwrap(); +//! assert_eq!(n.mean(), Some(1.0)); +//! assert_eq!(n.variance(), Some(1.0)); +//! assert_eq!(n.entropy(), Some(1.0)); +//! assert_eq!(n.skewness(), Some(2.0)); +//! assert_eq!(n.cdf(1.0), 0.6321205588285576784045); +//! assert_eq!(n.pdf(1.0), 0.3678794411714423215955); +//! ``` +//! +//! # Utility functions +//! as well as utility functions including `erf`, `gamma`, `ln_gamma`, `beta`, etc. //! -//! # fn main() { -//! let mut r = rand::thread_rng(); -//! let n = Normal::new(0.0, 1.0).unwrap(); -//! for _ in 0..10 { -//! print!("{}", n.sample(&mut r)); -//! } -//! # } //! ``` +//! use statrs::distribution::FisherSnedecor; +//! use statrs::statistics::Distribution; +//! +//! let n = FisherSnedecor::new(1.0, 1.0).unwrap(); +//! assert!(n.variance().is_none()); +//! ``` +//! ## Distributions implemented +//! Statrs comes with a number of commonly used distributions including Normal, Gamma, Student's T, Exponential, Weibull, etc. view all implemented in `distributions` module. #![crate_type = "lib"] #![crate_name = "statrs"] #![allow(clippy::excessive_precision)] #![allow(clippy::many_single_char_names)] -#![allow(unused_imports)] #![forbid(unsafe_code)] -#![cfg_attr(all(test, feature = "nightly"), feature(unboxed_closures))] -#![cfg_attr(all(test, feature = "nightly"), feature(fn_traits))] +#![cfg_attr(coverage_nightly, feature(coverage_attribute))] #[macro_use] extern crate approx; -#[macro_use] -extern crate lazy_static; - #[macro_export] macro_rules! assert_almost_eq { - ($a:expr, $b:expr, $prec:expr) => { + ($a:expr, $b:expr, $prec:expr $(,)?) => { if !$crate::prec::almost_eq($a, $b, $prec) { panic!( "assertion failed: `abs(left - right) < {:e}`, (left: `{}`, right: `{}`)", @@ -58,18 +73,10 @@ pub mod function; pub mod generate; pub mod prec; pub mod statistics; +pub mod stats_tests; mod error; -// function to silence clippy on the special case when comparing to zero. -#[inline(always)] -pub(crate) fn is_zero(x: f64) -> bool { - ulps_eq!(x, 0.0, max_ulps = 0) -} - -// #[cfg(test)] -mod testing; - pub use crate::error::StatsError; /// Result type for the statrs library package that returns diff --git a/src/prec.rs b/src/prec.rs index 59ad3714..fc9a5836 100644 --- a/src/prec.rs +++ b/src/prec.rs @@ -1,5 +1,7 @@ //! Provides utility functions for working with floating point precision +use approx::AbsDiffEq; + /// Standard epsilon, maximum relative precision of IEEE 754 double-precision /// floating point numbers (64 bit) e.g. `2^-53` pub const F64_PREC: f64 = 0.00000000000000011102230246251565; @@ -7,21 +9,20 @@ pub const F64_PREC: f64 = 0.00000000000000011102230246251565; /// Default accuracy for `f64`, equivalent to `0.0 * F64_PREC` pub const DEFAULT_F64_ACC: f64 = 0.0000000000000011102230246251565; -/// Returns true if `a` and `b `are within `acc` of each other. -/// If `a` or `b` are infinite, returns `true` only if both are -/// infinite and similarly signed. Always returns `false` if -/// either number is a `NAN`. +/// Compares if two floats are close via `approx::abs_diff_eq` +/// using a maximum absolute difference (epsilon) of `acc`. pub fn almost_eq(a: f64, b: f64, acc: f64) -> bool { - // only true if a and b are infinite with same - // sign - if a.is_infinite() || b.is_infinite() { + if a.is_infinite() && b.is_infinite() { return a == b; } + a.abs_diff_eq(&b, acc) +} - // NANs are never equal - if a.is_nan() && b.is_nan() { - return false; - } - - (a - b).abs() < acc +/// Compares if two floats are close via `approx::relative_eq!` +/// and `crate::consts::ACC` relative precision. +/// Updates first argument to value of second argument +pub fn convergence(x: &mut f64, x_new: f64) -> bool { + let res = approx::relative_eq!(*x, x_new, max_relative = crate::consts::ACC); + *x = x_new; + res } diff --git a/src/statistics/iter_statistics.rs b/src/statistics/iter_statistics.rs index 3a53d835..97b7fd17 100644 --- a/src/statistics/iter_statistics.rs +++ b/src/statistics/iter_statistics.rs @@ -244,133 +244,8 @@ where #[cfg(test)] mod tests { use std::f64::consts; - use rand::rngs::StdRng; - use rand::{SeedableRng}; - use rand::distributions::Distribution; - use crate::distribution::Normal; use crate::statistics::Statistics; use crate::generate::{InfinitePeriodic, InfiniteSinusoidal}; - use crate::testing; - - #[test] - fn test_mean() { - let mut data = testing::load_data("nist/lottery.txt"); - assert_almost_eq!((&data).mean(), 518.958715596330, 1e-12); - - data = testing::load_data("nist/lew.txt"); - assert_almost_eq!((&data).mean(), -177.435000000000, 1e-13); - - data = testing::load_data("nist/mavro.txt"); - assert_almost_eq!((&data).mean(), 2.00185600000000, 1e-15); - - data = testing::load_data("nist/michaelso.txt"); - assert_almost_eq!((&data).mean(), 299.852400000000, 1e-13); - - data = testing::load_data("nist/numacc1.txt"); - assert_eq!((&data).mean(), 10000002.0); - - data = testing::load_data("nist/numacc2.txt"); - assert_almost_eq!((&data).mean(), 1.2, 1e-15); - - data = testing::load_data("nist/numacc3.txt"); - assert_eq!((&data).mean(), 1000000.2); - - data = testing::load_data("nist/numacc4.txt"); - assert_almost_eq!((&data).mean(), 10000000.2, 1e-8); - } - - #[test] - fn test_std_dev() { - let mut data = testing::load_data("nist/lottery.txt"); - assert_almost_eq!((&data).std_dev(), 291.699727470969, 1e-13); - - data = testing::load_data("nist/lew.txt"); - assert_almost_eq!((&data).std_dev(), 277.332168044316, 1e-12); - - data = testing::load_data("nist/mavro.txt"); - assert_almost_eq!((&data).std_dev(), 0.000429123454003053, 1e-15); - - data = testing::load_data("nist/michaelso.txt"); - assert_almost_eq!((&data).std_dev(), 0.0790105478190518, 1e-13); - - data = testing::load_data("nist/numacc1.txt"); - assert_eq!((&data).std_dev(), 1.0); - - data = testing::load_data("nist/numacc2.txt"); - assert_almost_eq!((&data).std_dev(), 0.1, 1e-16); - - data = testing::load_data("nist/numacc3.txt"); - assert_almost_eq!((&data).std_dev(), 0.1, 1e-10); - - data = testing::load_data("nist/numacc4.txt"); - assert_almost_eq!((&data).std_dev(), 0.1, 1e-9); - } - - #[test] - fn test_min_max_short() { - let data = [-1.0, 5.0, 0.0, -3.0, 10.0, -0.5, 4.0]; - assert_eq!(data.min(), -3.0); - assert_eq!(data.max(), 10.0); - } - - #[test] - fn test_mean_variance_stability() { - let seed = [ - 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, - 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31 - ]; - let mut rng: StdRng = SeedableRng::from_seed(seed); - let normal = Normal::new(1e9, 2.0).unwrap(); - let samples = (0..10000).map(|_| normal.sample::(&mut rng)).collect::>(); - assert_almost_eq!((&samples).mean(), 1e9, 10.0); - assert_almost_eq!((&samples).variance(), 4.0, 0.1); - assert_almost_eq!((&samples).std_dev(), 2.0, 0.01); - assert_almost_eq!((&samples).quadratic_mean(), 1e9, 10.0); - } - - #[test] - fn test_covariance_consistent_with_variance() { - let mut data = testing::load_data("nist/lottery.txt"); - assert_almost_eq!((&data).variance(), (&data).covariance(&data), 1e-10); - - data = testing::load_data("nist/lew.txt"); - assert_almost_eq!((&data).variance(), (&data).covariance(&data), 1e-10); - - data = testing::load_data("nist/mavro.txt"); - assert_almost_eq!((&data).variance(), (&data).covariance(&data), 1e-10); - - data = testing::load_data("nist/michaelso.txt"); - assert_almost_eq!((&data).variance(), (&data).covariance(&data), 1e-10); - - data = testing::load_data("nist/numacc1.txt"); - assert_almost_eq!((&data).variance(), (&data).covariance(&data), 1e-10); - } - - #[test] - fn test_pop_covar_consistent_with_pop_var() { - let mut data = testing::load_data("nist/lottery.txt"); - assert_almost_eq!((&data).population_variance(), (&data).population_covariance(&data), 1e-10); - - data = testing::load_data("nist/lew.txt"); - assert_almost_eq!((&data).population_variance(), (&data).population_covariance(&data), 1e-10); - - data = testing::load_data("nist/mavro.txt"); - assert_almost_eq!((&data).population_variance(), (&data).population_covariance(&data), 1e-10); - - data = testing::load_data("nist/michaelso.txt"); - assert_almost_eq!((&data).population_variance(), (&data).population_covariance(&data), 1e-10); - - data = testing::load_data("nist/numacc1.txt"); - assert_almost_eq!((&data).population_variance(), (&data).population_covariance(&data), 1e-10); - } - - #[test] - fn test_covariance_is_symmetric() { - let data_a = &testing::load_data("nist/lottery.txt")[0..200]; - let data_b = &testing::load_data("nist/lew.txt")[0..200]; - assert_almost_eq!(data_a.covariance(data_b), data_b.covariance(data_a), 1e-10); - assert_almost_eq!(data_a.population_covariance(data_b), data_b.population_covariance(data_a), 1e-11); - } #[test] fn test_empty_data_returns_nan() { diff --git a/src/statistics/mod.rs b/src/statistics/mod.rs index b156c6ec..272091a4 100644 --- a/src/statistics/mod.rs +++ b/src/statistics/mod.rs @@ -1,6 +1,5 @@ //! Provides traits for statistical computation -pub use self::iter_statistics::*; pub use self::order_statistics::*; pub use self::slice_statistics::*; pub use self::statistics::*; @@ -10,5 +9,6 @@ mod iter_statistics; mod order_statistics; // TODO: fix later mod slice_statistics; +#[allow(clippy::module_inception)] mod statistics; mod traits; diff --git a/src/statistics/slice_statistics.rs b/src/statistics/slice_statistics.rs index 1d1b79cc..b0b1d2a7 100644 --- a/src/statistics/slice_statistics.rs +++ b/src/statistics/slice_statistics.rs @@ -1,12 +1,37 @@ use crate::statistics::*; use core::ops::{Index, IndexMut}; -use rand::prelude::SliceRandom; -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, PartialEq, Eq, PartialOrd, Ord, Debug)] pub struct Data(D); +impl std::fmt::Display for Data +where + D: Clone + IntoIterator, + I: Clone + std::fmt::Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let mut tee = self.0.clone().into_iter(); + write!(f, "Data([")?; + + if let Some(v) = tee.next() { + write!(f, "{}", v)?; + } + for _ in 1..5 { + if let Some(v) = tee.next() { + write!(f, ", {}", v)?; + } + } + if tee.next().is_some() { + write!(f, "...")?; + } + + write!(f, "])") + } +} + impl> Index for Data { type Output = f64; + fn index(&self, i: usize) -> &f64 { &self.0.as_ref()[i] } @@ -22,18 +47,23 @@ impl + AsRef<[f64]>> Data { pub fn new(data: D) -> Self { Data(data) } + pub fn swap(&mut self, i: usize, j: usize) { self.0.as_mut().swap(i, j) } + pub fn len(&self) -> usize { self.0.as_ref().len() } + pub fn is_empty(&self) -> bool { self.0.as_ref().len() == 0 } + pub fn iter(&self) -> core::slice::Iter<'_, f64> { self.0.as_ref().iter() } + // Selection algorithm from Numerical Recipes // See: https://en.wikipedia.org/wiki/Selection_algorithm fn select_inplace(&mut self, rank: usize) -> f64 { @@ -102,8 +132,11 @@ impl + AsRef<[f64]>> Data { } } +#[cfg(feature = "rand")] impl> ::rand::distributions::Distribution for Data { fn sample(&self, rng: &mut R) -> f64 { + use rand::prelude::SliceRandom; + *self.0.as_ref().choose(rng).unwrap() } } @@ -299,6 +332,7 @@ impl + AsRef<[f64]>> Distribution for Data { fn mean(&self) -> Option { Some(Statistics::mean(self.iter())) } + /// Estimates the unbiased population variance from the provided samples /// /// # Remarks diff --git a/src/statistics/statistics.rs b/src/statistics/statistics.rs index 40b91c72..3081791b 100644 --- a/src/statistics/statistics.rs +++ b/src/statistics/statistics.rs @@ -1,6 +1,6 @@ /// Enumeration of possible tie-breaking strategies /// when computing ranks -#[derive(Debug, Copy, Clone)] +#[derive(Copy, Clone, Debug)] pub enum RankTieBreaker { /// Replaces ties with their mean Average, diff --git a/src/statistics/traits.rs b/src/statistics/traits.rs index d264d719..9140eab4 100644 --- a/src/statistics/traits.rs +++ b/src/statistics/traits.rs @@ -1,10 +1,5 @@ -use ::nalgebra::{ - base::allocator::Allocator, base::dimension::DimName, DefaultAllocator, Dim, DimMin, U1, -}; use ::num_traits::float::Float; -const STEPS: usize = 1_000; - /// The `Min` trait specifies than an object has a minimum value pub trait Min { /// Returns the minimum value in the domain of a given distribution @@ -38,7 +33,7 @@ pub trait Max { /// ``` fn max(&self) -> T; } -pub trait DiscreteDistribution: ::rand::distributions::Distribution { +pub trait DiscreteDistribution { /// Returns the mean, if it exists. fn mean(&self) -> Option { None @@ -61,14 +56,8 @@ pub trait DiscreteDistribution: ::rand::distributions::Distribution: ::rand::distributions::Distribution { +pub trait Distribution { /// Returns the mean, if it exists. - /// The default implementation returns an estimation - /// based on random samples. This is a crude estimate - /// for when no further information is known about the - /// distribution. More accurate statements about the - /// mean can and should be given by overriding the - /// default implementation. /// /// # Examples /// @@ -80,23 +69,9 @@ pub trait Distribution: ::rand::distributions::Distribution { /// assert_eq!(0.5, n.mean().unwrap()); /// ``` fn mean(&self) -> Option { - // TODO: Does not need cryptographic rng - let mut rng = ::rand::rngs::OsRng; - let mut mean = T::zero(); - let mut steps = T::zero(); - for _ in 0..STEPS { - steps = steps + T::one(); - mean = mean + Self::sample(self, &mut rng); - } - Some(mean / steps) + None } /// Returns the variance, if it exists. - /// The default implementation returns an estimation - /// based on random samples. This is a crude estimate - /// for when no further information is known about the - /// distribution. More accurate statements about the - /// variance can and should be given by overriding the - /// default implementation. /// /// # Examples /// @@ -108,19 +83,7 @@ pub trait Distribution: ::rand::distributions::Distribution { /// assert_eq!(1.0 / 12.0, n.variance().unwrap()); /// ``` fn variance(&self) -> Option { - // TODO: Does not need cryptographic rng - let mut rng = ::rand::rngs::OsRng; - let mut mean = T::zero(); - let mut variance = T::zero(); - let mut steps = T::zero(); - for _ in 0..STEPS { - steps = steps + T::one(); - let sample = Self::sample(self, &mut rng); - variance = variance + (steps - T::one()) * (sample - mean) * (sample - mean) / steps; - mean = mean + (sample - mean) / steps; - } - steps = steps - T::one(); - Some(variance / steps) + None } /// Returns the standard deviation, if it exists. /// diff --git a/src/stats_tests/fisher.rs b/src/stats_tests/fisher.rs new file mode 100644 index 00000000..909b4e7b --- /dev/null +++ b/src/stats_tests/fisher.rs @@ -0,0 +1,397 @@ +use super::Alternative; +use crate::distribution::{Discrete, DiscreteCDF, Hypergeometric, HypergeometricError}; + +const EPSILON: f64 = 1.0 - 1e-4; + +/// Binary search in two-sided test with starting bound as argument +fn binary_search( + n: u64, + n1: u64, + n2: u64, + mode: u64, + p_exact: f64, + epsilon: f64, + upper: bool, +) -> u64 { + let (mut min_val, mut max_val) = { + if upper { + (mode, n) + } else { + (0, mode) + } + }; + + let population = n1 + n2; + let successes = n1; + let draws = n; + let dist = Hypergeometric::new(population, successes, draws).unwrap(); + + let mut guess = 0; + loop { + if max_val - min_val <= 1 { + break; + } + guess = { + if max_val == min_val + 1 && guess == min_val { + max_val + } else { + (max_val + min_val) / 2 + } + }; + + let ng = { + if upper { + guess - 1 + } else { + guess + 1 + } + }; + + let pmf_comp = dist.pmf(ng); + let p_guess = dist.pmf(guess); + if p_guess <= p_exact && p_exact < pmf_comp { + break; + } + if p_guess < p_exact { + max_val = guess + } else { + min_val = guess + } + } + + if guess == 0 { + guess = min_val + } + if upper { + loop { + if guess > 0 && dist.pmf(guess) < p_exact * epsilon { + guess -= 1; + } else { + break; + } + } + loop { + if dist.pmf(guess) > p_exact / epsilon { + guess += 1; + } else { + break; + } + } + } else { + loop { + if dist.pmf(guess) < p_exact * epsilon { + guess += 1; + } else { + break; + } + } + loop { + if guess > 0 && dist.pmf(guess) > p_exact / epsilon { + guess -= 1; + } else { + break; + } + } + } + guess +} + +#[derive(Copy, Clone, PartialEq, Eq, Debug, Hash)] +#[non_exhaustive] +pub enum FishersExactTestError { + /// The table does not describe a valid [`Hypergeometric`] distribution. + /// Make sure that the contingency table stores the data in row-major order. + TableInvalidForHypergeometric(HypergeometricError), +} + +impl std::fmt::Display for FishersExactTestError { + #[cfg_attr(coverage_nightly, coverage(off))] + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + match self { + FishersExactTestError::TableInvalidForHypergeometric(hg_err) => { + writeln!(f, "Cannot create a Hypergeometric distribution from the data in the contingency table.")?; + writeln!(f, "Is it in row-major order?")?; + write!(f, "Inner error: '{}'", hg_err) + } + } + } +} + +impl std::error::Error for FishersExactTestError {} + +impl From for FishersExactTestError { + fn from(value: HypergeometricError) -> Self { + Self::TableInvalidForHypergeometric(value) + } +} + +/// Perform a Fisher exact test on a 2x2 contingency table. +/// Based on scipy's fisher test: +/// Expects a table in row-major order +/// Returns the [odds ratio](https://en.wikipedia.org/wiki/Odds_ratio) and p_value +/// # Examples +/// +/// ``` +/// use statrs::stats_tests::fishers_exact_with_odds_ratio; +/// use statrs::stats_tests::Alternative; +/// let table = [3, 5, 4, 50]; +/// let (odds_ratio, p_value) = fishers_exact_with_odds_ratio(&table, Alternative::Less).unwrap(); +/// ``` +pub fn fishers_exact_with_odds_ratio( + table: &[u64; 4], + alternative: Alternative, +) -> Result<(f64, f64), FishersExactTestError> { + // If both values in a row or column are zero, p-value is 1 and odds ratio is NaN. + match table { + [0, _, 0, _] | [_, 0, _, 0] => return Ok((f64::NAN, 1.0)), // both 0 in a row + [0, 0, _, _] | [_, _, 0, 0] => return Ok((f64::NAN, 1.0)), // both 0 in a column + _ => (), // continue + } + + let odds_ratio = { + if table[1] > 0 && table[2] > 0 { + (table[0] * table[3]) as f64 / (table[1] * table[2]) as f64 + } else { + f64::INFINITY + } + }; + + let p_value = fishers_exact(table, alternative)?; + Ok((odds_ratio, p_value)) +} + +/// Perform a Fisher exact test on a 2x2 contingency table. +/// Based on scipy's fisher test: +/// Expects a table in row-major order +/// Returns only the p_value +/// # Examples +/// +/// ``` +/// use statrs::stats_tests::fishers_exact; +/// use statrs::stats_tests::Alternative; +/// let table = [3, 5, 4, 50]; +/// let p_value = fishers_exact(&table, Alternative::Less).unwrap(); +/// ``` +pub fn fishers_exact( + table: &[u64; 4], + alternative: Alternative, +) -> Result { + // If both values in a row or column are zero, the p-value is 1 and the odds ratio is NaN. + match table { + [0, _, 0, _] | [_, 0, _, 0] => return Ok(1.0), // both 0 in a row + [0, 0, _, _] | [_, _, 0, 0] => return Ok(1.0), // both 0 in a column + _ => (), // continue + } + + let n1 = table[0] + table[1]; + let n2 = table[2] + table[3]; + let n = table[0] + table[2]; + + let p_value = { + let population = n1 + n2; + let successes = n1; + + match alternative { + Alternative::Less => { + let draws = n; + let dist = Hypergeometric::new(population, successes, draws)?; + dist.cdf(table[0]) + } + Alternative::Greater => { + let draws = table[1] + table[3]; + let dist = Hypergeometric::new(population, successes, draws)?; + dist.cdf(table[1]) + } + Alternative::TwoSided => { + let draws = n; + let dist = Hypergeometric::new(population, successes, draws)?; + + let p_exact = dist.pmf(table[0]); + let mode = ((n + 1) * (n1 + 1)) / (n1 + n2 + 2); + let p_mode = dist.pmf(mode); + + if (p_exact - p_mode).abs() / p_exact.max(p_mode) <= 1.0 - EPSILON { + return Ok(1.0); + } + + if table[0] < mode { + let p_lower = dist.cdf(table[0]); + if dist.pmf(n) > p_exact / EPSILON { + return Ok(p_lower); + } + let guess = binary_search(n, n1, n2, mode, p_exact, EPSILON, true); + return Ok(p_lower + 1.0 - dist.cdf(guess - 1)); + } + + let p_upper = 1.0 - dist.cdf(table[0] - 1); + if dist.pmf(0) > p_exact / EPSILON { + return Ok(p_upper); + } + + let guess = binary_search(n, n1, n2, mode, p_exact, EPSILON, false); + p_upper + dist.cdf(guess) + } + } + }; + + Ok(p_value.min(1.0)) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::prec; + + /// Test fishers_exact by comparing against values from scipy. + #[test] + fn test_fishers_exact() { + let cases = [ + ( + [3, 5, 4, 50], + 0.9963034765672599, + 0.03970749246529277, + 0.03970749246529276, + ), + ( + [61, 118, 2, 1], + 0.27535061623455315, + 0.9598172545684959, + 0.27535061623455315, + ), + ( + [172, 46, 90, 127], + 1.0, + 6.662405187351769e-16, + 9.041009036528785e-16, + ), + ( + [127, 38, 112, 43], + 0.8637599357870167, + 0.20040942958644145, + 0.3687862842650179, + ), + ( + [186, 177, 111, 154], + 0.9918518696328176, + 0.012550663906725129, + 0.023439141644624434, + ), + ( + [137, 49, 135, 183], + 0.999999999998533, + 5.6517533666400615e-12, + 8.870999836202932e-12, + ), + ( + [37, 115, 37, 152], + 0.8834621182590621, + 0.17638403366123565, + 0.29400927608021704, + ), + ( + [124, 117, 119, 175], + 0.9956704915461392, + 0.007134712391455461, + 0.011588218284387445, + ), + ( + [70, 114, 41, 118], + 0.9945558498544903, + 0.010384865876586255, + 0.020438291037108678, + ), + ( + [173, 21, 89, 7], + 0.2303739114068352, + 0.8808002774812677, + 0.4027047267306024, + ), + ( + [18, 147, 123, 58], + 4.077820702304103e-29, + 0.9999999999999817, + 0.0, + ), + ( + [116, 20, 92, 186], + 0.9999999999998267, + 6.598118571034892e-25, + 8.164831402188242e-25, + ), + ( + [9, 22, 44, 38], + 0.01584272038710196, + 0.9951463496539362, + 0.021581786662999272, + ), + ( + [9, 101, 135, 7], + 3.3336213533847776e-50, + 1.0, + 3.3336213533847776e-50, + ), + ( + [153, 27, 191, 144], + 0.9999999999950817, + 2.473736787266208e-11, + 3.185816623300107e-11, + ), + ( + [111, 195, 189, 69], + 6.665245982898848e-19, + 0.9999999999994574, + 1.0735744915712542e-18, + ), + ( + [125, 21, 31, 131], + 0.99999999999974, + 9.720661317939016e-34, + 1.0352129312860277e-33, + ), + ( + [201, 192, 69, 179], + 0.9999999988714893, + 3.1477232259550017e-09, + 4.761075937088169e-09, + ), + ( + [124, 138, 159, 160], + 0.30153826772785475, + 0.7538974235759873, + 0.5601766196310243, + ), + ]; + + for (table, less_expected, greater_expected, two_sided_expected) in cases.iter() { + for (alternative, expected) in [ + Alternative::Less, + Alternative::Greater, + Alternative::TwoSided, + ] + .iter() + .zip(vec![less_expected, greater_expected, two_sided_expected]) + { + let p_value = fishers_exact(table, *alternative).unwrap(); + assert!(prec::almost_eq(p_value, *expected, 1e-12)); + } + } + } + + #[test] + fn test_fishers_exact_for_trivial() { + let cases = [[0, 0, 1, 2], [1, 2, 0, 0], [1, 0, 2, 0], [0, 1, 0, 2]]; + + for table in cases.iter() { + assert_eq!(fishers_exact(table, Alternative::Less).unwrap(), 1.0) + } + } + + #[test] + fn test_fishers_exact_with_odds() { + let table = [3, 5, 4, 50]; + let (odds_ratio, p_value) = + fishers_exact_with_odds_ratio(&table, Alternative::Less).unwrap(); + assert!(prec::almost_eq(p_value, 0.9963034765672599, 1e-12)); + assert!(prec::almost_eq(odds_ratio, 7.5, 1e-1)); + } +} diff --git a/src/stats_tests/mod.rs b/src/stats_tests/mod.rs new file mode 100644 index 00000000..84a01fc7 --- /dev/null +++ b/src/stats_tests/mod.rs @@ -0,0 +1,17 @@ +pub mod fisher; + +/// Specifies an [alternative hypothesis](https://en.wikipedia.org/wiki/Alternative_hypothesis) +#[derive(Debug, Copy, Clone)] +pub enum Alternative { + #[doc(alias = "two-tailed")] + #[doc(alias = "two tailed")] + TwoSided, + #[doc(alias = "one-tailed")] + #[doc(alias = "one tailed")] + Less, + #[doc(alias = "one-tailed")] + #[doc(alias = "one tailed")] + Greater, +} + +pub use fisher::{fishers_exact, fishers_exact_with_odds_ratio}; diff --git a/src/testing/mod.rs b/src/testing/mod.rs deleted file mode 100644 index 45eb3cab..00000000 --- a/src/testing/mod.rs +++ /dev/null @@ -1,32 +0,0 @@ -//! Provides testing helpers and utilities - -use std::fs::File; -use std::io::{BufRead, BufReader}; -use std::str; - -/// Loads a test data file into a vector of `f64`'s. -/// Path is relative to /data. -/// -/// # Panics -/// -/// Panics if the file does not exist or could not be opened, or -/// there was an error reading the file. -#[cfg(test)] -pub fn load_data(path: &str) -> Vec { - // note: the copious use of unwrap is because this is a test helper and - // if reading the data file fails, we want to panic immediately - - let path_prefix = "./data/".to_string(); - let true_path = path_prefix + path.trim().trim_start_matches('/'); - - let f = File::open(true_path).unwrap(); - let mut reader = BufReader::new(f); - - let mut buf = String::new(); - let mut data: Vec = vec![]; - while reader.read_line(&mut buf).unwrap() > 0 { - data.push(buf.trim().parse::().unwrap()); - buf.clear(); - } - data -} diff --git a/tests/gather_nist_data.sh b/tests/gather_nist_data.sh new file mode 100755 index 00000000..2b663734 --- /dev/null +++ b/tests/gather_nist_data.sh @@ -0,0 +1,36 @@ +#! /bin/bash +# this script is to download and preprocess datafiles for the nist_tests.rs +# integration test for statrs downloads data to directory specified by env +# var STATRS_NIST_DATA_DIR + +process_file() { + # Define input and output file names + SOURCE=$1 + FILENAME=$2 + TARGET=${STATRS_NIST_DATA_DIR-tests}/${FILENAME} + echo -e ${FILENAME} '\n\tDownloading...' + curl -fsSL ${SOURCE}/$FILENAME > ${TARGET} + + # Extract line numbers for Certified Values and Data from the header + INFO=$(grep "Certified Values:" $TARGET) + CERTIFIED_VALUES_START=$(echo $INFO | awk '{print $4}') + CERTIFIED_VALUES_END=$(echo $INFO | awk '{print $6}') + + INFO=$(grep "Data :" $TARGET) + DATA_START=$(echo $INFO | awk '{print $4}') + DATA_END=$(echo $INFO | awk '{print $6}') + + echo -e '\tFormatting...' + # Extract and reformat sections + sed -n -i \ + -e "${CERTIFIED_VALUES_START},${CERTIFIED_VALUES_END}p" \ + -e "${DATA_START},${DATA_END}p" \ + $TARGET +} + +URL='https://www.itl.nist.gov/div898/strd/univ/data' +for file in Lottery.dat Lew.dat Mavro.dat Michelso.dat NumAcc1.dat NumAcc2.dat NumAcc3.dat +do + process_file $URL $file +done + diff --git a/tests/nist_tests.rs b/tests/nist_tests.rs new file mode 100644 index 00000000..0f731067 --- /dev/null +++ b/tests/nist_tests.rs @@ -0,0 +1,137 @@ +//! This test relies on data that is reusable but not distributable by statrs as +//! such, the data will need to be downloaded from the relevant NIST StRD dataset +//! the parsing for testing assumes data to be of form, +//! ```text +//! sample mean : +//! sample std_dev : +//! sample correlation: +//! [zero or more blank lines] +//! data0 +//! data1 +//! data2 +//! ... +//! ``` +//! This test can be run on it's own from the shell from this folder as +//! ```sh +//! ./gather_nist_data.sh && cargo test -- --ignored nist_ +//! ``` +use anyhow::Result; +use approx::assert_relative_eq; +use statrs::statistics::Statistics; + +use std::io::{BufRead, BufReader}; +use std::path::PathBuf; +use std::{env, fs}; + +struct TestCase { + certified: CertifiedValues, + values: Vec, +} + +impl std::fmt::Debug for TestCase { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "TestCase({:?}, [...]", self.certified) + } +} + +#[derive(Debug)] +struct CertifiedValues { + mean: f64, + std_dev: f64, + corr: f64, +} + +impl std::fmt::Display for CertifiedValues { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "μ={:.3e}, σ={:.3e}, r={:.3e}", + self.mean, self.std_dev, self.corr + ) + } +} + +const NIST_DATA_DIR_ENV: &str = "STATRS_NIST_DATA_DIR"; +const FILENAMES: [&str; 7] = [ + "Lottery.dat", + "Lew.dat", + "Mavro.dat", + "Michelso.dat", + "NumAcc1.dat", + "NumAcc2.dat", + "NumAcc3.dat", +]; + +fn get_path(fname: &str, prefix: Option<&str>) -> PathBuf { + if let Some(prefix) = prefix { + [prefix, fname].iter().collect() + } else { + ["tests", fname].iter().collect() + } +} + +#[test] +#[ignore = "NIST tests should not run from typical `cargo test` calls"] +fn nist_strd_univariate_mean() { + for fname in FILENAMES { + let filepath = get_path(fname, env::var(NIST_DATA_DIR_ENV).ok().as_deref()); + let case = parse_file(filepath) + .unwrap_or_else(|e| panic!("failed parsing file {} with `{:?}`", fname, e)); + assert_relative_eq!(case.values.mean(), case.certified.mean, epsilon = 1e-12); + } +} + +#[test] +#[ignore] +fn nist_strd_univariate_std_dev() { + for fname in FILENAMES { + let filepath = get_path(fname, env::var(NIST_DATA_DIR_ENV).ok().as_deref()); + let case = parse_file(filepath) + .unwrap_or_else(|e| panic!("failed parsing file {} with `{:?}`", fname, e)); + assert_relative_eq!( + case.values.std_dev(), + case.certified.std_dev, + epsilon = 1e-10 + ); + } +} + +fn parse_certified_value(line: String) -> Result { + line.chars() + .skip_while(|&c| c != ':') + .skip(1) // skip through ':' delimiter + .skip_while(|&c| c.is_whitespace()) // effectively `String` trim + .take_while(|&c| matches!(c, '0'..='9' | '-' | '.')) + .collect::() + .parse::() + .map_err(|e| e.into()) +} + +fn parse_file(path: impl AsRef) -> anyhow::Result { + let f = fs::File::open(path)?; + let reader = BufReader::new(f); + let mut lines = reader.lines(); + + let mean = parse_certified_value(lines.next().expect("file should not be exhausted")?)?; + let std_dev = parse_certified_value(lines.next().expect("file should not be exhausted")?)?; + let corr = parse_certified_value(lines.next().expect("file should not be exhausted")?)?; + + Ok(TestCase { + certified: CertifiedValues { + mean, + std_dev, + corr, + }, + values: lines + .map_while(|line| line.ok()?.trim().parse().ok()) + .collect(), + }) +} + +#[test] +#[ignore = "NIST tests should not run from typical `cargo test` calls"] +fn nist_test_covariance_consistent_with_variance() {} + +#[test] +#[ignore = "NIST tests should not run from typical `cargo test` calls"] +fn nist_test_covariance_is_symmetric() {}