Skip to content

Commit

Permalink
half a LinearInterpolation POC
Browse files Browse the repository at this point in the history
  • Loading branch information
SouthEndMusic committed Jul 7, 2024
1 parent 152bcbf commit 9737c50
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 15 deletions.
50 changes: 37 additions & 13 deletions ext/DataInterpolationsChainRulesCoreExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,50 @@ module DataInterpolationsChainRulesCoreExt
if isdefined(Base, :get_extension)
using DataInterpolations: _interpolate, derivative, AbstractInterpolation,
LagrangeInterpolation, AkimaInterpolation,
BSplineInterpolation, BSplineApprox
BSplineInterpolation, BSplineApprox, LinearInterpolation,
linear_interpolation_parameters
using ChainRulesCore
else
using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation,
LagrangeInterpolation, AkimaInterpolation,
BSplineInterpolation, BSplineApprox
BSplineInterpolation, BSplineApprox, LinearInterpolation,
linear_interpolation_parameters
using ..ChainRulesCore
end

function ChainRulesCore.rrule(::typeof(_interpolate),
A::Union{
LagrangeInterpolation,
AkimaInterpolation,
BSplineInterpolation,
BSplineApprox
},
t::Number)
deriv = derivative(A, t)
interpolate_pullback(Δ) = (NoTangent(), NoTangent(), deriv * Δ)
return _interpolate(A, t), interpolate_pullback
function ChainRulesCore.rrule(::typeof(linear_interpolation_parameters), u, t, idx)
slope = linear_interpolation_parameters(u, t, idx)

function linear_interpolation_parameters_pullback(Δslope)
du = @thunk(Δslope*...) # TODO: How to handle sparsity?
dt = @thunk(Δslope*...) # TODO: How to handle sparsity?
df = NoTangent()
didx = NoTangent()

return (df, du, dt, didx)
end

return slope, linear_interpolation_parameters_pullback
end

function _tangent_u(A::LinearInterpolation, t)
... # TODO: How to handle sparsity?
end

function _tangent_t(A::LinearInterpolation, t)
... # TODO: How to handle sparsity?
end

function ChainRulesCore.rrule(::typeof(_interpolate), A::AType, t, iguess) where {AType}
u = _interpolate(A, t, iguess)[1]
function _interpolate_pullback(Δ)
df = NoTangent()
dA = Tangent{AType}(; u = _tangent_u(A, t), t = _tangent_t(A, t))
dt = @thunk(derivative(A, t)*Δ)
diguess = NoTangent()
return df, dA, dt, diguess
end
u, _interpolate_pullback
end

function ChainRulesCore.frule((_, _, Δt), ::typeof(_interpolate), A::AbstractInterpolation,
Expand Down
9 changes: 7 additions & 2 deletions src/parameter_caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@ struct LinearParameterCache{pType}
end

function LinearParameterCache(u, t)
slope = linear_interpolation_parameters.(Ref(u), Ref(t), 1:(length(t) - 1))
slope_prototype = linear_interpolation_parameters(u, t, 1)
idxs = 1:(length(t) - 1)
slope = [zero(slope_prototype) for i in idxs]
for idx in idxs
slope[idx] = linear_interpolation_parameters(u, t, idx)
end
return LinearParameterCache(slope)
end

function linear_interpolation_parameters(u, t, idx)
function linear_interpolation_parameters(u, t, idx::Integer)
Δu = u isa AbstractMatrix ? u[:, idx + 1] - u[:, idx] : u[idx + 1] - u[idx]
Δt = t[idx + 1] - t[idx]
slope = Δu / Δt
Expand Down

0 comments on commit 9737c50

Please sign in to comment.