-
-
Notifications
You must be signed in to change notification settings - Fork 214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Linear interpolation for algebraic variables #2048
Conversation
a303634
to
e7dbe93
Compare
I made a very first attempt to get this to work, and am now at a stage again where I could use some help from someone more familiar with the code. There are several questions
I'm currently testing with the rober example from https://docs.sciml.ai/DiffEqDocs/stable/tutorials/dae_example/#Mass-Matrix-Differential-Algebraic-Equations-(DAEs) To elaborate on the first point, there are two methods of ode_interpolation that currently don't pass differential_vars anywhere, and I'm not sure when they are used. Are there other code paths that lead here that we need to handle? The second point is a bit annoying. It seems like |
I did the implementation as best as I can, from here on it's a tone of clean up and review and handling edge cases I suppose. I'm currently throwing an error in the base case to make sure we handle all the cases that we should, and I see that is getting hit by a |
Summary of the current failures and my ongoing understanding of them
|
src/dense/generic_dense.jl
Outdated
@@ -324,6 +340,7 @@ function ode_interpolation(tvals, id::I, idxs, deriv::D, p, | |||
@unpack ts, timeseries, ks, f, cache = id | |||
@inbounds tdir = sign(ts[end] - ts[1]) | |||
idx = sortperm(tvals, rev = tdir < 0) | |||
differential_vars = get_differential_vars(f, size(timeseries[begin])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should maybe be passed somewhere, but @oscardssmith had some argument why not all cache types need it that I forgot so idk if deep down in the evaluate_interpolant machinery there is somewhere this needs to go
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we do want to cache this eventually, but at first it's not necessary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure for now I'm talking more about covering all the places. This variable is currently unused.
src/dense/generic_dense.jl
Outdated
end | ||
|
||
function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}) where {TI} | ||
function partial_hermite_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars) where {TI} | ||
throw("how did we get here") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once I'm confident I've handled all the cases I guess either completely remove this method, or remove the error and leave it as a fallback if it turns out there is some insane edge case somewhere in the ecosystem.
It is officially faster to broadcast both interpolations than to do the indexing, go figure.
|
You should probably have a fast path for no algebraic variables though. Lots of people only solve ODEs. |
Ok so basically a complete rewrite to avoid allocations and use broadcasting and optimize for ODE. I think if we default For broadcasting we need to change And I fear that for completely avoiding allocations in the in-place version I need to write a fully integrated and broadcasted interpolation. For the out of place version you can just do h = hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T)
l = linear_interpolant(Θ, dt, y₀, y₁, idxs, T)
@.. broadcast=false h*differential_vars + l*!differential_vars but for the in-place version you can't have both write to hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T)
l = linear_interpolant(Θ, dt, y₀, y₁, idxs, T)
@.. broadcast=false out=out*differential_vars + l*!differential_vars So I'm thinking to change return (1 - Θ) * y₀[idxs] + Θ * y₁[idxs] +
Θ * (Θ - 1) *
((1 - 2Θ) * (y₁[idxs] - y₀[idxs]) + (Θ - 1) * dt * k[1][idxs] +
Θ * dt * k[2][idxs]) to return (1 - Θ) * y₀[idxs] + Θ * y₁[idxs] +
isdiff * Θ * (Θ - 1) *
((1 - 2Θ) * (y₁[idxs] - y₀[idxs]) + (Θ - 1) * dt * k[1][idxs] +
Θ * dt * k[2][idxs]) That adds like one multiplication. Maybe that's insignificant enough that we don't need a complete fast path? Would it make sense to pass it as a type-level boolean so in the fast path you can compile away the multiplication, or would that cause more trouble than it's worth when it can't be statically inferred? I'd really like to avoid having two complete copies of hermite interpolation for a single multiplication. |
That wasn't so bad for the primal. For derivatives it's a bit more tricky but if it passes the tests I'll do the rest. |
src/dense/generic_dense.jl
Outdated
hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T) | ||
else | ||
hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T, differential_vars) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why not just use dispatch?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ambiguous method error
src/dense/generic_dense.jl
Outdated
if idxs === nothing || differential_vars === nothing | ||
return differential_vars | ||
else | ||
return differential_vars[idxs] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not a view?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good idea I suppose
I'm honestly a bit puzzled by some of these CI errors. Like the |
The foop one isn't your fault. |
fbba739
to
a48046d
Compare
Is this my fault? https://github.com/SciML/OrdinaryDiffEq.jl/actions/runs/6940604309/job/18879835886?pr=2048#step:6:637 It seems to be a regression test for #2055 and I don't really see the connection or the bug. |
I just merged a PR that was all green except for the format check (which the formatter is still having issues, I'm going to put a bounty on that) #2069, so if interpolation things are failing I'd venture to guess there was a merge issue. |
a48046d
to
d6e10b9
Compare
I'm hoping Meanwhile I've started adding all the derivative implementations. |
@muladd function hermite_interpolant!(out::Array, Θ, dt, y₀, y₁, k, idxs, T::Type{Val{2}}) # Default interpolant is Hermite | ||
@views @.. broadcast=false out=(-4 * dt * k[1][idxs] - 2 * dt * k[2][idxs] - | ||
6 * y₀[idxs] + | ||
Θ * (6 * dt * k[1][idxs] + 6 * dt * k[2][idxs] + | ||
12 * y₀[idxs] - 12 * y₁[idxs]) + 6 * y₁[idxs]) / | ||
(dt * dt) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this one had two complete copies of the interpolation??? I just kept the last one which was effectively being used I guess.
Seems to throw an error for
|
For what problem/solver? Anyway, I'll add a bunch of tests to gain some confidence that all of those dozens of methods I changed actually work. |
I'm pretty sure I've seen matrix-shaped idxs and outputs, but when I just pass a matrix to sol, it errors. Is there some other code path that results in a matrix? |
It seems like master fails the same checks as this branch. If the tests I've just added pass on CI and didn't break anything else, this should be good for another round of reviews and hopefully merge. |
src/dense/rosenbrock_interpolants.jl
Outdated
@@ -282,7 +282,7 @@ end | |||
|
|||
@muladd function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, | |||
cache::Union{Rosenbrock5ConstantCache, Rosenbrock5Cache}, | |||
idxs, T::Type{Val{1}}) | |||
idxs, T::Type{Val{1}}, dv=nothing) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why would dv show up here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we pass it as nothing to all interpolants
using StaticArrays, LinearAlgebra
mm = SMatrix{2,2}(rand(2,2))
u = SA[1.0,2.0]
reshape(diag(mm) .!= 0, size(u)) isa SVector # true |
I'll need to do DelayDiffEq.jl because it reaches into internals that are changed on this. |
Wow you've been busy, thanks! |
BTW by finite-differencing, I mean that the coefficients of the polynomial interpolants and their derivatives for algebraic variables are (usually linear) functions of |
That's not a local scheme though, as you're using more than just the current step information? That can only be done in the post solution interpolation. |
That just means you have to restrict to one-sided finite differencing techniques I was kind of expecting But if there are scenarios where it uses a poor finite-differencing scheme or returns zero results when it shouldn't then maybe we just need a way to know when sol(t, Val{n}) is the right thing to use and when it isn't (esp. for post-solution interpolation) |
But you only have u_n and u_{n+1} and the k's, which means if you finite difference the u's in a step then it must be a linear interpolation. |
Someone apparently already ran into this problem 😅 SciML/DelayDiffEq.jl#274 |
Yeah it was a fundamental break and just needs a downstream fix. I'll put that in ASAP. Just been doing grant writing all week and was using this as a sidepiece while avoiding the real work. |
I suspect this PR caused a regression: #2086 |
What makes you say so? If that's the case the thing to look for is a broadcast in Hermite interpolant that returns an array rather than a scalar. |
There are only two changes between 6.60 and 6.61, and the other doesn't appear related, but I might be wrong. v6.60.0...v6.61.0 |
Except for those solvers with specialized stiffness aware* interpolation, hermite interpolation is used which uses the derivative. For algebraic variables this ends up incorrectly using the residual. This PR aims to make it fall back to linear interpolation for algebraic variables.
*what does that actually mean?
At the moment I just pushed whatever I was doing with @oscardssmith last time