From f86f920607bff0f9d273643133a3626e0ca3e2bc Mon Sep 17 00:00:00 2001 From: himoto Date: Wed, 20 Nov 2024 17:26:29 +0000 Subject: [PATCH] Add test for SDE problem with de.jit --- diffeqpy/tests/test_sde.py | 26 ++++++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/diffeqpy/tests/test_sde.py b/diffeqpy/tests/test_sde.py index f5d53eb..09ddf57 100644 --- a/diffeqpy/tests/test_sde.py +++ b/diffeqpy/tests/test_sde.py @@ -35,3 +35,29 @@ def g(du,u,p,t): p = [10.0,28.0,2.66] prob = de.SDEProblem(numba_f, numba_g, u0, tspan, p) sol = de.solve(prob) + + +def test_jit(): + + def f(du, u, p, t): + x, y, z = u + sigma, rho, beta = p + du[0] = sigma * (y - x) + du[1] = x * (rho - z) - y + du[2] = x * y - beta * z + + def g(du, u, p, t): + du[0] = 0.3 * u[0] + du[1] = 0.3 * u[1] + du[2] = 0.3 * u[2] + + u0 = [1.0, 0.0, 0.0] + tspan = (0.0, 100.0) + p = [10.0, 28.0, 2.66] + prob = de.jit(de.SDEProblem(f, g, u0, tspan, p)) + sol = de.solve(prob) + assert sol.t[-1] == tspan[-1], f"Solver did not reach the final time. Last time: {sol.t[-1]}" + assert len(sol.u) > 0, "Solution is empty." + assert all( + abs(sol.u[i][j]) < float("inf") for j in range(len(u0)) for i in range(len(sol.t)) + ), "Solution contains non-finite values." \ No newline at end of file