Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor!: remove indexing dispatches and add dispatch for higher order derivatives with Symbolics #247

Merged
merged 7 commits into from
May 6, 2024
4 changes: 4 additions & 0 deletions ext/DataInterpolationsSymbolicsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
end
SymbolicUtils.promote_symtype(::typeof(derivative), _...) = Real

function Symbolics.derivative(::typeof(derivative), args::NTuple{3, Any}, ::Val{2})
Symbolics.unwrap(derivative(args[1], Symbolics.wrap(args[2]), args[3] + 1))

Check warning on line 25 in ext/DataInterpolationsSymbolicsExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/DataInterpolationsSymbolicsExt.jl#L24-L25

Added lines #L24 - L25 were not covered by tests
end
ChrisRackauckas marked this conversation as resolved.
Show resolved Hide resolved

function Symbolics.derivative(interp::AbstractInterpolation, args::NTuple{1, Any}, ::Val{1})
Symbolics.unwrap(derivative(interp, Symbolics.wrap(args[1])))
end
Expand Down
2 changes: 1 addition & 1 deletion src/DataInterpolations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module DataInterpolations

### Interface Functionality

abstract type AbstractInterpolation{FT, T} <: AbstractVector{T} end
abstract type AbstractInterpolation{FT, T} end

Base.size(A::AbstractInterpolation) = size(A.u)
Base.size(A::AbstractInterpolation{true}) = length(A.u) .+ size(A.t)
Expand Down
2 changes: 1 addition & 1 deletion src/integrals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ function integral(A::AbstractInterpolation, t1::Number, t2::Number)
if A.t[idx2] == t2
idx2 -= 1
end
total = zero(eltype(A))
total = zero(eltype(A.u))
for idx in idx1:idx2
lt1 = idx == idx1 ? t1 : A.t[idx]
lt2 = idx == idx2 ? t2 : A.t[idx + 1]
Expand Down
2 changes: 1 addition & 1 deletion src/interpolation_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
function _interpolate(A::LinearInterpolation{<:AbstractVector}, t::Number, iguess)
if isnan(t)
# For correct derivative with NaN
idx = firstindex(A) - 1
idx = firstindex(A.u) - 1

Check warning on line 11 in src/interpolation_methods.jl

View check run for this annotation

Codecov / codecov/patch

src/interpolation_methods.jl#L11

Added line #L11 was not covered by tests
t1 = t2 = one(eltype(A.t))
u1 = u2 = one(eltype(A.u))
else
Expand Down
22 changes: 17 additions & 5 deletions test/derivative_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -223,22 +223,34 @@ end
A = QuadraticSpline(u, t)
@variables τ, ω(τ)
D = Symbolics.Differential(τ)
D2 = Symbolics.Differential(τ)^2
expr = A(ω)
@test isequal(Symbolics.derivative(expr, τ), D(ω) * DataInterpolations.derivative(A, ω))

derivexpr = expand_derivatives(substitute(D(A(ω)), Dict(ω => 0.5τ)))
symfunc = Symbolics.build_function(derivexpr, τ; expression = Val{false})
@test symfunc(0.5) == 0.5 * 3
derivexpr1 = expand_derivatives(substitute(D(A(ω)), Dict(ω => 0.5τ)))
derivexpr2 = expand_derivatives(substitute(D2(A(ω)), Dict(ω => 0.5τ)))
symfunc1 = Symbolics.build_function(derivexpr1, τ; expression = Val{false})
symfunc2 = Symbolics.build_function(derivexpr2, τ; expression = Val{false})
@test symfunc1(0.5) == 0.5 * 3
@test symfunc2(0.5) == 0.5 * 6

u = [0.0, 1.5, 0.0]
t = [0.0, 0.5, 1.0]
@variables τ
D = Symbolics.Differential(τ)
D2 = Symbolics.Differential(τ)^2
D3 = Symbolics.Differential(τ)^3
f = LinearInterpolation(u, t)
df = expand_derivatives(D(f(τ)))
symfunc = Symbolics.build_function(df, τ; expression = Val{false})
df2 = expand_derivatives(D2(f(τ)))
df3 = expand_derivatives(D3(f(τ)))
symfunc1 = Symbolics.build_function(df, τ; expression = Val{false})
symfunc2 = Symbolics.build_function(df2, τ; expression = Val{false})
symfunc3 = Symbolics.build_function(df3, τ; expression = Val{false})
ts = 0.0:0.1:1.0
@test all(map(ti -> symfunc(ti) == derivative(f, ti), ts))
@test all(map(ti -> symfunc1(ti) == derivative(f, ti), ts))
@test all(map(ti -> symfunc2(ti) == derivative(f, ti, 2), ts))
@test_throws DataInterpolations.DerivativeNotFoundError symfunc3(ts[1])
end

@testset "Jacobian tests" begin
Expand Down
2 changes: 0 additions & 2 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ u = 2.0collect(1:10)
t = 1.0collect(1:10)
A = LinearInterpolation(u, t)

@test length(A) == 20
for i in 1:10
@test u[i] == A[i]
end
Expand All @@ -13,7 +12,6 @@ for i in 11:20
end

A = LinearInterpolation{false}(u, t, true)
@test length(A) == 10
for i in 1:10
@test u[i] == A[i]
end
Expand Down
8 changes: 6 additions & 2 deletions test/online_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,13 @@ u = [0, 1, 0]
for di in [LinearInterpolation, QuadraticInterpolation, ConstantInterpolation]
li = di(copy(u), copy(t))
append!(li, u, t)
@test li == di(vcat(u, u), vcat(t, t))
li2 = di(vcat(u, u), vcat(t, t))
@test li.u == li2.u
@test li.t == li2.t

li = di(copy(u), copy(t))
push!(li, 1, 4)
@test li == di(vcat(u, 1), vcat(t, 4))
li2 = di(vcat(u, 1), vcat(t, 4))
@test li.u == li2.u
@test li.t == li2.t
end
Loading