From f8a6cb03ac242820c9d83506030ec270768dfa66 Mon Sep 17 00:00:00 2001 From: Filippo Airaldi Date: Wed, 29 Jan 2025 19:23:06 +0100 Subject: [PATCH] added digamma1p function --- src/csnlp/util/math.py | 34 ++++++++++++++++++++++++++++------ tests/test_util.py | 12 +++++++++++- 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/src/csnlp/util/math.py b/src/csnlp/util/math.py index eadb4c1..092167e 100644 --- a/src/csnlp/util/math.py +++ b/src/csnlp/util/math.py @@ -444,8 +444,8 @@ def gammaln( # cs.if_else(z < 1, cs.substitute(out, z, z + 1) - cs.log(z), out) -def digamma(z: Union[cs.SX, cs.MX, cs.DM], n: int) -> Union[cs.SX, cs.MX, cs.DM]: - """Computes the digamma function via asymptotic expansion. +def digamma1p(z: Union[cs.SX, cs.MX, cs.DM], n: int) -> Union[cs.SX, cs.MX, cs.DM]: + """Computes the digamma function evaluated at :math:`z+1` via asymptotic expansion. Only valid for non-negative real scalars. Parameters @@ -465,13 +465,35 @@ def digamma(z: Union[cs.SX, cs.MX, cs.DM], n: int) -> Union[cs.SX, cs.MX, cs.DM] Requires :mod:`scipy` to be installed. For important details, see - https://www.boost.org/doc/libs/1_87_0/libs/math/doc/html/math_toolkit/sf_gamma/digamma.html """ - # we shift by one since digamma(z) = digamma(z + 1) - 1 / z, and the approximation - # is not good for z < 1 - from scipy.special import bernoulli N_2 = 2 * np.arange(1, n + 1) B = bernoulli(2 * n)[2::2] z1p = z + 1 powers = cs.power(z1p, N_2) - return -1 / z + cs.log1p(z) - 0.5 / z1p - cs.sum1(1 / ((N_2 / B) * powers)) + return cs.log1p(z) - 0.5 / z1p - cs.sum1(1 / ((N_2 / B) * powers)) + +def digamma(z: Union[cs.SX, cs.MX, cs.DM], n: int) -> Union[cs.SX, cs.MX, cs.DM]: + """Computes the digamma function via asymptotic expansion. + Only valid for non-negative real scalars. + + Parameters + ---------- + z : casadi.SX, MX or DM + The value at which to compute the logarithm of the gamma function. + n : int, optional + The number of coefficients to compute for the approximation. + + Returns + ------- + casadi.SX, MX or DM + The value of the digamma function at ``z``. + + Notes + ----- + Requires :mod:`scipy` to be installed. For important details, see + - https://www.boost.org/doc/libs/1_87_0/libs/math/doc/html/math_toolkit/sf_gamma/digamma.html + """ + # we shift by one since digamma(z) = digamma(z + 1) - 1 / z, and the approximation + # is not good for z < 1 + return digamma1p(z, n) - 1 / z diff --git a/tests/test_util.py b/tests/test_util.py index f514739..f081b35 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -299,12 +299,22 @@ def test_gammaln(self): actual = cs_gammaln(z).full().flatten() np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6) + def test_digamma1p(self): + z = cs.SX.sym("z") + y = math.digamma1p(z, 3) + cs_digamma1p = cs.Function("digamma1p", [z], [y]) + + z = np.linspace(1e-6, 10, 10_000) + expected = digamma(z + 1) + actual = cs_digamma1p(z).full().flatten() + np.testing.assert_allclose(actual, expected, rtol=1e-2, atol=1e-3) + def test_digamma(self): z = cs.SX.sym("z") y = math.digamma(z, 3) cs_digamma = cs.Function("digamma", [z], [y]) - z = np.linspace(1e-6, 10, 10000) + z = np.linspace(1e-6, 10, 10_000) expected = digamma(z) actual = cs_digamma(z).full().flatten() np.testing.assert_allclose(actual, expected, rtol=1e-4, atol=1e-4)