Skip to content

Commit

Permalink
added digamma1p function
Browse files Browse the repository at this point in the history
  • Loading branch information
FilippoAiraldi committed Jan 29, 2025
1 parent 44baae3 commit f8a6cb0
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 7 deletions.
34 changes: 28 additions & 6 deletions src/csnlp/util/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -444,8 +444,8 @@ def gammaln(
# cs.if_else(z < 1, cs.substitute(out, z, z + 1) - cs.log(z), out)


def digamma(z: Union[cs.SX, cs.MX, cs.DM], n: int) -> Union[cs.SX, cs.MX, cs.DM]:
"""Computes the digamma function via asymptotic expansion.
def digamma1p(z: Union[cs.SX, cs.MX, cs.DM], n: int) -> Union[cs.SX, cs.MX, cs.DM]:
"""Computes the digamma function evaluated at :math:`z+1` via asymptotic expansion.
Only valid for non-negative real scalars.
Parameters
Expand All @@ -465,13 +465,35 @@ def digamma(z: Union[cs.SX, cs.MX, cs.DM], n: int) -> Union[cs.SX, cs.MX, cs.DM]
Requires :mod:`scipy` to be installed. For important details, see
- https://www.boost.org/doc/libs/1_87_0/libs/math/doc/html/math_toolkit/sf_gamma/digamma.html
"""
# we shift by one since digamma(z) = digamma(z + 1) - 1 / z, and the approximation
# is not good for z < 1

from scipy.special import bernoulli

N_2 = 2 * np.arange(1, n + 1)
B = bernoulli(2 * n)[2::2]
z1p = z + 1
powers = cs.power(z1p, N_2)
return -1 / z + cs.log1p(z) - 0.5 / z1p - cs.sum1(1 / ((N_2 / B) * powers))
return cs.log1p(z) - 0.5 / z1p - cs.sum1(1 / ((N_2 / B) * powers))

def digamma(z: Union[cs.SX, cs.MX, cs.DM], n: int) -> Union[cs.SX, cs.MX, cs.DM]:
"""Computes the digamma function via asymptotic expansion.
Only valid for non-negative real scalars.
Parameters
----------
z : casadi.SX, MX or DM
The value at which to compute the logarithm of the gamma function.
n : int, optional
The number of coefficients to compute for the approximation.
Returns
-------
casadi.SX, MX or DM
The value of the digamma function at ``z``.
Notes
-----
Requires :mod:`scipy` to be installed. For important details, see
- https://www.boost.org/doc/libs/1_87_0/libs/math/doc/html/math_toolkit/sf_gamma/digamma.html
"""
# we shift by one since digamma(z) = digamma(z + 1) - 1 / z, and the approximation
# is not good for z < 1
return digamma1p(z, n) - 1 / z
12 changes: 11 additions & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,12 +299,22 @@ def test_gammaln(self):
actual = cs_gammaln(z).full().flatten()
np.testing.assert_allclose(actual, expected, rtol=1e-6, atol=1e-6)

def test_digamma1p(self):
z = cs.SX.sym("z")
y = math.digamma1p(z, 3)
cs_digamma1p = cs.Function("digamma1p", [z], [y])

z = np.linspace(1e-6, 10, 10_000)
expected = digamma(z + 1)
actual = cs_digamma1p(z).full().flatten()
np.testing.assert_allclose(actual, expected, rtol=1e-2, atol=1e-3)

def test_digamma(self):
z = cs.SX.sym("z")
y = math.digamma(z, 3)
cs_digamma = cs.Function("digamma", [z], [y])

z = np.linspace(1e-6, 10, 10000)
z = np.linspace(1e-6, 10, 10_000)
expected = digamma(z)
actual = cs_digamma(z).full().flatten()
np.testing.assert_allclose(actual, expected, rtol=1e-4, atol=1e-4)
Expand Down

0 comments on commit f8a6cb0

Please sign in to comment.