diff --git a/Project.toml b/Project.toml index 123b98a2..8189faab 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SparseDiffTools" uuid = "47a9eef4-7e08-11e9-0b38-333d64bd3804" authors = ["Pankaj Mishra ", "Chris Rackauckas "] -version = "2.9.2" +version = "2.10.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/ext/SparseDiffToolsZygoteExt.jl b/ext/SparseDiffToolsZygoteExt.jl index fd31dcea..5006f843 100644 --- a/ext/SparseDiffToolsZygoteExt.jl +++ b/ext/SparseDiffToolsZygoteExt.jl @@ -13,12 +13,12 @@ import SparseDiffTools: __f̂, __jacobian!, __gradient, __gradient! import ADTypes: AutoZygote, AutoSparseZygote ## Satisfying High-Level Interface for Sparse Jacobians -function __gradient(::Union{AutoSparseZygote, AutoZygote}, f, x, cols) +function __gradient(::Union{AutoSparseZygote, AutoZygote}, f::F, x, cols) where {F} _, ∂x, _ = Zygote.gradient(__f̂, f, x, cols) return vec(∂x) end -function __gradient!(::Union{AutoSparseZygote, AutoZygote}, f!, fx, x, cols) +function __gradient!(::Union{AutoSparseZygote, AutoZygote}, f!::F, fx, x, cols) where {F} return error("Zygote.jl cannot differentiate in-place (mutating) functions.") end @@ -26,7 +26,8 @@ end # https://github.com/FluxML/Zygote.jl/blob/82c7a000bae7fb0999275e62cc53ddb61aed94c7/src/lib/grad.jl#L140-L157C4 import Zygote: _jvec, _eyelike, _gradcopy! -@views function __jacobian!(J::AbstractMatrix, ::Union{AutoSparseZygote, AutoZygote}, f, x) +@views function __jacobian!(J::AbstractMatrix, ::Union{AutoSparseZygote, AutoZygote}, f::F, + x) where {F} y, back = Zygote.pullback(_jvec ∘ f, x) δ = _eyelike(y) for k in LinearIndices(y) @@ -36,13 +37,13 @@ import Zygote: _jvec, _eyelike, _gradcopy! return J end -function __jacobian!(J, ::Union{AutoSparseZygote, AutoZygote}, f!, fx, x) +function __jacobian!(_, ::Union{AutoSparseZygote, AutoZygote}, f!::F, fx, x) where {F} return error("Zygote.jl cannot differentiate in-place (mutating) functions.") end ### Jac, Hes products -function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v)) +function numback_hesvec!(dy, f::F, x, v, cache1 = similar(v), cache2 = similar(v)) where {F} g = let f = f (dx, x) -> dx .= first(Zygote.gradient(f, x)) end @@ -57,15 +58,14 @@ function numback_hesvec!(dy, f, x, v, cache1 = similar(v), cache2 = similar(v)) @. dy = (cache1 - cache2) / (2ϵ) end -function numback_hesvec(f, x, v) - g = x -> first(Zygote.gradient(f, x)) +function numback_hesvec(f::F, x, v) where {F} T = eltype(x) # Should it be min? max? mean? ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x))) x += ϵ * v - gxp = g(x) + gxp = first(Zygote.gradient(f, x)) x -= 2ϵ * v - gxm = g(x) + gxm = first(Zygote.gradient(f, x)) (gxp - gxm) / (2ϵ) end @@ -94,38 +94,36 @@ end ## VecJac products # VJP methods -function auto_vecjac!(du, f, x, v) +function auto_vecjac!(du, f::F, x, v) where {F} !static_hasmethod(f, typeof((x,))) && error("For inplace function use autodiff = AutoFiniteDiff()") du .= reshape(SparseDiffTools.auto_vecjac(f, x, v), size(du)) end -function auto_vecjac(f, x, v) +function auto_vecjac(f::F, x, v) where {F} y, back = Zygote.pullback(f, x) - return vec(back(reshape(v, size(y)))[1]) + return vec(only(back(reshape(v, size(y))))) end # overload operator interface -function SparseDiffTools._vecjac(f, u, autodiff::AutoZygote) - cache = () +function SparseDiffTools._vecjac(f::F, _, u, autodiff::AutoZygote) where {F} + !static_hasmethod(f, typeof((u,))) && + error("For inplace function use autodiff = AutoFiniteDiff()") pullback = Zygote.pullback(f, u) - - return AutoDiffVJP(f, u, cache, autodiff, pullback) + return AutoDiffVJP(f, u, (), autodiff, pullback) end function update_coefficients(L::AutoDiffVJP{<:AutoZygote}, u, p, t; VJP_input = nothing) VJP_input !== nothing && (@set! L.u = VJP_input) - @set! L.f = update_coefficients(L.f, L.u, p, t) @set! L.pullback = Zygote.pullback(L.f, L.u) + return L end function update_coefficients!(L::AutoDiffVJP{<:AutoZygote}, u, p, t; VJP_input = nothing) VJP_input !== nothing && copy!(L.u, VJP_input) - update_coefficients!(L.f, L.u, p, t) L.pullback = Zygote.pullback(L.f, L.u) - return L end @@ -133,22 +131,14 @@ end function (L::AutoDiffVJP{<:AutoZygote})(v, p, t; VJP_input = nothing) # ignore VJP_input as pullback was computed in update_coefficients(...) y, back = L.pullback - V = reshape(v, size(y)) - - return vec(first(back(V))) + return vec(only(back(reshape(v, size(y))))) end # prefer non in-place method -function (L::AutoDiffVJP{<:AutoZygote, IIP, true})(dv, v, p, t; - VJP_input = nothing) where {IIP} +function (L::AutoDiffVJP{<:AutoZygote})(dv, v, p, t; VJP_input = nothing) # ignore VJP_input as pullback was computed in update_coefficients!(...) - - _dv = L(v, p, t; VJP_input = VJP_input) + _dv = L(v, p, t; VJP_input) copy!(dv, _dv) end -function (L::AutoDiffVJP{<:AutoZygote, true, false})(args...; kwargs...) - error("Zygote requires an out of place method with signature f(u).") -end - end # module diff --git a/src/differentiation/vecjac_products.jl b/src/differentiation/vecjac_products.jl index 87162012..835f9e19 100644 --- a/src/differentiation/vecjac_products.jl +++ b/src/differentiation/vecjac_products.jl @@ -1,10 +1,10 @@ -function num_vecjac!(du, f, x, v, cache1 = similar(v), cache2 = similar(v); - compute_f0 = true) +function num_vecjac!(du, f::F, x, v, cache1 = similar(v), cache2 = similar(v); + compute_f0 = true) where {F} compute_f0 && (f(cache1, x)) T = eltype(x) # Should it be min? max? mean? ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x))) - vv = reshape(v, size(x)) + vv = reshape(v, size(cache1)) for i in 1:length(x) x[i] += ϵ f(cache2, x) @@ -14,9 +14,9 @@ function num_vecjac!(du, f, x, v, cache1 = similar(v), cache2 = similar(v); return du end -function num_vecjac(f, x, v, f0 = nothing) - vv = reshape(v, axes(x)) - f0 === nothing ? _f0 = f(x) : _f0 = f0 +function num_vecjac(f::F, x, v, f0 = nothing) where {F} + f0 === nothing ? (_f0 = f(x)) : (_f0 = f0) + vv = reshape(v, axes(_f0)) T = eltype(x) # Should it be min? max? mean? ϵ = sqrt(eps(real(T))) * max(one(real(T)), abs(norm(x))) @@ -33,12 +33,16 @@ end ### Operator Forms """ - VecJac(f, u, [p, t]; autodiff = AutoFiniteDiff()) + VecJac(f, u, [p, t]; fu = nothing, autodiff = AutoFiniteDiff()) -Returns SciMLOperators.FunctionOperator which computes vector-jacobian -product `df/du * v`. +Returns SciMLOperators.FunctionOperator which computes vector-jacobian product `df/du * v`. -``` +!!! note + + For non-square jacobians with inplace `f`, `fu` must be specified, else `VecJac` assumes + a square jacobian. + +```julia L = VecJac(f, u) L * v # = df/du * v @@ -47,31 +51,121 @@ mul!(w, L, v) # = df/du * v L(v, p, t; VJP_input = w) # = df/dw * v L(x, v, p, t; VJP_input = w) # = df/dw * v ``` + +## Allowed Function Signatures for `f` + +For Out of Place Functions: + +```julia +f(u, p, t) # t !== nothing +f(u, p) # p !== nothing +f(u) # Otherwise +``` + +For In Place Functions: + +```julia +f(du, u, p, t) # t !== nothing +f(du, u, p) # p !== nothing +f(du, u) # Otherwise +``` """ -function VecJac(f, u::AbstractArray, p = nothing, t = nothing; +function VecJac(f, u::AbstractArray, p = nothing, t = nothing; fu = nothing, autodiff = AutoFiniteDiff(), kwargs...) - L = _vecjac(f, u, autodiff) - IIP, OOP = get_iip_oop(L) + ff = VecJacFunctionWrapper(f, fu, u, p, t) - if isa(autodiff, AutoZygote) & !OOP + if !__internal_oop(ff) && autodiff isa AutoZygote msg = "Zygote requires an out of place method with signature f(u)." throw(ArgumentError(msg)) end - # NOTE: The operator returned has both in-place and out-of-place definitions and - # doesn't follow the convention of `f` - return FunctionOperator(L, u, u; isinplace = true, outofplace = OOP, - p, t, islinear = true, accepted_kwargs = (:VJP_input,), kwargs...) + fu === nothing && (fu = __internal_oop(ff) ? ff(u) : u) + + op = _vecjac(ff, fu, u, autodiff) + + # FIXME: FunctionOperator is terribly type unstable. It makes it `::Any` + # NOTE: We pass `p`, `t` to Function Operator but we always use the cached version from + # VecJacFunctionWrapper + return FunctionOperator(op, fu, u; p, t, isinplace = true, outofplace = true, + islinear = true, accepted_kwargs = (:VJP_input,), kwargs...) end -function _vecjac(f, u, autodiff::AutoFiniteDiff) - cache = (similar(u), similar(u)) - pullback = nothing +mutable struct VecJacFunctionWrapper{iip, oop, mode, F, FU, P, T} <: Function + f::F + fu::FU + p::P + t::T +end + +function SciMLOperators.update_coefficients!(L::VecJacFunctionWrapper{iip, oop, mode}, _, + p, t) where {iip, oop, mode} + mode == 1 && (L.t = t) + mode == 2 && (L.p = p) + return L +end +function SciMLOperators.update_coefficients(L::VecJacFunctionWrapper{iip, oop, mode}, _, p, + t) where {iip, oop, mode} + return VecJacFunctionWrapper{iip, oop, mode, typeof(L.f), typeof(L.fu), typeof(p), + typeof(t)}(L.f, L.fu, p, + t) +end + +__internal_iip(::VecJacFunctionWrapper{iip}) where {iip} = iip +__internal_oop(::VecJacFunctionWrapper{iip, oop}) where {iip, oop} = oop + +(f::VecJacFunctionWrapper{true, oop, 1})(fu, u) where {oop} = f.f(fu, u, f.p, f.t) +(f::VecJacFunctionWrapper{true, oop, 2})(fu, u) where {oop} = f.f(fu, u, f.p) +(f::VecJacFunctionWrapper{true, oop, 3})(fu, u) where {oop} = f.f(fu, u) +(f::VecJacFunctionWrapper{true, true, 1})(u) = f.f(u, f.p, f.t) +(f::VecJacFunctionWrapper{true, true, 2})(u) = f.f(u, f.p) +(f::VecJacFunctionWrapper{true, true, 3})(u) = f.f(u) +(f::VecJacFunctionWrapper{true, false, 1})(u) = (f.f(f.fu, u, f.p, f.t); copy(f.fu)) +(f::VecJacFunctionWrapper{true, false, 2})(u) = (f.f(f.fu, u, f.p); copy(f.fu)) +(f::VecJacFunctionWrapper{true, false, 3})(u) = (f.f(f.fu, u); copy(f.fu)) + +(f::VecJacFunctionWrapper{false, true, 1})(fu, u) = (vec(fu) .= vec(f.f(u, f.p, f.t))) +(f::VecJacFunctionWrapper{false, true, 2})(fu, u) = (vec(fu) .= vec(f.f(u, f.p))) +(f::VecJacFunctionWrapper{false, true, 3})(fu, u) = (vec(fu) .= vec(f.f(u))) +(f::VecJacFunctionWrapper{false, true, 1})(u) = f.f(u, f.p, f.t) +(f::VecJacFunctionWrapper{false, true, 2})(u) = f.f(u, f.p) +(f::VecJacFunctionWrapper{false, true, 3})(u) = f.f(u) + +function VecJacFunctionWrapper(f::F, fu_, u, p, t) where {F} + fu = fu_ === nothing ? copy(u) : copy(fu_) + if t !== nothing + iip = static_hasmethod(f, typeof((fu, u, p, t))) + oop = static_hasmethod(f, typeof((u, p, t))) + if !iip && !oop + throw(ArgumentError("`f(u, p, t)` or `f(fu, u, p, t)` not defined for `f`")) + end + return VecJacFunctionWrapper{iip, oop, 1, F, typeof(fu), typeof(p), typeof(t)}(f, + fu, p, t) + elseif p !== nothing + iip = static_hasmethod(f, typeof((fu, u, p))) + oop = static_hasmethod(f, typeof((u, p))) + if !iip && !oop + throw(ArgumentError("`f(u, p)` or `f(fu, u, p)` not defined for `f`")) + end + return VecJacFunctionWrapper{iip, oop, 2, F, typeof(fu), typeof(p), typeof(t)}(f, + fu, p, t) + else + iip = static_hasmethod(f, typeof((fu, u))) + oop = static_hasmethod(f, typeof((u,))) + if !iip && !oop + throw(ArgumentError("`f(u)` or `f(fu, u)` not defined for `f`")) + end + return VecJacFunctionWrapper{iip, oop, 3, F, typeof(fu), typeof(p), typeof(t)}(f, + fu, p, t) + end +end - AutoDiffVJP(f, u, cache, autodiff, pullback) +function _vecjac(f::F, fu, u, autodiff::AutoFiniteDiff) where {F} + cache = (similar(fu), similar(fu)) + pullback = nothing + return AutoDiffVJP(f, u, cache, autodiff, pullback) end -mutable struct AutoDiffVJP{AD, IIP, OOP, F, U, C, PB} <: AbstractAutoDiffVecProd +mutable struct AutoDiffVJP{AD, F, U, C, PB} <: AbstractAutoDiffVecProd """ Compute VJP of `f` at `u`, applied to vector `v`: `df/du' * u` """ f::F """ input to `f` """ @@ -82,71 +176,35 @@ mutable struct AutoDiffVJP{AD, IIP, OOP, F, U, C, PB} <: AbstractAutoDiffVecProd autodiff::AD """ stores the result of Zygote.pullback for AutoZygote """ pullback::PB - - function AutoDiffVJP(f, u, cache, autodiff, pullback) - outofplace = static_hasmethod(f, typeof((u,))) - isinplace = static_hasmethod(f, typeof((u, u))) - - if !(isinplace) & !(outofplace) - msg = "$f must have signature f(u), or f(du, u)" - throw(ArgumentError(msg)) - end - - new{ - typeof(autodiff), - isinplace, - outofplace, - typeof(f), - typeof(u), - typeof(cache), - typeof(pullback), - }(f, - u, - cache, - autodiff, - pullback) - end -end - -function get_iip_oop(::AutoDiffVJP{AD, IIP, OOP}) where {AD, IIP, OOP} - IIP, OOP end -function update_coefficients(L::AutoDiffVJP{AD}, u, p, t; - VJP_input = nothing) where {AD <: AutoFiniteDiff} - if !isnothing(VJP_input) - @set! L.u = VJP_input - end - +function update_coefficients(L::AutoDiffVJP{<:AutoFiniteDiff}, u, p, t; VJP_input = nothing) + VJP_input !== nothing && (@set! L.u = VJP_input) @set! L.f = update_coefficients(L.f, L.u, p, t) + return L end -function update_coefficients!(L::AutoDiffVJP{AD}, u, p, t; - VJP_input = nothing) where {AD <: AutoFiniteDiff} - if !isnothing(VJP_input) - copy!(L.u, VJP_input) - end - +function update_coefficients!(L::AutoDiffVJP{<:AutoFiniteDiff}, u, p, t; + VJP_input = nothing) + VJP_input !== nothing && copy!(L.u, VJP_input) update_coefficients!(L.f, L.u, p, t) - - L + return L end # Interpret the call as df/du' * v -function (L::AutoDiffVJP{AD})(v, p, t; VJP_input = nothing) where {AD <: AutoFiniteDiff} +function (L::AutoDiffVJP{<:AutoFiniteDiff})(v, p, t; VJP_input = nothing) # ignore VJP_input as L.u was set in update_coefficients(...) - num_vecjac(L.f, L.u, v) + return num_vecjac(L.f, L.u, v) end -function (L::AutoDiffVJP{AD})(dv, v, p, t; VJP_input = nothing) where {AD <: AutoFiniteDiff} +function (L::AutoDiffVJP{<:AutoFiniteDiff})(dv, v, p, t; VJP_input = nothing) # ignore VJP_input as L.u was set in update_coefficients!(...) - num_vecjac!(dv, L.f, L.u, v, L.cache...) + return num_vecjac!(dv, L.f, L.u, v, L.cache...) end function Base.resize!(L::AutoDiffVJP, n::Integer) static_hasmethod(resize!, typeof((L.f, n))) && resize!(L.f, n) resize!(L.u, n) - for v in L.cache resize!(v, n) end diff --git a/test/test_vecjac_products.jl b/test/test_vecjac_products.jl index aa411bfb..44dadca9 100644 --- a/test/test_vecjac_products.jl +++ b/test/test_vecjac_products.jl @@ -21,8 +21,10 @@ a, b = rand(Float32, 2) A = rand(Float32, N, N) _f(y, x) = mul!(y, A, x .^ 2) _f(x) = A * (x .^ 2) +_f2(x, p, t) = _f(x) * p * t +_f2(y, x, p, t) = (_f(y, x); lmul!(p * t, y); y) -# Define state-dependent functions for operator tests +# Define state-dependent functions for operator tests include("update_coeffs_testutils.jl") f = WrapFunc(_f, 1.0f0, 1.0f0) @@ -36,7 +38,9 @@ f = WrapFunc(_f, 1.0f0, 1.0f0) @info "VecJac AutoZygote" p, t = rand(Float32, 2) -L = VecJac(f, copy(x1), p, t; autodiff = AutoZygote()) +f = WrapFunc(_f, p, t) + +L = VecJac(_f2, copy(x1), p, t; autodiff = AutoZygote()) update_coefficients!(L, v, p, t) update_coefficients!(f, v, p, t) @@ -83,7 +87,8 @@ y = zeros(N); @info "VecJac AutoFiniteDiff" p, t = rand(Float32, 2) -L = VecJac(f, copy(x1), 1.0f0, 1.0f0; autodiff = AutoFiniteDiff()) +f = WrapFunc(_f, p, t) +L = VecJac(_f2, copy(x1), p, t; autodiff = AutoFiniteDiff()) update_coefficients!(L, v, p, t) update_coefficients!(f, v, p, t) @@ -129,16 +134,39 @@ f2(y, x) = (copy!(y, x); lmul!(2, y); y) x = rand(Float32, N) for M in (100, 400) - local L = VecJac(f2, copy(x), 1.0f0, 1.0f0; autodiff = AutoZygote()) + local L = VecJac(f2, copy(x); autodiff = AutoZygote()) resize!(L, M) _x = resize!(copy(x), M) _u = rand(M) local J2 = Zygote.jacobian(f2, _x)[1] - update_coefficients!(L, _u, 1.0f0, 1.0f0; VJP_input = _x) + update_coefficients!(L, _u, nothing, 1.0f0; VJP_input = _x) @test L * _u≈J2' * _u rtol=1e-6 local _v = zeros(M) @test mul!(_v, L, _u)≈J2' * _u rtol=1e-6 end -# + +# Test Non-Square Jacobians +f3_oop(x) = vcat(x, x) +function f3_iip(y, x) + y[1:length(x)] .= x + y[(length(x) + 1):end] .= x + return nothing +end + +x = rand(Float32, 2) +y = rand(eltype(x), 4) + +L = VecJac(f3_oop, copy(x); autodiff = AutoFiniteDiff()) +@test size(L) == (length(x), length(y)) +@test L * y ≈ num_vecjac(f3_oop, copy(x), y) + +L = VecJac(f3_iip, copy(x); autodiff = AutoFiniteDiff(), fu = copy(y)) +@test size(L) == (length(x), length(y)) +@test mul!(zero(x), L, y) ≈ num_vecjac!(zero(x), f3_iip, copy(x), y) +@test L * y ≈ num_vecjac!(zero(x), f3_iip, copy(x), y) + +L = VecJac(f3_oop, copy(x); autodiff = AutoZygote()) +@test size(L) == (length(x), length(y)) +@test L * y ≈ Zygote.jacobian(f3_oop, copy(x))[1]' * y