Skip to content

Commit

Permalink
undo Rosenbrock23 changes
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith committed Sep 27, 2024
1 parent ad6fdc2 commit 9d2a4c8
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 0 deletions.
11 changes: 11 additions & 0 deletions lib/OrdinaryDiffEqRosenbrock/src/interp_func.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
function DiffEqBase.interp_summary(::Type{cacheType},
dense::Bool) where {
cacheType <:
Union{Rosenbrock23ConstantCache,
Rosenbrock32ConstantCache,
Rosenbrock23Cache,
Rosenbrock32Cache}}
dense ? "specialized 2nd order \"free\" stiffness-aware interpolation" :
"1st order linear"
end

function DiffEqBase.interp_summary(::Type{cacheType},
dense::Bool) where {
cacheType <:
Expand Down
121 changes: 121 additions & 0 deletions lib/OrdinaryDiffEqRosenbrock/src/rosenbrock_interpolants.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,124 @@
### Fallbacks to capture
ROSENBROCKS_WITH_INTERPOLATIONS = Union{Rosenbrock23ConstantCache, Rosenbrock23Cache,
Rosenbrock32ConstantCache, Rosenbrock32Cache,
RosenbrockCombinedConstantCache,
RosenbrockCache}

function _ode_interpolant(Θ, dt, y₀, y₁, k,
cache::ROSENBROCKS_WITH_INTERPOLATIONS,
idxs, T::Type{Val{D}}, differential_vars) where {D}
throw(DerivativeOrderNotPossibleError())
end

function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
cache::ROSENBROCKS_WITH_INTERPOLATIONS,
idxs, T::Type{Val{D}}, differential_vars) where {D}
throw(DerivativeOrderNotPossibleError())
end

"""
From MATLAB ODE Suite by Shampine
"""
@def rosenbrock2332unpack begin
if cache isa OrdinaryDiffEqMutableCache
d = cache.tab.d
else
d = cache.d
end
end

@def rosenbrock2332pre0 begin
@rosenbrock2332unpack
c1 = Θ * (1 - Θ) / (1 - 2d)
c2 = Θ *- 2d) / (1 - 2d)
end

@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
cache::Union{Rosenbrock23ConstantCache,
Rosenbrock32ConstantCache}, idxs::Nothing,
T::Type{Val{0}}, differential_vars)
@rosenbrock2332pre0
@inbounds y₀ + dt * (c1 * k[1] + c2 * k[2])
end

@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
cache::Union{Rosenbrock23Cache, Rosenbrock32Cache},
idxs::Nothing, T::Type{Val{0}}, differential_vars)
@rosenbrock2332pre0
@inbounds @.. y₀+dt * (c1 * k[1] + c2 * k[2])
end

@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
cache::Union{Rosenbrock23ConstantCache, Rosenbrock23Cache,
Rosenbrock32ConstantCache, Rosenbrock32Cache
}, idxs, T::Type{Val{0}}, differential_vars)
@rosenbrock2332pre0
@.. y₀[idxs]+dt * (c1 * k[1][idxs] + c2 * k[2][idxs])
end

@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
cache::Union{Rosenbrock23ConstantCache,
Rosenbrock23Cache,
Rosenbrock32ConstantCache, Rosenbrock32Cache
}, idxs::Nothing, T::Type{Val{0}}, differential_vars)
@rosenbrock2332pre0
@inbounds @.. out=y₀ + dt * (c1 * k[1] + c2 * k[2])
out
end

@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
cache::Union{Rosenbrock23ConstantCache,
Rosenbrock23Cache,
Rosenbrock32ConstantCache, Rosenbrock32Cache
}, idxs, T::Type{Val{0}}, differential_vars)
@rosenbrock2332pre0
@views @.. out=y₀[idxs] + dt * (c1 * k[1][idxs] + c2 * k[2][idxs])
out
end

# First Derivative of the dense output
@def rosenbrock2332pre1 begin
@rosenbrock2332unpack
c1diff = (1 - 2 * Θ) / (1 - 2 * d)
c2diff = (2 * Θ - 2 * d) / (1 - 2 * d)
end

@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
cache::Union{Rosenbrock23ConstantCache, Rosenbrock23Cache,
Rosenbrock32ConstantCache, Rosenbrock32Cache
}, idxs::Nothing, T::Type{Val{1}}, differential_vars)
@rosenbrock2332pre1
@.. c1diff * k[1]+c2diff * k[2]
end

@muladd function _ode_interpolant(Θ, dt, y₀, y₁, k,
cache::Union{Rosenbrock23ConstantCache, Rosenbrock23Cache,
Rosenbrock32ConstantCache, Rosenbrock32Cache
}, idxs, T::Type{Val{1}}, differential_vars)
@rosenbrock2332pre1
@.. c1diff * k[1][idxs]+c2diff * k[2][idxs]
end

@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
cache::Union{Rosenbrock23ConstantCache,
Rosenbrock23Cache,
Rosenbrock32ConstantCache, Rosenbrock32Cache
}, idxs::Nothing, T::Type{Val{1}}, differential_vars)
@rosenbrock2332pre1
@.. out=c1diff * k[1] + c2diff * k[2]
out
end

@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k,
cache::Union{Rosenbrock23ConstantCache,
Rosenbrock23Cache,
Rosenbrock32ConstantCache, Rosenbrock32Cache
}, idxs, T::Type{Val{1}}, differential_vars)
@rosenbrock2332pre1
@views @.. out=c1diff * k[1][idxs] + c2diff * k[2][idxs]
out
end

"""
From MATLAB ODE Suite by Shampine
"""
Expand Down

0 comments on commit 9d2a4c8

Please sign in to comment.