From a4fe8704c1c1f387abe00a186d481c934e3229ea Mon Sep 17 00:00:00 2001 From: Fernando Pelliccioni Date: Tue, 4 Feb 2025 20:27:22 +0100 Subject: [PATCH] feat: Extend hypergeometric distribution PMF for non-integral arguments --- .../math/distributions/hypergeometric.hpp | 67 +++++++++++++++++-- 1 file changed, 62 insertions(+), 5 deletions(-) diff --git a/include/boost/math/distributions/hypergeometric.hpp b/include/boost/math/distributions/hypergeometric.hpp index d7d1b6a069..fba8acf5af 100644 --- a/include/boost/math/distributions/hypergeometric.hpp +++ b/include/boost/math/distributions/hypergeometric.hpp @@ -15,7 +15,9 @@ #include #include #include +#include #include +#include #include namespace boost { namespace math { @@ -136,14 +138,69 @@ namespace boost { namespace math { { BOOST_MATH_STD_USING static const char* function = "boost::math::pdf(const hypergeometric_distribution<%1%>&, const %1%&)"; - RealType r = static_cast(x); - auto u = static_cast(lltrunc(r, typename policies::normalise >::type())); - if(u != r) + const RealType x_real = static_cast(x); + const auto u = static_cast(lltrunc(x_real, typename policies::normalise >::type())); + + // If x is an integer, call the PDF directly + if(u == x_real) + { + return pdf(dist, u); + } + + if (x_real < 0) { + return pdf(dist, static_cast(-1)); + } + + const auto r = dist.defective(); + const auto n = dist.sample_count(); + const auto N = dist.total(); + const std::int64_t max_valid_x = std::min(r, n); + + if (x_real > max_valid_x) { + return pdf(dist, max_valid_x + 1); + } + + // If x is not an integer, perform cubic Hermite interpolation + const std::int64_t x_rounded = static_cast(round(x_real)); + if (max_valid_x < 2) { return boost::math::policies::raise_domain_error( - function, "Random variable out of range: must be an integer but got %1%", r, Policy()); + function, "Not enough points available for interpolation, we got %1% points and we need at least 3", x, Policy()); + } + + std::int64_t lower_x = x_rounded - 1; + if (lower_x < 0) { + lower_x = 0; + } + std::int64_t upper_x = lower_x + 2; + if (upper_x > max_valid_x) { + upper_x = max_valid_x; + --lower_x; } - return pdf(dist, u); + + std::vector x_vals; + std::vector y_vals; + for (std::int64_t xi = lower_x; xi <= upper_x; ++xi) { + const auto pdf_val = pdf(dist, xi); + x_vals.push_back(static_cast(xi)); + y_vals.push_back(pdf_val); + } + + std::vector dydx_vals; + for (size_t i = 1; i < x_vals.size() - 1; ++i) { + const RealType deriv = (y_vals[i + 1] - y_vals[i - 1]) / (x_vals[i + 1] - x_vals[i - 1]); + dydx_vals.push_back(deriv); + } + + dydx_vals.insert(dydx_vals.begin(), (y_vals[1] - y_vals[0]) / (x_vals[1] - x_vals[0])); + dydx_vals.push_back((y_vals[y_vals.size() - 1] - y_vals[y_vals.size() - 2]) / + (x_vals[y_vals.size() - 1] - x_vals[y_vals.size() - 2])); + + using boost::math::interpolators::cubic_hermite; + const auto interpolator = cubic_hermite>( + std::move(x_vals), std::move(y_vals), std::move(dydx_vals) + ); + return interpolator(x); } template