Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Internal: Delete unused code #780

Merged
merged 1 commit into from
Aug 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples_parameter_estimation/neural_ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def build_update_fn(*, optimizer, loss_fn):
@jax.jit
def update(params, opt_state):
"""Update the optimiser state."""
loss, grads = jax.value_and_grad(loss_fn)(params)
_loss, grads = jax.value_and_grad(loss_fn)(params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def build_update_fn(*, optimizer, loss_fn):
@jax.jit
def update(params, opt_state):
"""Update the optimiser state."""
loss, grads = jax.value_and_grad(loss_fn)(params)
_loss, grads = jax.value_and_grad(loss_fn)(params)
updates, opt_state = optimizer.update(grads, opt_state)
params = optax.apply_updates(params, updates)
return params, opt_state
Expand Down
5 changes: 0 additions & 5 deletions probdiffeq/backend/containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,3 @@
@dataclass_transform()
def dataclass(*args, **kwargs):
return dataclasses.dataclass(*args, **kwargs)


@dataclass_transform()
def dataclass_astuple(datacls):
return dataclasses.astuple(datacls)
4 changes: 0 additions & 4 deletions probdiffeq/backend/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,6 @@ def vector_norm(arr, /, *, order=None):
return jnp.linalg.norm(arr, ord=order)


def matrix_norm(arr, /, *, order=None):
return jnp.linalg.norm(arr, ord=order)


def solve_triangular(matrix, rhs, /, *, trans=0, lower=False):
return jax.scipy.linalg.solve_triangular(matrix, rhs, trans=trans, lower=lower)

Expand Down
9 changes: 2 additions & 7 deletions probdiffeq/impl/_conditional.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,11 @@

from probdiffeq.backend import abc, containers, functools, linalg, tree_util
from probdiffeq.backend import numpy as np
from probdiffeq.backend.typing import Any, Array, Callable
from probdiffeq.backend.typing import Any, Array
from probdiffeq.impl import _normal
from probdiffeq.util import cholesky_util, linop_util


class Transformation(containers.NamedTuple):
matmul: Callable
bias: Array


class TransformBackend(abc.ABC):
@abc.abstractmethod
def marginalise(self, rv, transformation, /):
Expand Down Expand Up @@ -165,7 +160,7 @@ def conditional(self, matmul, noise):
return Conditional(matmul, noise)

@abc.abstractmethod
def identity(self, num_derivatives_per_ode_dimension, /):
def identity(self, ndim, /):
raise NotImplementedError

@abc.abstractmethod
Expand Down
2 changes: 1 addition & 1 deletion probdiffeq/impl/_normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def preconditioner_apply(self, rv: Normal, p, /) -> Normal:
raise NotImplementedError

@abc.abstractmethod
def standard(self, num_derivatives_per_ode_dimension, /, output_scale) -> Normal:
def standard(self, num, /, output_scale) -> Normal:
raise NotImplementedError


Expand Down
14 changes: 0 additions & 14 deletions probdiffeq/util/cholesky_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,20 +106,6 @@ def revert_conditional(R_X_F, R_X, R_YX):
return R_Y, (R_XY, G)


def _triu_via_shortcut(R, R_YX):
"""If R_X and R_X_F are zero, triu_via_qr(R) can be implemented cheaply.

Namely, by applying it only to R_YX and embedding the result in zeros.
This is not only more efficient because it requires fewer floating-point operations
than qr-decomposing the full matrix, but it also admits a well-defined
reverse-mode derivative!
"""
R = np.zeros_like(R)
R_YX = triu_via_qr(R_YX)
n, m = np.shape(R_YX)
return R.at[:n, :m].set(R_YX)


def _is_matrix(mat, matrix_ndim=2):
return np.ndim(mat) == matrix_ndim

Expand Down
Loading