From 40d77d613280713a0f3aaf0f267e97239196434c Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 15 Nov 2024 15:04:44 +0100 Subject: [PATCH] Specialize on functions in StaticArrays extension --- ext/ForwardDiffStaticArraysExt.jl | 52 +++++++++++++++---------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/ext/ForwardDiffStaticArraysExt.jl b/ext/ForwardDiffStaticArraysExt.jl index f2b1540b..1f8c792c 100644 --- a/ext/ForwardDiffStaticArraysExt.jl +++ b/ext/ForwardDiffStaticArraysExt.jl @@ -21,7 +21,7 @@ using DiffResults: DiffResult, ImmutableDiffResult, MutableDiffResult end end -@inline static_dual_eval(::Type{T}, f, x::StaticArray) where T = f(dualize(T, x)) +@inline static_dual_eval(::Type{T}, f::F, x::StaticArray) where {T,F} = f(dualize(T, x)) # To fix method ambiguity issues: function LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N} @@ -35,13 +35,13 @@ end ForwardDiff._lyap_div!!(A::StaticArrays.MMatrix, λ::AbstractVector) = ForwardDiff._lyap_div!(A, λ) # Gradient -@inline ForwardDiff.gradient(f, x::StaticArray) = vector_mode_gradient(f, x) -@inline ForwardDiff.gradient(f, x::StaticArray, cfg::GradientConfig) = gradient(f, x) -@inline ForwardDiff.gradient(f, x::StaticArray, cfg::GradientConfig, ::Val) = gradient(f, x) +@inline ForwardDiff.gradient(f::F, x::StaticArray) where {F} = vector_mode_gradient(f, x) +@inline ForwardDiff.gradient(f::F, x::StaticArray, cfg::GradientConfig) where {F} = gradient(f, x) +@inline ForwardDiff.gradient(f::F, x::StaticArray, cfg::GradientConfig, ::Val) where {F} = gradient(f, x) -@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray) = vector_mode_gradient!(result, f, x) -@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::GradientConfig) = gradient!(result, f, x) -@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::GradientConfig, ::Val) = gradient!(result, f, x) +@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray) where {F} = vector_mode_gradient!(result, f, x) +@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::GradientConfig) where {F} = gradient!(result, f, x) +@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::GradientConfig, ::Val) where {F} = gradient!(result, f, x) @generated function extract_gradient(::Type{T}, y::Real, x::S) where {T,S<:StaticArray} result = Expr(:tuple, [:(partials(T, y, $i)) for i in 1:length(x)]...) @@ -52,24 +52,24 @@ ForwardDiff._lyap_div!!(A::StaticArrays.MMatrix, λ::AbstractVector) = ForwardDi end end -@inline function ForwardDiff.vector_mode_gradient(f, x::StaticArray) +@inline function ForwardDiff.vector_mode_gradient(f::F, x::StaticArray) where {F} T = typeof(Tag(f, eltype(x))) return extract_gradient(T, static_dual_eval(T, f, x), x) end -@inline function ForwardDiff.vector_mode_gradient!(result, f, x::StaticArray) +@inline function ForwardDiff.vector_mode_gradient!(result, f::F, x::StaticArray) where {F} T = typeof(Tag(f, eltype(x))) return extract_gradient!(T, result, static_dual_eval(T, f, x)) end # Jacobian -@inline ForwardDiff.jacobian(f, x::StaticArray) = vector_mode_jacobian(f, x) -@inline ForwardDiff.jacobian(f, x::StaticArray, cfg::JacobianConfig) = jacobian(f, x) -@inline ForwardDiff.jacobian(f, x::StaticArray, cfg::JacobianConfig, ::Val) = jacobian(f, x) +@inline ForwardDiff.jacobian(f::F, x::StaticArray) where {F} = vector_mode_jacobian(f, x) +@inline ForwardDiff.jacobian(f::F, x::StaticArray, cfg::JacobianConfig) where {F} = jacobian(f, x) +@inline ForwardDiff.jacobian(f::F, x::StaticArray, cfg::JacobianConfig, ::Val) where {F} = jacobian(f, x) -@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray) = vector_mode_jacobian!(result, f, x) -@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::JacobianConfig) = jacobian!(result, f, x) -@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::JacobianConfig, ::Val) = jacobian!(result, f, x) +@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray) where {F} = vector_mode_jacobian!(result, f, x) +@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::JacobianConfig) where {F} = jacobian!(result, f, x) +@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::JacobianConfig, ::Val) where {F} = jacobian!(result, f, x) @generated function extract_jacobian(::Type{T}, ydual::StaticArray, x::S) where {T,S<:StaticArray} M, N = length(ydual), length(x) @@ -81,7 +81,7 @@ end end end -@inline function ForwardDiff.vector_mode_jacobian(f, x::StaticArray) +@inline function ForwardDiff.vector_mode_jacobian(f::F, x::StaticArray) where {F} T = typeof(Tag(f, eltype(x))) return extract_jacobian(T, static_dual_eval(T, f, x), x) end @@ -91,7 +91,7 @@ function extract_jacobian(::Type{T}, ydual::AbstractArray, x::StaticArray) where return extract_jacobian!(T, result, ydual, length(x)) end -@inline function ForwardDiff.vector_mode_jacobian!(result, f, x::StaticArray) +@inline function ForwardDiff.vector_mode_jacobian!(result, f::F, x::StaticArray) where {F} T = typeof(Tag(f, eltype(x))) ydual = static_dual_eval(T, f, x) result = extract_jacobian!(T, result, ydual, length(x)) @@ -99,7 +99,7 @@ end return result end -@inline function ForwardDiff.vector_mode_jacobian!(result::ImmutableDiffResult, f, x::StaticArray) +@inline function ForwardDiff.vector_mode_jacobian!(result::ImmutableDiffResult, f::F, x::StaticArray) where {F} T = typeof(Tag(f, eltype(x))) ydual = static_dual_eval(T, f, x) result = DiffResults.jacobian!(result, extract_jacobian(T, ydual, x)) @@ -108,18 +108,18 @@ end end # Hessian -ForwardDiff.hessian(f, x::StaticArray) = jacobian(y -> gradient(f, y), x) -ForwardDiff.hessian(f, x::StaticArray, cfg::HessianConfig) = hessian(f, x) -ForwardDiff.hessian(f, x::StaticArray, cfg::HessianConfig, ::Val) = hessian(f, x) +ForwardDiff.hessian(f::F, x::StaticArray) where {F} = jacobian(y -> gradient(f, y), x) +ForwardDiff.hessian(f::F, x::StaticArray, cfg::HessianConfig) where {F} = hessian(f, x) +ForwardDiff.hessian(f::F, x::StaticArray, cfg::HessianConfig, ::Val) where {F} = hessian(f, x) -ForwardDiff.hessian!(result::AbstractArray, f, x::StaticArray) = jacobian!(result, y -> gradient(f, y), x) +ForwardDiff.hessian!(result::AbstractArray, f::F, x::StaticArray) where {F} = jacobian!(result, y -> gradient(f, y), x) -ForwardDiff.hessian!(result::MutableDiffResult, f, x::StaticArray) = hessian!(result, f, x, HessianConfig(f, result, x)) +ForwardDiff.hessian!(result::MutableDiffResult, f::F, x::StaticArray) where {F} = hessian!(result, f, x, HessianConfig(f, result, x)) -ForwardDiff.hessian!(result::ImmutableDiffResult, f, x::StaticArray, cfg::HessianConfig) = hessian!(result, f, x) -ForwardDiff.hessian!(result::ImmutableDiffResult, f, x::StaticArray, cfg::HessianConfig, ::Val) = hessian!(result, f, x) +ForwardDiff.hessian!(result::ImmutableDiffResult, f::F, x::StaticArray, cfg::HessianConfig) where {F} = hessian!(result, f, x) +ForwardDiff.hessian!(result::ImmutableDiffResult, f::F, x::StaticArray, cfg::HessianConfig, ::Val) where {F} = hessian!(result, f, x) -function ForwardDiff.hessian!(result::ImmutableDiffResult, f, x::StaticArray) +function ForwardDiff.hessian!(result::ImmutableDiffResult, f::F, x::StaticArray) where {F} T = typeof(Tag(f, eltype(x))) d1 = dualize(T, x) d2 = dualize(T, d1)