diff --git a/Project.toml b/Project.toml index 7105932..cfcef21 100644 --- a/Project.toml +++ b/Project.toml @@ -32,7 +32,7 @@ CTBasePlots = "Plots" [compat] DataStructures = "0.18" -DifferentiationInterface = "0.5" +DifferentiationInterface = "0.6" DocStringExtensions = "0.9" ForwardDiff = "0.10" Interpolations = "0.15" diff --git a/src/utils.jl b/src/utils.jl index 7993439..009d038 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -74,8 +74,7 @@ $(TYPEDSIGNATURES) Return the gradient of `f` at `x`. """ function ctgradient(f::Function, x::ctNumber; backend = __get_AD_backend()) - extras = prepare_derivative(f, backend, x) - return derivative(f, backend, x, extras) + return derivative(f, backend, x) end function __ctgradient(f::Function, x::ctNumber) @@ -88,8 +87,7 @@ $(TYPEDSIGNATURES) Return the gradient of `f` at `x`. """ function ctgradient(f::Function, x; backend = __get_AD_backend()) - extras = prepare_gradient(f, backend, x) - return gradient(f, backend, x, extras) + return gradient(f, backend, x) end function __ctgradient(f::Function, x) @@ -112,8 +110,7 @@ Return the Jacobian of `f` at `x`. """ function ctjacobian(f::Function, x::ctNumber; backend = __get_AD_backend()) f_number_to_number = only ∘ f ∘ only - extras = prepare_derivative(f_number_to_number, backend, x) - der = derivative(f_number_to_number, backend, x, extras) + der = derivative(f_number_to_number, backend, x) return [der;;] end @@ -127,8 +124,7 @@ $(TYPEDSIGNATURES) Return the Jacobian of `f` at `x`. """ function ctjacobian(f::Function, x; backend = __get_AD_backend()) - extras = prepare_jacobian(f, backend, x) - return jacobian(f, backend, x, extras) + return jacobian(f, backend, x) end __ctjacobian(f::Function, x) = ForwardDiff.jacobian(f, x) diff --git a/test/Project.toml b/test/Project.toml index 80e4417..a687798 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -8,7 +8,7 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [compat] Aqua = "0.8" -DifferentiationInterface = "0.5" +DifferentiationInterface = "0.6" JSON3 = "1" JLD2 = "0.5" Plots = "1"