From 9737c508af96c94b49ab8a5ec9f011190a82daf0 Mon Sep 17 00:00:00 2001 From: Bart de Koning Date: Sun, 7 Jul 2024 13:39:36 +0200 Subject: [PATCH] half a LinearInterpolation POC --- ext/DataInterpolationsChainRulesCoreExt.jl | 50 ++++++++++++++++------ src/parameter_caches.jl | 9 +++- 2 files changed, 44 insertions(+), 15 deletions(-) diff --git a/ext/DataInterpolationsChainRulesCoreExt.jl b/ext/DataInterpolationsChainRulesCoreExt.jl index 34e27841..ebd00743 100644 --- a/ext/DataInterpolationsChainRulesCoreExt.jl +++ b/ext/DataInterpolationsChainRulesCoreExt.jl @@ -3,26 +3,50 @@ module DataInterpolationsChainRulesCoreExt if isdefined(Base, :get_extension) using DataInterpolations: _interpolate, derivative, AbstractInterpolation, LagrangeInterpolation, AkimaInterpolation, - BSplineInterpolation, BSplineApprox + BSplineInterpolation, BSplineApprox, LinearInterpolation, + linear_interpolation_parameters using ChainRulesCore else using ..DataInterpolations: _interpolate, derivative, AbstractInterpolation, LagrangeInterpolation, AkimaInterpolation, - BSplineInterpolation, BSplineApprox + BSplineInterpolation, BSplineApprox, LinearInterpolation, + linear_interpolation_parameters using ..ChainRulesCore end -function ChainRulesCore.rrule(::typeof(_interpolate), - A::Union{ - LagrangeInterpolation, - AkimaInterpolation, - BSplineInterpolation, - BSplineApprox - }, - t::Number) - deriv = derivative(A, t) - interpolate_pullback(Δ) = (NoTangent(), NoTangent(), deriv * Δ) - return _interpolate(A, t), interpolate_pullback +function ChainRulesCore.rrule(::typeof(linear_interpolation_parameters), u, t, idx) + slope = linear_interpolation_parameters(u, t, idx) + + function linear_interpolation_parameters_pullback(Δslope) + du = @thunk(Δslope*...) # TODO: How to handle sparsity? + dt = @thunk(Δslope*...) # TODO: How to handle sparsity? + df = NoTangent() + didx = NoTangent() + + return (df, du, dt, didx) + end + + return slope, linear_interpolation_parameters_pullback +end + +function _tangent_u(A::LinearInterpolation, t) + ... # TODO: How to handle sparsity? +end + +function _tangent_t(A::LinearInterpolation, t) + ... # TODO: How to handle sparsity? +end + +function ChainRulesCore.rrule(::typeof(_interpolate), A::AType, t, iguess) where {AType} + u = _interpolate(A, t, iguess)[1] + function _interpolate_pullback(Δ) + df = NoTangent() + dA = Tangent{AType}(; u = _tangent_u(A, t), t = _tangent_t(A, t)) + dt = @thunk(derivative(A, t)*Δ) + diguess = NoTangent() + return df, dA, dt, diguess + end + u, _interpolate_pullback end function ChainRulesCore.frule((_, _, Δt), ::typeof(_interpolate), A::AbstractInterpolation, diff --git a/src/parameter_caches.jl b/src/parameter_caches.jl index 2820dc8f..b550b621 100644 --- a/src/parameter_caches.jl +++ b/src/parameter_caches.jl @@ -3,11 +3,16 @@ struct LinearParameterCache{pType} end function LinearParameterCache(u, t) - slope = linear_interpolation_parameters.(Ref(u), Ref(t), 1:(length(t) - 1)) + slope_prototype = linear_interpolation_parameters(u, t, 1) + idxs = 1:(length(t) - 1) + slope = [zero(slope_prototype) for i in idxs] + for idx in idxs + slope[idx] = linear_interpolation_parameters(u, t, idx) + end return LinearParameterCache(slope) end -function linear_interpolation_parameters(u, t, idx) +function linear_interpolation_parameters(u, t, idx::Integer) Δu = u isa AbstractMatrix ? u[:, idx + 1] - u[:, idx] : u[idx + 1] - u[idx] Δt = t[idx + 1] - t[idx] slope = Δu / Δt