Skip to content

Commit

Permalink
Merge pull request #201 from sathvikbhagavan/sb/lagrange
Browse files Browse the repository at this point in the history
fix: Lagrange interpolations derivative
  • Loading branch information
ChrisRackauckas authored Nov 5, 2023
2 parents 98e9ffa + 1140d4a commit e16d610
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 58 deletions.
100 changes: 48 additions & 52 deletions src/derivatives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,74 +37,70 @@ end

function derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
idxs = findRequiredIdxs(A, t)
if A.t[idxs[1]] == t
return zero(A.u[idxs[1]])
end
G = zero(A.u[1])
F = zero(A.t[1])
DG = zero(A.u[1])
DF = zero(A.t[1])
tmp = G
for i in 1:length(idxs)
if isnan(A.bcache[idxs[i]])
der = zero(A.u[1])
for j in eachindex(A.t)
tmp = zero(A.t[1])
if isnan(A.bcache[j])
mult = one(A.t[1])
for j in 1:(i - 1)
mult *= (A.t[idxs[i]] - A.t[idxs[j]])
for i in 1:(j - 1)
mult *= (A.t[j] - A.t[i])
end
for j in (i + 1):length(idxs)
mult *= (A.t[idxs[i]] - A.t[idxs[j]])
for i in (j + 1):length(A.t)
mult *= (A.t[j] - A.t[i])
end
A.bcache[idxs[i]] = mult
A.bcache[j] = mult
else
mult = A.bcache[idxs[i]]
mult = A.bcache[j]
end
for l in eachindex(A.t)
if l != j
k = one(A.t[1])
for m in eachindex(A.t)
if m != j && m != l
k *= (t - A.t[m])
end
end
k *= inv(mult)
tmp += k
end
end
wi = inv(mult)
tti = t - A.t[idxs[i]]
tmp = wi / (t - A.t[idxs[i]])
g = tmp * A.u[idxs[i]]
G += g
DG -= g / (t - A.t[idxs[i]])
F += tmp
DF -= tmp / (t - A.t[idxs[i]])
der += A.u[j]*tmp
end
(DG * F - G * DF) / (F^2)
der
end

function derivative(A::LagrangeInterpolation{<:AbstractMatrix}, t::Number)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
idxs = findRequiredIdxs(A, t)
if A.t[idxs[1]] == t
return zero(A.u[:, idxs[1]])
end
G = zero(A.u[:, 1])
F = zero(A.t[1])
DG = zero(A.u[:, 1])
DF = zero(A.t[1])
tmp = G
for i in 1:length(idxs)
if isnan(A.bcache[idxs[i]])
der = zero(A.u[:, 1])
for j in eachindex(A.t)
tmp = zero(A.t[1])
if isnan(A.bcache[j])
mult = one(A.t[1])
for j in 1:(i - 1)
mult *= (A.t[idxs[i]] - A.t[idxs[j]])
for i in 1:(j - 1)
mult *= (A.t[j] - A.t[i])
end
for j in (i + 1):length(idxs)
mult *= (A.t[idxs[i]] - A.t[idxs[j]])
for i in (j + 1):length(A.t)
mult *= (A.t[j] - A.t[i])
end
A.bcache[idxs[i]] = mult
A.bcache[j] = mult
else
mult = A.bcache[idxs[i]]
mult = A.bcache[j]
end
for l in eachindex(A.t)
if l != j
k = one(A.t[1])
for m in eachindex(A.t)
if m != j && m != l
k *= (t - A.t[m])
end
end
k *= inv(mult)
tmp += k
end
end
wi = inv(mult)
tti = t - A.t[idxs[i]]
tmp = wi / (t - A.t[idxs[i]])
g = tmp * A.u[:, idxs[i]]
@. G += g
@. DG -= g / (t - A.t[idxs[i]])
F += tmp
DF -= tmp / (t - A.t[idxs[i]])
@. der += A.u[:, j]*tmp
end
@. (DG * F - G * DF) / (F^2)
der
end

derivative(A::LagrangeInterpolation{<:AbstractVector}, t::Number, i) = derivative(A, t), i
Expand Down
40 changes: 34 additions & 6 deletions test/derivative_tests.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,53 @@
using DataInterpolations, Test
using FiniteDifferences
using DataInterpolations: derivative
# using Symbolics
using Symbolics

function test_derivatives(method, u, t, args...; name::String)
func = method(u, t, args...; extrapolate = true)
trange = range(minimum(t) - 5.0, maximum(t) + 5.0, length = 32)
trange = collect(range(minimum(t) - 5.0, maximum(t) + 5.0, step = 0.1))
trange_exclude = filter(x -> !in(x, t), trange)
@testset "$name" begin
for t in trange
cdiff = central_fdm(5, 1; geom = true)(_t -> func(_t), t)
adiff = derivative(func, t)
# Rest of the points
for _t in trange_exclude
cdiff = central_fdm(5, 1; geom = true)(func, _t)
adiff = derivative(func, _t)
@test isapprox(cdiff, adiff, atol = 1e-8)
end

# Interpolation time points
for _t in t[2:end-1]
fdiff = if func isa BSplineInterpolation || func isa BSplineApprox
forward_fdm(5, 1; geom = true)(func, _t)
else
backward_fdm(5, 1; geom = true)(func, _t)
end
adiff = derivative(func, _t)
@test isapprox(fdiff, adiff, atol = 1e-8)
end

# t = t0
fdiff = forward_fdm(5, 1; geom = true)(func, t[1])
adiff = derivative(func, t[1])
if func isa BSplineInterpolation || func isa BSplineApprox
# Bug in BSplines
@test_broken isapprox(fdiff, adiff, atol = 1e-8)
else
@test isapprox(fdiff, adiff, atol = 1e-8)
end

# t = tend
fdiff = backward_fdm(5, 1; geom = true)(func, t[end])
adiff = derivative(func, t[end])
@test isapprox(fdiff, adiff, atol = 1e-8)
end
func = method(u, t, args...)
@test_throws DataInterpolations.ExtrapolationError derivative(func, t[1] - 1.0)
@test_throws DataInterpolations.ExtrapolationError derivative(func, t[end] + 1.0)
end

@testset "Linear Interpolation" begin
u = 2.0collect(1:10)
u = vcat(collect(1:5), 2*collect(6:10))
t = 1.0collect(1:10)
test_derivatives(LinearInterpolation, u, t; name = "Linear Interpolation (Vector)")
u = vcat(2.0collect(1:10)', 3.0collect(1:10)')
Expand Down

0 comments on commit e16d610

Please sign in to comment.