From 9d2a4c83b92d64a9e92ba3658539f87d3cae609c Mon Sep 17 00:00:00 2001 From: oscarddssmith Date: Fri, 20 Sep 2024 16:17:07 -0400 Subject: [PATCH] undo Rosenbrock23 changes --- .../src/interp_func.jl | 11 ++ .../src/rosenbrock_interpolants.jl | 121 ++++++++++++++++++ 2 files changed, 132 insertions(+) diff --git a/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl b/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl index 077e27c7b0..43f66149fc 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl @@ -1,3 +1,14 @@ +function DiffEqBase.interp_summary(::Type{cacheType}, + dense::Bool) where { + cacheType <: + Union{Rosenbrock23ConstantCache, + Rosenbrock32ConstantCache, + Rosenbrock23Cache, + Rosenbrock32Cache}} + dense ? "specialized 2nd order \"free\" stiffness-aware interpolation" : + "1st order linear" +end + function DiffEqBase.interp_summary(::Type{cacheType}, dense::Bool) where { cacheType <: diff --git a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl index 20cf702e14..5a65299042 100644 --- a/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl +++ b/lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl @@ -1,3 +1,124 @@ +### Fallbacks to capture +ROSENBROCKS_WITH_INTERPOLATIONS = Union{Rosenbrock23ConstantCache, Rosenbrock23Cache, + Rosenbrock32ConstantCache, Rosenbrock32Cache, + RosenbrockCombinedConstantCache, + RosenbrockCache} + +function _ode_interpolant(Θ, dt, y₀, y₁, k, + cache::ROSENBROCKS_WITH_INTERPOLATIONS, + idxs, T::Type{Val{D}}, differential_vars) where {D} + throw(DerivativeOrderNotPossibleError()) +end + +function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, + cache::ROSENBROCKS_WITH_INTERPOLATIONS, + idxs, T::Type{Val{D}}, differential_vars) where {D} + throw(DerivativeOrderNotPossibleError()) +end + +""" +From MATLAB ODE Suite by Shampine +""" +@def rosenbrock2332unpack begin + if cache isa OrdinaryDiffEqMutableCache + d = cache.tab.d + else + d = cache.d + end +end + +@def rosenbrock2332pre0 begin + @rosenbrock2332unpack + c1 = Θ * (1 - Θ) / (1 - 2d) + c2 = Θ * (Θ - 2d) / (1 - 2d) +end + +@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, + cache::Union{Rosenbrock23ConstantCache, + Rosenbrock32ConstantCache}, idxs::Nothing, + T::Type{Val{0}}, differential_vars) + @rosenbrock2332pre0 + @inbounds y₀ + dt * (c1 * k[1] + c2 * k[2]) +end + +@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, + cache::Union{Rosenbrock23Cache, Rosenbrock32Cache}, + idxs::Nothing, T::Type{Val{0}}, differential_vars) + @rosenbrock2332pre0 + @inbounds @.. y₀+dt * (c1 * k[1] + c2 * k[2]) +end + +@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, + cache::Union{Rosenbrock23ConstantCache, Rosenbrock23Cache, + Rosenbrock32ConstantCache, Rosenbrock32Cache + }, idxs, T::Type{Val{0}}, differential_vars) + @rosenbrock2332pre0 + @.. y₀[idxs]+dt * (c1 * k[1][idxs] + c2 * k[2][idxs]) +end + +@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, + cache::Union{Rosenbrock23ConstantCache, + Rosenbrock23Cache, + Rosenbrock32ConstantCache, Rosenbrock32Cache + }, idxs::Nothing, T::Type{Val{0}}, differential_vars) + @rosenbrock2332pre0 + @inbounds @.. out=y₀ + dt * (c1 * k[1] + c2 * k[2]) + out +end + +@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, + cache::Union{Rosenbrock23ConstantCache, + Rosenbrock23Cache, + Rosenbrock32ConstantCache, Rosenbrock32Cache + }, idxs, T::Type{Val{0}}, differential_vars) + @rosenbrock2332pre0 + @views @.. out=y₀[idxs] + dt * (c1 * k[1][idxs] + c2 * k[2][idxs]) + out +end + +# First Derivative of the dense output +@def rosenbrock2332pre1 begin + @rosenbrock2332unpack + c1diff = (1 - 2 * Θ) / (1 - 2 * d) + c2diff = (2 * Θ - 2 * d) / (1 - 2 * d) +end + +@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, + cache::Union{Rosenbrock23ConstantCache, Rosenbrock23Cache, + Rosenbrock32ConstantCache, Rosenbrock32Cache + }, idxs::Nothing, T::Type{Val{1}}, differential_vars) + @rosenbrock2332pre1 + @.. c1diff * k[1]+c2diff * k[2] +end + +@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k, + cache::Union{Rosenbrock23ConstantCache, Rosenbrock23Cache, + Rosenbrock32ConstantCache, Rosenbrock32Cache + }, idxs, T::Type{Val{1}}, differential_vars) + @rosenbrock2332pre1 + @.. c1diff * k[1][idxs]+c2diff * k[2][idxs] +end + +@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, + cache::Union{Rosenbrock23ConstantCache, + Rosenbrock23Cache, + Rosenbrock32ConstantCache, Rosenbrock32Cache + }, idxs::Nothing, T::Type{Val{1}}, differential_vars) + @rosenbrock2332pre1 + @.. out=c1diff * k[1] + c2diff * k[2] + out +end + +@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, + cache::Union{Rosenbrock23ConstantCache, + Rosenbrock23Cache, + Rosenbrock32ConstantCache, Rosenbrock32Cache + }, idxs, T::Type{Val{1}}, differential_vars) + @rosenbrock2332pre1 + @views @.. out=c1diff * k[1][idxs] + c2diff * k[2][idxs] + out +end + """ From MATLAB ODE Suite by Shampine """