diff --git a/sunode/solver.py b/sunode/solver.py index 227e7f2..6bccf88 100644 --- a/sunode/solver.py +++ b/sunode/solver.py @@ -217,7 +217,6 @@ def _init_sundials(self): self._state_buffer = sunode.empty_vector(n_states) self._state_buffer.data[:] = 0 - self._jac = check(lib.SUNDenseMatrix(n_states, n_states)) self._ode = check(lib.CVodeCreate(self._solver_kind)) rhs = self._problem.make_sundials_rhs() @@ -233,17 +232,57 @@ def _init_sundials(self): self._constraints_vec = sunode.from_numpy(self._constraints) check(lib.CVodeSetConstraints(self._ode, self._constraints_vec.c_ptr)) - self._make_linsol(self._linear_solver_kind) + self._make_linsol(self._linear_solver_kind, **self._linear_solver_kwargs) self._compute_sens = self._sens_mode is not None if self._compute_sens: sens_rhs = self._problem.make_sundials_sensitivity_rhs() self._init_sens(sens_rhs, self._sens_mode) - def __init__(self, problem: Problem, *, - abstol: float = 1e-10, reltol: float = 1e-10, - sens_mode: Optional[str] = None, scaling_factors: Optional[np.ndarray] = None, - constraints: Optional[np.ndarray] = None, solver='BDF', linear_solver="dense"): + def __init__( + self, + problem: Problem, + *, + abstol: float = 1e-10, + reltol: float = 1e-10, + sens_mode: Optional[str] = None, + scaling_factors: Optional[np.ndarray] = None, + constraints: Optional[np.ndarray] = None, + solver='BDF', + linear_solver="dense", + linear_solver_kwargs=None, + ): + """ + Parameters + ---------- + problem: sunode Problem + abstol: float, optional + Absolute tolerance (default is 1e-10). + reltol: float, optional + Relative tolerance (default is 1e-10). + sense_mode: {"simultaneous", "staggered"}, optional + Forward sensitivity method (see [this explanation in the SUNDIALS documentation][1]). + scaling_factors: numpy.ndarray, optional + Vector of positive scaling factors used for the sensitivity calculations. + constraints: numpy.ndarray, optional + Vector of inequality constraints for the solution. The length of the vector must correspond + to the number of states. See the SUNDIALS documentation for the [constraint options][2]. + solver: {"BDF", "ADAMS"}, optional + Algorithm for solving the ODE (the default is ``"BDF"``). + linear_solver: {"dense", "dense_finitediff", "spgmr", "spgmr_finitediff", "band"}, optional + Type of linear solver to use (the default is "dense"). + If linear_solver is ``"band"``, ``linear_solver_kwargs`` must contain ``lower_bandwidth`` + and ``upper_bandwidth``, defining the lower and upper half-bandwidth of the banded matrix + (see the [SUNDIALS documentation][3] for details). + linear_solver_kwargs: dict, optional + Keyword arguments for the linear solver. + + [1]: https://sundials.readthedocs.io/en/latest/idas/Mathematics_link.html#forward-sensitivity-methods + [2]: https://sundials.readthedocs.io/en/latest/cvode/Usage/index.html#c.CVodeSetConstraints + [3]: https://sundials.readthedocs.io/en/latest/sunmatrix/SUNMatrix_links.html#the-sunmatrix-band-module + """ + if linear_solver_kwargs is None: + linear_solver_kwargs = {} self._problem = problem self._user_data = problem.make_user_data() self._constraints = constraints @@ -252,6 +291,7 @@ def __init__(self, problem: Problem, *, self._reltol = reltol self._linear_solver_kind = linear_solver + self._linear_solver_kwargs = linear_solver_kwargs self._sens_mode = sens_mode if solver == 'BDF': @@ -268,6 +308,7 @@ def __init__(self, problem: Problem, *, "_abstol", "_reltol", "_linear_solver_kind", + "_linear_solver_kwargs", "_sens_mode", "_solver_kind", ] @@ -281,14 +322,17 @@ def __setstate__(self, state): self.__dict__.update(state) self._init_sundials() - def _make_linsol(self, linear_solver) -> None: + def _make_linsol(self, linear_solver, **kwargs) -> None: + n_states = self._problem.n_states if linear_solver == "dense": + self._jac = check(lib.SUNDenseMatrix(n_states, n_states)) linsolver = check(lib.SUNLinSol_Dense(self._state_buffer.c_ptr, self._jac)) check(lib.CVodeSetLinearSolver(self._ode, linsolver, self._jac)) self._jac_func = self._problem.make_sundials_jac_dense() check(lib.CVodeSetJacFn(self._ode, self._jac_func.cffi)) elif linear_solver == "dense_finitediff": + self._jac = check(lib.SUNDenseMatrix(n_states, n_states)) linsolver = check(lib.SUNLinSol_Dense(self._state_buffer.c_ptr, self._jac)) check(lib.CVodeSetLinearSolver(self._ode, linsolver, self._jac)) elif linear_solver == "spgmr_finitediff": @@ -301,6 +345,14 @@ def _make_linsol(self, linear_solver) -> None: check(lib.SUNLinSolInitialize_SPGMR(linsolver)) jac_prod = self._problem.make_sundials_jac_prod() check(lib.CVodeSetJacTimes(self._ode, ffi.NULL, jac_prod.cffi)) + elif linear_solver == "band": + upper_bandwidth = kwargs.get("upper_bandwidth", None) + lower_bandwidth = kwargs.get("lower_bandwidth", None) + if upper_bandwidth is None or lower_bandwidth is None: + raise ValueError("Specify 'lower_bandwidth' and 'upper_bandwidth' arguments for banded solver.") + self._jac = check(lib.SUNBandMatrix(n_states, upper_bandwidth, lower_bandwidth)) + linsolver = check(lib.SUNLinSol_Band(self._state_buffer.c_ptr, self._jac)) + check(lib.CVodeSetLinearSolver(self._ode, linsolver, self._jac)) else: raise ValueError(f"Unknown linear solver: {linear_solver}") diff --git a/sunode/test_solve.py b/sunode/test_solve.py index 3b14355..5dae0f2 100644 --- a/sunode/test_solve.py +++ b/sunode/test_solve.py @@ -171,7 +171,11 @@ def rhs(t, y, p): 'b': 0.2 } problem = SympyProblem(params, states, rhs, derivative_params=[]) - linear_solver_opts = ["dense", "dense_finitediff", "spgmr_finitediff", "spgmr"] + linear_solver_opts = ["dense", "dense_finitediff", "spgmr_finitediff", "spgmr", "band"] for linear_solver in linear_solver_opts: - solver = Solver(problem, linear_solver=linear_solver) + if linear_solver == "band": + linear_solver_kwargs = {"upper_bandwidth": 1, "lower_bandwidth": 1} + else: + linear_solver_kwargs = {} + solver = Solver(problem, linear_solver=linear_solver, linear_solver_kwargs=linear_solver_kwargs) check_call_solve(solver, param_vals, None)