Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache endpoints in InterpolatedUnivariateSpline #116

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jax_cosmo/scipy/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,7 @@ def __init__(self, x, y, k=3, endpoints="not-a-knot", coefficients=None):

# Saving spline parameters for evaluation later
self.k = k
self._endpoints = endpoints
self._x = x
self._y = y
self._coefficients = coefficients
Expand Down
60 changes: 60 additions & 0 deletions tests/test_spline.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from jax.config import config

config.update("jax_enable_x64", True)
import jax
import jax.numpy as np
from numpy.testing import assert_allclose

Expand Down Expand Up @@ -72,3 +73,62 @@ def test_cubic_spline():
a = spl_ref.antiderivative()(t) - spl_ref.antiderivative()(0.01)
b = spl.antiderivative(t) - spl.antiderivative(0.01)
assert_allclose(a, b, rtol=1e-10)


def test_spline_pytree():
"""
Test that we can interpolate over pytrees.
"""

# Time and data structure to interpolate over.
ts = np.linspace(0, 1, 10)
us = {
"a": np.linspace(0.0, 1.0, 10),
"b": {
"b0": np.linspace(0.0, 0.1, 10),
"b1": np.linspace(0.0, 0.2, 10),
},
}

# Generate a pytree of splines with the same structure as "us".
spline_order = 1
spline_tree = jax.tree_util.tree_map(
lambda u: InterpolatedUnivariateSpline(ts, u, spline_order), us
)

def eval_splines(t):
return jax.tree_util.tree_map(
lambda sp: sp(t),
spline_tree,
is_leaf=lambda obj: isinstance(obj, InterpolatedUnivariateSpline),
)

# Evaluate the splines at t=0.0.
out0 = eval_splines(0.0)
assert out0 == {
"a": 0.0,
"b": {
"b0": 0.0,
"b1": 0.0,
},
}

# Evaluate the splines at t=0.5.
out05 = eval_splines(0.5)
assert out05 == {
"a": 0.5,
"b": {
"b0": 0.05,
"b1": 0.1,
},
}

# Evaluate the splines at t=1.0.
out1 = eval_splines(1.0)
assert out1 == {
"a": 1.0,
"b": {
"b0": 0.1,
"b1": 0.2,
},
}