Skip to content

Commit

Permalink
Merge pull request #141 from control-toolbox/differentiationinterface
Browse files Browse the repository at this point in the history
WIP: Differentiationinterface
  • Loading branch information
jbcaillau authored Jun 14, 2024
2 parents eda4107 + 605dcc6 commit 61c71dd
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 9 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.9.1"

[deps]
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Interpolations = "a98d9a8b-a2ab-59e6-89dd-64a1c18fca59"
Expand All @@ -21,6 +22,7 @@ Unicode = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5"

[compat]
DataStructures = "0.18"
DifferentiationInterface = "0.5"
DocStringExtensions = "0.9"
ForwardDiff = "0.10"
Interpolations = "0.15"
Expand Down
5 changes: 4 additions & 1 deletion src/CTBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ module CTBase
# using
import Base
using DocStringExtensions
using ForwardDiff: jacobian, gradient, ForwardDiff # automatic differentiation
using DifferentiationInterface: AutoForwardDiff, derivative, gradient, jacobian, prepare_derivative, prepare_gradient, prepare_jacobian
import ForwardDiff
using Interpolations: linear_interpolation, Line, Interpolations # for default interpolation
using MLStyle # pattern matching
using Parameters # @with_kw: to have default values in struct
Expand Down Expand Up @@ -193,6 +194,8 @@ See also: [`ctVector`](@ref), [`DState`](@ref).
"""
const DCostate = ctVector

__auto() = AutoForwardDiff() # default AD backend

#
include("exception.jl")
include("description.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/print.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,7 @@ function Base.show(io::IO, ::MIME"text/plain", ocp::OptimalControlModel{<: TimeD
#
println(io)
printstyled(io, "Declarations ", bold=true)
printstyled(io, "(* required):\n", bold=false)
printstyled(io, "(* for required):\n", bold=false)
#println(io)

# print table of settings
Expand Down
22 changes: 15 additions & 7 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,19 @@ $(TYPEDSIGNATURES)
Return the gradient of `f` at `x`.
"""
function ctgradient(f::Function, x::ctNumber)
return ForwardDiff.derivative(x -> f(x), x)
function ctgradient(f::Function, x::ctNumber; backend=__auto())
extras = prepare_derivative(f, backend, x)
return derivative(f, backend, x, extras)
end

"""
$(TYPEDSIGNATURES)
Return the gradient of `f` at `x`.
"""
function ctgradient(f::Function, x)
return ForwardDiff.gradient(f, x)
function ctgradient(f::Function, x; backend=__auto())
extras = prepare_gradient(f, backend, x)
return gradient(f, backend, x, extras)
end

"""
Expand All @@ -98,16 +100,22 @@ $(TYPEDSIGNATURES)
Return the Jacobian of `f` at `x`.
"""
function ctjacobian(f::Function, x::ctNumber)
return ForwardDiff.jacobian(x -> f(x[1]), [x])
function ctjacobian(f::Function, x::ctNumber; backend=__auto())
f_number_to_number = only f only
extras = prepare_derivative(f_number_to_number, backend, x)
der = derivative(f_number_to_number, backend, x, extras)
return [der;;]
end

"""
$(TYPEDSIGNATURES)
Return the Jacobian of `f` at `x`.
"""
ctjacobian(f::Function, x) = ForwardDiff.jacobian(f, x)
function ctjacobian(f::Function, x; backend=__auto())
extras = prepare_jacobian(f, backend, x)
return jacobian(f, backend, x, extras)
end

"""
$(TYPEDSIGNATURES)
Expand Down

0 comments on commit 61c71dd

Please sign in to comment.