-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Todo in taylor.py * Formatting in test case file * repr in implementations * Test utilities * cubature is cubature_rule now * Cleaned up conftest * fixed grid and while loop tests updated * Solve and save at updated * simulate terminal values updated * Tests for dense output updated * test_edges -> test_misc * Fixed grid differentiability tests * JVP tests for fixed grid solvers * Improved test readability * Update and rerun internal benchmark * Removed debug_nan flag * Cubature rule function in DenseSLR1 * SLR0 takes cubature factory * Updated benchmark * Fixed a doctest
- Loading branch information
Showing
19 changed files
with
2,961 additions
and
3,231 deletions.
There are no files selected for viewing
5,153 changes: 2,396 additions & 2,757 deletions
5,153
docs/benchmarks/lotka_volterra/internal.ipynb
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
"""Test utilities.""" | ||
|
||
from probdiffeq import solvers | ||
from probdiffeq.implementations import recipes | ||
from probdiffeq.strategies import filters | ||
|
||
|
||
def generate_solver( | ||
*, | ||
solver_factory=solvers.MLESolver, | ||
strategy_factory=filters.Filter, | ||
impl_factory=recipes.IsoTS0.from_params, | ||
**impl_factory_kwargs, | ||
): | ||
"""Generate a solver. | ||
Examples | ||
-------- | ||
>>> from jax.config import config | ||
>>> config.update("jax_platform_name", "cpu") | ||
>>> from probdiffeq import solvers | ||
>>> from probdiffeq.implementations import recipes | ||
>>> from probdiffeq.strategies import smoothers | ||
>>> print(generate_solver()) | ||
MLESolver(strategy=Filter(implementation=<IsoTS0 with num_derivatives=4>)) | ||
>>> print(generate_solver(num_derivatives=1)) | ||
MLESolver(strategy=Filter(implementation=<IsoTS0 with num_derivatives=1>)) | ||
>>> print(generate_solver(solver_factory=solvers.DynamicSolver)) | ||
DynamicSolver(strategy=Filter(implementation=<IsoTS0 with num_derivatives=4>)) | ||
>>> impl_fcty = recipes.DenseTS1.from_params | ||
>>> strat_fcty = smoothers.Smoother | ||
>>> print(generate_solver(strategy_factory=strat_fcty, impl_factory=impl_fcty, ode_shape=(1,))) # noqa: E501 | ||
MLESolver(strategy=Smoother(implementation=<DenseTS1 with num_derivatives=4>)) | ||
""" | ||
impl = impl_factory(**impl_factory_kwargs) | ||
strat = strategy_factory(impl) | ||
|
||
# I am not too happy with the need for this distinction below... | ||
|
||
if solver_factory in [solvers.MLESolver, solvers.DynamicSolver]: | ||
return solver_factory(strat) | ||
|
||
scale_sqrtm = impl.extrapolation.init_output_scale_sqrtm() | ||
return solver_factory(strat, output_scale_sqrtm=scale_sqrtm) |
Oops, something went wrong.