Skip to content

Commit

Permalink
Use vulture to detect dead code and remove what the command found (#788)
Browse files Browse the repository at this point in the history
  • Loading branch information
pnkraemer authored Oct 25, 2024
1 parent 31f027a commit 5891d75
Show file tree
Hide file tree
Showing 5 changed files with 5 additions and 58 deletions.
5 changes: 0 additions & 5 deletions docs/examples_solver_config/conditioning-on-zero-residual.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,11 +132,6 @@ def vector_field(y, t): # noqa: ARG001
}


def log_residual(*args):
"""Evaluate the log-ODE-residual."""
return jnp.log10(jnp.abs(residual(*args)))


def residual(x, t):
"""Evaluate the ODE residual."""
return x[1] - jax.vmap(jax.vmap(vector_field), in_axes=(0, None))(x[0], t)
Expand Down
3 changes: 3 additions & 0 deletions makefile
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,6 @@ doc:
make example
make benchmarks-plot-results
JUPYTER_PLATFORM_DIRS=1 mkdocs build

find-dead-code:
vulture . --ignore-names case*,fixture*,*jvp --exclude probdiffeq/_version.py
4 changes: 0 additions & 4 deletions probdiffeq/backend/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ def diff(arr, /):
return jnp.diff(arr)


def diff_along_axis(arr, /, *, axis):
return jnp.diff(arr, axis=axis)


def reshape(arr, /, new_shape, order="C"):
return jnp.reshape(arr, new_shape, order=order)

Expand Down
45 changes: 0 additions & 45 deletions probdiffeq/backend/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,23 +49,6 @@ def solution(t):
return solution


def ivp_logistic():
# Local imports because diffeqzoo is not an official dependency
from diffeqzoo import backend, ivps

if not backend.has_been_selected:
backend.select("jax")

f, u0, (t0, _), f_args = ivps.logistic()
t1 = 0.75

@jax.jit
def vf(x, *, t): # noqa: ARG001
return f(x, *f_args)

return vf, (u0,), (t0, t1)


def ivp_lotka_volterra():
# Local imports because diffeqzoo is not an official dependency
from diffeqzoo import backend, ivps
Expand All @@ -83,34 +66,6 @@ def vf(x, *, t): # noqa: ARG001
return vf, (u0,), (t0, t1)


def ivp_affine_multi_dimensional():
t0, t1 = 0.0, 2.0
u0 = jnp.ones((2,))

@jax.jit
def vf(x, *, t): # noqa: ARG001
return 2 * x

def solution(t):
return jnp.exp(2 * t) * jnp.ones((2,))

return vf, (u0,), (t0, t1), solution


def ivp_affine_scalar():
t0, t1 = 0.0, 2.0
u0 = 1.0

@jax.jit
def vf(x, *, t): # noqa: ARG001
return 2 * x

def solution(t):
return jnp.exp(2 * t)

return vf, (u0,), (t0, t1), solution


def ivp_three_body_1st():
# Local imports because diffeqzoo is not an official dependency
from diffeqzoo import backend, ivps
Expand Down
6 changes: 2 additions & 4 deletions probdiffeq/util/filter_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
from probdiffeq.backend.typing import Any


# TODO: fixedpointsmoother and kalmanfilter should be estimate()
# with two different methods. This saves a lot of code.
def estimate_fwd(data, /, init, prior_transitions, observation_model, *, estimator):
"""Estimate forward-in-time."""
initialise, step = estimator
Expand Down Expand Up @@ -91,7 +89,7 @@ def _initialise(rv, data, model) -> _KFState:
observed, conditional = ssm.conditional.revert(rv, model)
corrected = ssm.conditional.apply(data, conditional)
logpdf = ssm.stats.logpdf(data, observed)
return _KFState(corrected, 1.0, logpdf)
return _KFState(corrected, num_data_points=0.0, logpdf=logpdf)

def _step(state: _KFState, cond_and_data_and_obs) -> tuple[_KFState, _KFState]:
conditional, data, observation = cond_and_data_and_obs
Expand All @@ -105,7 +103,7 @@ def _step(state: _KFState, cond_and_data_and_obs) -> tuple[_KFState, _KFState]:
# Update logpdf
logpdf_new = ssm.stats.logpdf(data, observed)
logpdf_mean = (logpdf * num_data + logpdf_new) / (num_data + 1)
state = _KFState(corrected, num_data + 1.0, logpdf_mean)
state = _KFState(corrected, num_data_points=num_data + 1.0, logpdf=logpdf_mean)

# Scan-compatible output
return state, state
Expand Down

0 comments on commit 5891d75

Please sign in to comment.