Skip to content

Commit

Permalink
first working prototype
Browse files Browse the repository at this point in the history
  • Loading branch information
Pepijn de Vos committed Nov 13, 2023
1 parent 8c2d0f4 commit 59017fe
Showing 1 changed file with 43 additions and 33 deletions.
76 changes: 43 additions & 33 deletions src/dense/generic_dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,22 @@ function evaluate_interpolant(f, Θ, dt, timeseries, i₋, i₊, cache, idxs,
end
end

function get_differential_vars(f, len)
differential_vars = trues(len)
if hasproperty(f, :mass_matrix)
mm = f.mass_matrix
if mm isa UniformScaling
# already correct
elseif isdiag(mm)
differential_vars = Diagonal(mm).diag .!= 0
else
@show typeof(mm)
error("QR factorizations is annoying")
end
end
@show differential_vars
end

"""
ode_interpolation(tvals,ts,timeseries,ks)
Expand All @@ -326,6 +342,7 @@ function ode_interpolation(tvals, id::I, idxs, deriv::D, p,
@unpack ts, timeseries, ks, f, cache = id
@inbounds tdir = sign(ts[end] - ts[1])
idx = sortperm(tvals, rev = tdir < 0)
differential_vars = get_differential_vars(f, length(timeseries[begin]))
# start the search thinking it's ts[1]-ts[2]
i₋₊ref = Ref((1, 2))
vals = map(idx) do j
Expand Down Expand Up @@ -363,6 +380,7 @@ function ode_interpolation!(vals, tvals, id::I, idxs, deriv::D, p,
@unpack ts, timeseries, ks, f, cache = id
@inbounds tdir = sign(ts[end] - ts[1])
idx = sortperm(tvals, rev = tdir < 0)
differential_vars = get_differential_vars(f, length(timeseries[begin]))

# start the search thinking it's in ts[1]-ts[2]
i₋ = 1
Expand Down Expand Up @@ -448,20 +466,7 @@ function ode_interpolation(tval::Number, id::I, idxs, deriv::D, p,
@unpack ts, timeseries, ks, f, cache = id
@inbounds tdir = sign(ts[end] - ts[1])

differential_vars = trues(length(timeseries[begin]))
if hasproperty(f, :mass_matrix)
mm = f.mass_matrix
if mm isa UniformScaling
# already correct
elseif isdiag(mm)
differential_vars = Diagonal(mm).diag .!= 0
else
@show typeof(mm)
error("QR factorizations is annoying")
end
end
@show differential_vars

differential_vars = get_differential_vars(f, length(timeseries[begin]))

if continuity === :left
# we have i₋ = i₊ = 1 if tval = ts[1], i₊ = i₋ + 1 = lastindex(ts) if tval > ts[end],
Expand Down Expand Up @@ -511,20 +516,7 @@ function ode_interpolation!(out, tval::Number, id::I, idxs, deriv::D, p,
@unpack ts, timeseries, ks, f, cache = id
@inbounds tdir = sign(ts[end] - ts[1])

differential_vars = trues(length(timeseries[begin]))
if hasproperty(f, :mass_matrix)
mm = f.mass_matrix
if mm isa UniformScaling
# already correct
elseif isdiag(mm)
differential_vars = Diagonal(mm).diag .!= 0
else
# @show typeof(mm)
error("QR factorizations is annoying")
end
end
@show "!" differential_vars

differential_vars = get_differential_vars(f, length(timeseries[begin]))

if continuity === :left
# we have i₋ = i₊ = 1 if tval = ts[1], i₊ = i₋ + 1 = lastindex(ts) if tval > ts[end],
Expand Down Expand Up @@ -634,24 +626,42 @@ end

function _ode_interpolant(Θ, dt, y₀::AbstractArray, y₁::AbstractArray, k, cache, idxs::AbstractArray, T::Type{Val{TI}}, differential_vars) where {TI}
sel = differential_vars[idxs]
@show y₀ idxs differential_vars sel
diffidxs = idxs[sel]
linidxs = idxs[.!sel]
diff = hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, diffidxs, T)
@show lin = linear_interpolant(Θ, dt, y₀, y₁, linidxs, T)
@show res = similar(y₀, eltype(y₀), size(sel))
res[diffidxs] = diff
res[linidxs] = lin
lin = linear_interpolant(Θ, dt, y₀, y₁, linidxs, T)
res = similar(y₀, eltype(y₀), size(sel))
res[sel] = diff
res[.!sel] = lin
res
end

function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs::Nothing, T::Type{Val{TI}}, differential_vars) where {TI}
_ode_interpolant(Θ, dt, y₀, y₁, k, cache, eachindex(y₀), T, differential_vars)
end

function _ode_interpolant(Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars) where {TI}
@assert all(differential_vars)
@show differential_vars
hermite_interpolant(Θ, dt, y₀, y₁, k, Val{cache isa OrdinaryDiffEqMutableCache}, idxs, T)
end

function _ode_interpolant!(out::AbstractArray, Θ, dt, y₀, y₁, k, cache, idxs::AbstractArray, T::Type{Val{TI}}, differential_vars) where {TI}
sel = differential_vars[idxs]
diffidxs = idxs[sel]
linidxs = idxs[.!sel]
hermite_interpolant!(@view(out[sel]), Θ, dt, y₀, y₁, k, diffidxs, T)
linear_interpolant!(@view(out[.!sel]), Θ, dt, y₀, y₁, linidxs, T)
out
end

function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs::Nothing, T::Type{Val{TI}}, differential_vars) where {TI}
_ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, eachindex(y₀), T, differential_vars)
end

function _ode_interpolant!(out, Θ, dt, y₀, y₁, k, cache, idxs, T::Type{Val{TI}}, differential_vars) where {TI}
@assert all(differential_vars)
@show differential_vars
hermite_interpolant!(out, Θ, dt, y₀, y₁, k, idxs, T)
end

Expand Down

0 comments on commit 59017fe

Please sign in to comment.