Skip to content

Commit

Permalink
Rename taylor_mode() to taylor_mode_scan() and move to toplevel (#666)
Browse files Browse the repository at this point in the history
* taylor_mode() is taylor_mode_scan() now

* Forward mode is forward_mode_recursive

* Updated changelog

* Updated linter dependencies and rerun

* Doc update

* Updated index and removed 3.12

* Moved taylor to toplevel

* Update docs

* Show annotations in docs

* Removed print from tests

* Reformat notebook

* Fixed tests
  • Loading branch information
pnkraemer authored Oct 5, 2023
1 parent 191bfd3 commit 1d91090
Show file tree
Hide file tree
Showing 75 changed files with 198 additions and 217 deletions.
11 changes: 6 additions & 5 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,23 @@ repos:
- id: end-of-file-fixer
- id: check-merge-conflict
- repo: https://github.com/psf/black
rev: 23.7.0
rev: 23.9.1
hooks:
- id: black-jupyter
language_version: python3
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.0.284
rev: v0.0.292
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- repo: https://github.com/mwouts/jupytext
rev: v1.15.0
rev: v1.15.2
hooks:
- id: jupytext
files: (docs/).+
files: ^(docs/(benchmarks|examples_solver_config|examples_parameter_estimation|getting_started)/).+.ipynb
args: [--sync]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.4.1
rev: v1.5.1
hooks:
- id: mypy
args: [--ignore-missing-imports]
Expand Down
16 changes: 6 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,19 +1,17 @@
# probdiffeq


[![PyPi Version](https://img.shields.io/pypi/v/probdiffeq.svg?style=flat-square&color=darkgray)](https://pypi.org/project/probdiffeq/)
[![gh-actions](https://img.shields.io/github/actions/workflow/status/pnkraemer/probdiffeq/ci.yaml?branch=main&style=flat-square)](https://github.com/pnkraemer/probdiffeq/actions?query=workflow%3Aci)
<a href="https://github.com/pnkraemer/probdiffeq/blob/master/LICENSE"><img src="https://img.shields.io/github/license/pnkraemer/probdiffeq?style=flat-square&color=2b9348" alt="License Badge"/></a>
[![GitHub stars](https://img.shields.io/github/stars/pnkraemer/probdiffeq.svg?style=flat-square&logo=github&label=Stars&logoColor=white)](https://github.com/pnkraemer/probdiffeq)
![Python](https://img.shields.io/badge/python-3.9+-black.svg?style=flat-square)

[![Actions status](https://github.com/pnkraemer/probdiffeq/workflows/ci/badge.svg)](https://github.com/pnkraemer/probdiffeq/actions)
[![image](https://img.shields.io/pypi/v/probdiffeq.svg)](https://pypi.python.org/pypi/probdiffeq)
[![image](https://img.shields.io/pypi/l/probdiffeq.svg)](https://pypi.python.org/pypi/probdiffeq)
[![image](https://img.shields.io/pypi/pyversions/probdiffeq.svg)](https://pypi.python.org/pypi/probdiffeq)

## Probabilistic solvers for differential equations in JAX

ProbDiffEq implements adaptive probabilistic numerical solvers for initial value problems.

It inherits automatic differentiation, vectorisation, and GPU capability from JAX.
Features include:

**Features include:**

* Stable implementation
* Calibration, step-size adaptation, and checkpointing
Expand All @@ -26,8 +24,6 @@ Features include:

and many more.



* **AN EASY EXAMPLE:** [LINK](https://pnkraemer.github.io/probdiffeq/getting_started/easy_example/)
* **EXAMPLES:** [LINK](https://pnkraemer.github.io/probdiffeq/examples_solver_config/posterior_uncertainties/)
* **CHOOSING A SOLVER:** [LINK](https://pnkraemer.github.io/probdiffeq/getting_started/choosing_a_solver/)
Expand Down
1 change: 0 additions & 1 deletion docs/api_docs/solvers/taylor/affine.md

This file was deleted.

1 change: 0 additions & 1 deletion docs/api_docs/solvers/taylor/autodiff.md

This file was deleted.

1 change: 0 additions & 1 deletion docs/api_docs/solvers/taylor/estim.md

This file was deleted.

1 change: 1 addition & 0 deletions docs/api_docs/taylor/affine.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: probdiffeq.taylor.affine
1 change: 1 addition & 0 deletions docs/api_docs/taylor/autodiff.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: probdiffeq.taylor.autodiff
1 change: 1 addition & 0 deletions docs/api_docs/taylor/estim.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
::: probdiffeq.taylor.estim
2 changes: 1 addition & 1 deletion docs/benchmarks/hires/plot.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jupyter:
extension: .md
format_name: markdown
format_version: '1.3'
jupytext_version: 1.15.0
jupytext_version: 1.15.2
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down
4 changes: 2 additions & 2 deletions docs/benchmarks/hires/run_hires.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from probdiffeq.solvers import calibrated
from probdiffeq.solvers.strategies import filters
from probdiffeq.solvers.strategies.components import corrections, priors
from probdiffeq.solvers.taylor import autodiff
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import info


Expand Down Expand Up @@ -105,7 +105,7 @@ def param_to_solution(tol):

# Initial state
vf_auto = functools.partial(vf_probdiffeq, t=t0)
tcoeffs = autodiff.taylor_mode(vf_auto, (u0,), num=num_derivatives)
tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0,), num=num_derivatives)
init = solver.initial_condition(tcoeffs, output_scale=1.0)

# Solve
Expand Down
2 changes: 1 addition & 1 deletion docs/benchmarks/lotkavolterra/plot.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jupyter:
extension: .md
format_name: markdown
format_version: '1.3'
jupytext_version: 1.15.0
jupytext_version: 1.15.2
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down
4 changes: 2 additions & 2 deletions docs/benchmarks/lotkavolterra/run_lotkavolterra.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from probdiffeq.solvers import calibrated
from probdiffeq.solvers.strategies import filters
from probdiffeq.solvers.strategies.components import corrections, priors
from probdiffeq.solvers.taylor import autodiff
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import info


Expand Down Expand Up @@ -93,7 +93,7 @@ def param_to_solution(tol):

# Initial state
vf_auto = functools.partial(vf_probdiffeq, t=t0)
tcoeffs = autodiff.taylor_mode(vf_auto, (u0,), num=num_derivatives)
tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0,), num=num_derivatives)
output_scale = 1.0 * jnp.ones((2,)) if implementation == "blockdiag" else 1.0
init = solver.initial_condition(tcoeffs, output_scale=output_scale)

Expand Down
2 changes: 1 addition & 1 deletion docs/benchmarks/pleiades/plot.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jupyter:
extension: .md
format_name: markdown
format_version: '1.3'
jupytext_version: 1.15.0
jupytext_version: 1.15.2
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down
4 changes: 2 additions & 2 deletions docs/benchmarks/pleiades/run_pleiades.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from probdiffeq.solvers import calibrated
from probdiffeq.solvers.strategies import filters
from probdiffeq.solvers.strategies.components import corrections, priors
from probdiffeq.solvers.taylor import autodiff
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import info


Expand Down Expand Up @@ -117,7 +117,7 @@ def param_to_solution(tol):

# Initial state
vf_auto = functools.partial(vf_probdiffeq, t=t0)
tcoeffs = autodiff.taylor_mode(vf_auto, (u0, du0), num=num_derivatives - 1)
tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0, du0), num=num_derivatives - 1)
init = solver.initial_condition(tcoeffs, output_scale=1.0)

# Solve
Expand Down
2 changes: 1 addition & 1 deletion docs/benchmarks/taylor_fitzhughnagumo/plot.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jupyter:
extension: .md
format_name: markdown
format_version: '1.3'
jupytext_version: 1.15.0
jupytext_version: 1.15.2
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from jax import config

from probdiffeq.impl import impl
from probdiffeq.solvers.taylor import autodiff
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import info


Expand Down Expand Up @@ -57,13 +57,13 @@ def timer(fun, /):
return timer


def taylor_mode() -> Callable:
def taylor_mode_scan() -> Callable:
"""Taylor-mode estimation."""
vf_auto, (u0,) = _fitzhugh_nagumo()

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.taylor_mode(vf_auto, (u0,), num=num)
tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0,), num=num)
return jax.block_until_ready(tcoeffs)

return estimate
Expand Down Expand Up @@ -93,13 +93,13 @@ def estimate(num):
return estimate


def forward_mode() -> Callable:
def forward_mode_recursive() -> Callable:
"""Forward-mode estimation."""
vf_auto, (u0,) = _fitzhugh_nagumo()

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.forward_mode(vf_auto, (u0,), num=num)
tcoeffs = autodiff.forward_mode_recursive(vf_auto, (u0,), num=num)
return jax.block_until_ready(tcoeffs)

return estimate
Expand Down Expand Up @@ -153,8 +153,8 @@ def adaptive_benchmark(fun, *, timeit_fun: Callable, max_time) -> dict:
if __name__ == "__main__":
set_jax_config()
algorithms = {
r"Forward-mode": forward_mode(),
r"Taylor-mode (scan)": taylor_mode(),
r"Forward-mode": forward_mode_recursive(),
r"Taylor-mode (scan)": taylor_mode_scan(),
r"Taylor-mode (unroll)": taylor_mode_unroll(),
r"Taylor-mode (doubling)": taylor_mode_doubling(),
}
Expand Down
2 changes: 1 addition & 1 deletion docs/benchmarks/taylor_node/plot.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jupyter:
extension: .md
format_name: markdown
format_version: '1.3'
jupytext_version: 1.15.0
jupytext_version: 1.15.2
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down
14 changes: 7 additions & 7 deletions docs/benchmarks/taylor_node/run_taylor_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from jax import config

from probdiffeq.impl import impl
from probdiffeq.solvers.taylor import autodiff
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import info


Expand Down Expand Up @@ -58,13 +58,13 @@ def timer(fun, /):
return timer


def taylor_mode() -> Callable:
def taylor_mode_scan() -> Callable:
"""Taylor-mode estimation."""
vf_auto, (u0,) = _node()

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.taylor_mode(vf_auto, (u0,), num=num)
tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0,), num=num)
return jax.block_until_ready(tcoeffs)

return estimate
Expand Down Expand Up @@ -94,13 +94,13 @@ def estimate(num):
return estimate


def forward_mode() -> Callable:
def forward_mode_recursive() -> Callable:
"""Forward-mode estimation."""
vf_auto, (u0,) = _node()

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.forward_mode(vf_auto, (u0,), num=num)
tcoeffs = autodiff.forward_mode_recursive(vf_auto, (u0,), num=num)
return jax.block_until_ready(tcoeffs)

return estimate
Expand Down Expand Up @@ -167,8 +167,8 @@ def adaptive_benchmark(fun, *, timeit_fun: Callable, max_time) -> dict:
set_jax_config()
backend.select("jax")
algorithms = {
r"Forward-mode": forward_mode(),
r"Taylor-mode (scan)": taylor_mode(),
r"Forward-mode": forward_mode_recursive(),
r"Taylor-mode (scan)": taylor_mode_scan(),
r"Taylor-mode (unroll)": taylor_mode_unroll(),
r"Taylor-mode (doubling)": taylor_mode_doubling(),
}
Expand Down
2 changes: 1 addition & 1 deletion docs/benchmarks/taylor_pleiades/plot.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jupyter:
extension: .md
format_name: markdown
format_version: '1.3'
jupytext_version: 1.15.0
jupytext_version: 1.15.2
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down
14 changes: 7 additions & 7 deletions docs/benchmarks/taylor_pleiades/run_taylor_pleiades.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from jax import config

from probdiffeq.impl import impl
from probdiffeq.solvers.taylor import autodiff
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import info


Expand Down Expand Up @@ -57,13 +57,13 @@ def timer(fun, /):
return timer


def taylor_mode() -> Callable:
def taylor_mode_scan() -> Callable:
"""Taylor-mode estimation."""
vf_auto, (u0, du0) = _pleiades()

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.taylor_mode(vf_auto, (u0, du0), num=num)
tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0, du0), num=num)
return jax.block_until_ready(tcoeffs)

return estimate
Expand All @@ -81,13 +81,13 @@ def estimate(num):
return estimate


def forward_mode() -> Callable:
def forward_mode_recursive() -> Callable:
"""Forward-mode estimation."""
vf_auto, (u0, du0) = _pleiades()

@functools.partial(jax.jit, static_argnames=["num"])
def estimate(num):
tcoeffs = autodiff.forward_mode(vf_auto, (u0, du0), num=num)
tcoeffs = autodiff.forward_mode_recursive(vf_auto, (u0, du0), num=num)
return jax.block_until_ready(tcoeffs)

return estimate
Expand Down Expand Up @@ -161,8 +161,8 @@ def adaptive_benchmark(fun, *, timeit_fun: Callable, max_time) -> dict:
if __name__ == "__main__":
set_jax_config()
algorithms = {
r"Forward-mode": forward_mode(),
r"Taylor-mode (scan)": taylor_mode(),
r"Forward-mode": forward_mode_recursive(),
r"Taylor-mode (scan)": taylor_mode_scan(),
r"Taylor-mode (unroll)": taylor_mode_unroll(),
}

Expand Down
2 changes: 1 addition & 1 deletion docs/benchmarks/vanderpol/plot.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ jupyter:
extension: .md
format_name: markdown
format_version: '1.3'
jupytext_version: 1.15.0
jupytext_version: 1.15.2
kernelspec:
display_name: Python 3 (ipykernel)
language: python
Expand Down
4 changes: 2 additions & 2 deletions docs/benchmarks/vanderpol/run_vanderpol.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from probdiffeq.solvers import calibrated
from probdiffeq.solvers.strategies import filters
from probdiffeq.solvers.strategies.components import corrections, priors
from probdiffeq.solvers.taylor import autodiff
from probdiffeq.taylor import autodiff
from probdiffeq.util.doc_util import info


Expand Down Expand Up @@ -96,7 +96,7 @@ def param_to_solution(tol):

# Initial state
vf_auto = functools.partial(vf_probdiffeq, t=t0)
tcoeffs = autodiff.taylor_mode(vf_auto, (u0, du0), num=num_derivatives - 1)
tcoeffs = autodiff.taylor_mode_scan(vf_auto, (u0, du0), num=num_derivatives - 1)
init = solver.initial_condition(tcoeffs, output_scale=1.0)

# Solve
Expand Down
10 changes: 8 additions & 2 deletions docs/dev_docs/changelog.md
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
# Changelog

## v0.2.3
## v0.3.0

**New features:**

* A new function `taylor_mode_unroll` implements Taylor-series estimation without a `scan`.

**Breaking changes:**

* What was formerly `taylor_mode()`, is now `taylor_mode_scan()` and stands in contrast to the new `taylor_mode_unroll()`.
* What was formerly `forward_mode()`, is now `forward_mode_recursive()`.
* The entire `taylor` subpackage moved to top-level. Instead of `from probdiffeq.solvers.taylor import ...`, use `from probdiffeq.taylor import ...`.


## v0.2.2

Expand All @@ -16,7 +22,7 @@ This release was due to issues in the publishing workflow.
**Breaking changes:**

* The input-argument to `taylor_mode_doubling` is `num_doublings` instead of `num`.
This argument behaves differently to e.g., `taylor_mode(..., num)`.
This argument behaves differently to e.g., `taylor_mode_scan(..., num)`.


## v0.2.0
Expand Down
Loading

0 comments on commit 1d91090

Please sign in to comment.