diff --git a/src/AbstractDifferentiation.jl b/src/AbstractDifferentiation.jl index 55504ed..55f6554 100644 --- a/src/AbstractDifferentiation.jl +++ b/src/AbstractDifferentiation.jl @@ -221,28 +221,26 @@ _zero(::AbstractVector, d::AbstractMatrix) = zero(similar(d, size(d, 2))) _zero(::AbstractMatrix, d::AbstractMatrix) = zero(d) _zero(::Any, d::Any) = zero(d) -@inline _dot(x, y) = dot(x, y) -@inline function _dot(x::AbstractVector, y::UniformScaling) +# support pullbacks for complex numbers +@inline _realdot(x, y) = real(dot(x, y)) +@inline _realdot(x::Number, y::Number) = muladd(real(x), real(y), imag(x) * imag(y)) +@inline _realdot(x::Real, y::Number) = x * real(y) +@inline _realdot(x::Number, y::Real) = real(x) * y +@inline _realdot(x::Real, y::Real) = x * y +@inline function _realdot(x::AbstractVector, y::UniformScaling) @assert length(x) == 1 - return @inbounds dot(x[1], y.λ) + return @inbounds _realdot(x[1], y.λ) end -@inline function _dot(x::AbstractVector, y::AbstractMatrix) +@inline function _realdot(x::AbstractVector, y::AbstractMatrix) @assert size(y, 2) == 1 - return dot(x, y) + return _realdot(x, vec(y)) +end +@inline function _realdot(xs::NTuple{N}, ys::NTuple{N}) where {N} + return sum(Base.splat(_realdot), zip(xs, ys)) end function pullback_function(ab::AbstractBackend, f, xs...) - return (ws) -> begin - return gradient(lowest(ab), (xs...,) -> begin - vs = f(xs...) - if ws isa Tuple - @assert length(vs) == length(ws) - return sum(Base.splat(_dot), zip(ws, vs)) - else - return _dot(vs, ws) - end - end, xs...) - end + return (ws) -> gradient(lowest(ab), (xs...,) -> _realdot(ws, f(xs...)), xs...) end function value_and_pullback_function( ab::AbstractBackend,