Skip to content

Commit

Permalink
Specialize on functions in StaticArrays extension
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Nov 15, 2024
1 parent 8eaba05 commit d2c29a2
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions ext/ForwardDiffStaticArraysExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using DiffResults: DiffResult, ImmutableDiffResult, MutableDiffResult
end
end

@inline static_dual_eval(::Type{T}, f, x::StaticArray) where T = f(dualize(T, x))
@inline static_dual_eval(::Type{T}, f::F, x::StaticArray) where {T,F} = f(dualize(T, x))

# To fix method ambiguity issues:
function LinearAlgebra.eigvals(A::Symmetric{<:Dual{Tg,T,N}, <:StaticArrays.StaticMatrix}) where {Tg,T<:Real,N}
Expand All @@ -35,13 +35,13 @@ end
ForwardDiff._lyap_div!!(A::StaticArrays.MMatrix, λ::AbstractVector) = ForwardDiff._lyap_div!(A, λ)

# Gradient
@inline ForwardDiff.gradient(f, x::StaticArray) = vector_mode_gradient(f, x)
@inline ForwardDiff.gradient(f, x::StaticArray, cfg::GradientConfig) = gradient(f, x)
@inline ForwardDiff.gradient(f, x::StaticArray, cfg::GradientConfig, ::Val) = gradient(f, x)
@inline ForwardDiff.gradient(f::F, x::StaticArray) where {F} = vector_mode_gradient(f, x)
@inline ForwardDiff.gradient(f::F, x::StaticArray, cfg::GradientConfig) where {F} = gradient(f, x)
@inline ForwardDiff.gradient(f::F, x::StaticArray, cfg::GradientConfig, ::Val) where {F} = gradient(f, x)

@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray) = vector_mode_gradient!(result, f, x)
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::GradientConfig) = gradient!(result, f, x)
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::GradientConfig, ::Val) = gradient!(result, f, x)
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray) where {F} = vector_mode_gradient!(result, f, x)
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::GradientConfig) where {F} = gradient!(result, f, x)
@inline ForwardDiff.gradient!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::GradientConfig, ::Val) where {F} = gradient!(result, f, x)

@generated function extract_gradient(::Type{T}, y::Real, x::S) where {T,S<:StaticArray}
result = Expr(:tuple, [:(partials(T, y, $i)) for i in 1:length(x)]...)
Expand All @@ -52,24 +52,24 @@ ForwardDiff._lyap_div!!(A::StaticArrays.MMatrix, λ::AbstractVector) = ForwardDi
end
end

@inline function ForwardDiff.vector_mode_gradient(f, x::StaticArray)
@inline function ForwardDiff.vector_mode_gradient(f::F, x::StaticArray) where {F}
T = typeof(Tag(f, eltype(x)))
return extract_gradient(T, static_dual_eval(T, f, x), x)
end

@inline function ForwardDiff.vector_mode_gradient!(result, f, x::StaticArray)
@inline function ForwardDiff.vector_mode_gradient!(result, f::F, x::StaticArray) where {F}
T = typeof(Tag(f, eltype(x)))
return extract_gradient!(T, result, static_dual_eval(T, f, x))
end

# Jacobian
@inline ForwardDiff.jacobian(f, x::StaticArray) = vector_mode_jacobian(f, x)
@inline ForwardDiff.jacobian(f, x::StaticArray, cfg::JacobianConfig) = jacobian(f, x)
@inline ForwardDiff.jacobian(f, x::StaticArray, cfg::JacobianConfig, ::Val) = jacobian(f, x)
@inline ForwardDiff.jacobian(f::F, x::StaticArray) where {F} = vector_mode_jacobian(f, x)
@inline ForwardDiff.jacobian(f::F, x::StaticArray, cfg::JacobianConfig) where {F} = jacobian(f, x)
@inline ForwardDiff.jacobian(f::F, x::StaticArray, cfg::JacobianConfig, ::Val) where {F} = jacobian(f, x)

@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray) = vector_mode_jacobian!(result, f, x)
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::JacobianConfig) = jacobian!(result, f, x)
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f, x::StaticArray, cfg::JacobianConfig, ::Val) = jacobian!(result, f, x)
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray) where {F} = vector_mode_jacobian!(result, f, x)
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::JacobianConfig) where {F} = jacobian!(result, f, x)
@inline ForwardDiff.jacobian!(result::Union{AbstractArray,DiffResult}, f::F, x::StaticArray, cfg::JacobianConfig, ::Val) where {F} = jacobian!(result, f, x)

@generated function extract_jacobian(::Type{T}, ydual::StaticArray, x::S) where {T,S<:StaticArray}
M, N = length(ydual), length(x)
Expand All @@ -81,7 +81,7 @@ end
end
end

@inline function ForwardDiff.vector_mode_jacobian(f, x::StaticArray)
@inline function ForwardDiff.vector_mode_jacobian(f::F, x::StaticArray) where {F}
T = typeof(Tag(f, eltype(x)))
return extract_jacobian(T, static_dual_eval(T, f, x), x)
end
Expand All @@ -91,15 +91,15 @@ function extract_jacobian(::Type{T}, ydual::AbstractArray, x::StaticArray) where
return extract_jacobian!(T, result, ydual, length(x))
end

@inline function ForwardDiff.vector_mode_jacobian!(result, f, x::StaticArray)
@inline function ForwardDiff.vector_mode_jacobian!(result, f::F, x::StaticArray) where {F}
T = typeof(Tag(f, eltype(x)))
ydual = static_dual_eval(T, f, x)
result = extract_jacobian!(T, result, ydual, length(x))
result = extract_value!(T, result, ydual)
return result
end

@inline function ForwardDiff.vector_mode_jacobian!(result::ImmutableDiffResult, f, x::StaticArray)
@inline function ForwardDiff.vector_mode_jacobian!(result::ImmutableDiffResult, f::F, x::StaticArray) where {F}
T = typeof(Tag(f, eltype(x)))
ydual = static_dual_eval(T, f, x)
result = DiffResults.jacobian!(result, extract_jacobian(T, ydual, x))
Expand All @@ -108,18 +108,18 @@ end
end

# Hessian
ForwardDiff.hessian(f, x::StaticArray) = jacobian(y -> gradient(f, y), x)
ForwardDiff.hessian(f, x::StaticArray, cfg::HessianConfig) = hessian(f, x)
ForwardDiff.hessian(f, x::StaticArray, cfg::HessianConfig, ::Val) = hessian(f, x)
ForwardDiff.hessian(f::F, x::StaticArray) where {F} = jacobian(y -> gradient(f, y), x)
ForwardDiff.hessian(f::F, x::StaticArray, cfg::HessianConfig) where {F} = hessian(f, x)
ForwardDiff.hessian(f::F, x::StaticArray, cfg::HessianConfig, ::Val) where {F} = hessian(f, x)

ForwardDiff.hessian!(result::AbstractArray, f, x::StaticArray) = jacobian!(result, y -> gradient(f, y), x)
ForwardDiff.hessian!(result::AbstractArray, f::F, x::StaticArray) where {F} = jacobian!(result, y -> gradient(f, y), x)

ForwardDiff.hessian!(result::MutableDiffResult, f, x::StaticArray) = hessian!(result, f, x, HessianConfig(f, result, x))
ForwardDiff.hessian!(result::MutableDiffResult, f::F, x::StaticArray) where {F} = hessian!(result, f, x, HessianConfig(f, result, x))

ForwardDiff.hessian!(result::ImmutableDiffResult, f, x::StaticArray, cfg::HessianConfig) = hessian!(result, f, x)
ForwardDiff.hessian!(result::ImmutableDiffResult, f, x::StaticArray, cfg::HessianConfig, ::Val) = hessian!(result, f, x)
ForwardDiff.hessian!(result::ImmutableDiffResult, f::F, x::StaticArray, cfg::HessianConfig) where {F} = hessian!(result, f, x)
ForwardDiff.hessian!(result::ImmutableDiffResult, f::F, x::StaticArray, cfg::HessianConfig, ::Val) where {F} = hessian!(result, f, x)

function ForwardDiff.hessian!(result::ImmutableDiffResult, f, x::StaticArray)
function ForwardDiff.hessian!(result::ImmutableDiffResult, f::F, x::StaticArray) where {F}
T = typeof(Tag(f, eltype(x)))
d1 = dualize(T, x)
d2 = dualize(T, d1)
Expand Down

0 comments on commit d2c29a2

Please sign in to comment.