Skip to content

Commit

Permalink
Derivative extrapolation POC
Browse files Browse the repository at this point in the history
  • Loading branch information
SouthEndMusic committed Nov 14, 2024
1 parent 832ca31 commit ea91f76
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 9 deletions.
47 changes: 40 additions & 7 deletions src/derivatives.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,48 @@
function derivative(A, t, order = 1)
((t < A.t[1] || t > A.t[end]) && !A.extrapolate) && throw(ExtrapolationError())
iguess = A.iguesser
(order (1, 2)) && throw(DerivativeNotFoundError())
if t < first(A.t)
_extrapolate_derivative_down(A, t, order)
elseif t > last(A.t)
_extrapolate_derivative_up(A, t, order)
else
(order == 1) ? _derivative(A, t, A.iguesser) :
ForwardDiff.derivative(t -> begin
_derivative(A, t, iguess)
end, t)
end
end

return if order == 1
_derivative(A, t, iguess)
elseif order == 2
function _extrapolate_derivative_down(A, t, order)
(; extrapolation_down) = A
typed_zero = zero(one(A.u[1]) / one(A.t[1]))
if extrapolation_down == ExtrapolationType.none
throw(UpExtrapolationError())
elseif extrapolation_down == ExtrapolationType.constant
typed_zero
elseif extrapolation_down == ExtrapolationType.linear
(order == 1) ? derivative(A, first(A.t)) : typed_zero
elseif extrapolation_down == ExtrapolationType.extension
(order == 1) ? _derivative(A, t, A.iguesser) :
ForwardDiff.derivative(t -> begin
_derivative(A, t, iguess)
end, t)
end
end

function _extrapolate_derivative_up(A, t, order)
(; extrapolation_up) = A
typed_zero = zero(one(A.u[1]) / one(A.t[1]))
if extrapolation_up == ExtrapolationType.none
throw(DownExtrapolationError())
elseif extrapolation_up == ExtrapolationType.constant
typed_zero
elseif extrapolation_up == ExtrapolationType.linear
(order == 1) ? derivative(A, last(A.t)) : typed_zero
elseif extrapolation_up == ExtrapolationType.extension
(order == 1) ? _derivative(A, t, A.iguesser) :
ForwardDiff.derivative(t -> begin
_derivative(A, t, iguess)
end, t)
else
throw(DerivativeNotFoundError())
end
end

Expand Down
4 changes: 2 additions & 2 deletions src/interpolation_methods.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ end
function _extrapolate_down(A, t)
(; extrapolation_down) = A
if extrapolation_down == ExtrapolationType.none
throw(ExtrapolationError(DOWN_EXTRAPOLATION_ERROR))
throw(DownExtrapolationError())
elseif extrapolation_down == ExtrapolationType.constant
first(A.u)
elseif extrapolation_down == ExtrapolationType.linear
Expand All @@ -25,7 +25,7 @@ end
function _extrapolate_up(A, t)
(; extrapolation_up) = A
if extrapolation_up == ExtrapolationType.none
throw(ExtrapolationError(UP_EXTRAPOLATION_ERROR))
throw(UpExtrapolationError())
elseif extrapolation_up == ExtrapolationType.constant
last(A.u)
elseif extrapolation_up == ExtrapolationType.linear
Expand Down

0 comments on commit ea91f76

Please sign in to comment.