diff --git a/src/derivative_wrappers.jl b/src/derivative_wrappers.jl index 3f71df703..cbded0b1c 100644 --- a/src/derivative_wrappers.jl +++ b/src/derivative_wrappers.jl @@ -693,10 +693,10 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, dup, Enzyme.Const(t), Enzyme.Const(W)) end - dλ !== nothing && (dλ .= tmp1) + dλ !== nothing && recursive_copyto!(dλ,tmp1) dgrad !== nothing && !(typeof(tmp2) <: DiffEqBase.NullParameters) && recursive_copyto!(dgrad, tmp2) - dy !== nothing && (dy .= tmp3) + dy !== nothing && recursive_copyto!(dy,tmp3) else if W === nothing Enzyme.autodiff(Enzyme.Reverse, S.diffcache.pf, Enzyme.Duplicated(tmp3, tmp4), @@ -715,10 +715,10 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad, end recursive_copyto!(dy, out_) end - dλ !== nothing && (dλ .= tmp1) + dλ !== nothing && recursive_copyto!(dλ,tmp1) dgrad !== nothing && !(typeof(tmp2) <: DiffEqBase.NullParameters) && recursive_copyto!(dgrad, tmp2) - dy !== nothing && (dy .= tmp3) + dy !== nothing && recursive_copyto!(dy,tmp3) end return end diff --git a/test/size_handling_adjoint.jl b/test/size_handling_adjoint.jl index 94717551c..d8c139ef7 100644 --- a/test/size_handling_adjoint.jl +++ b/test/size_handling_adjoint.jl @@ -1,4 +1,4 @@ -using SciMLSensitivity, Flux, OrdinaryDiffEq, Test # , Plots +using SciMLSensitivity, Zygote, Flux, OrdinaryDiffEq, Test # , Plots p = [1.5 1.0; 3.0 1.0] function lotka_volterra(du, u, p, t) @@ -36,3 +36,38 @@ cb() Flux.train!(loss_adjoint, ps, data, opt, cb = cb) @test loss_adjoint() < 1 + +tspan = (0, 1) +tran = collect(0:0.1:1) +p0 = rand(2) +f0 = randn(30, 50) + +function rhs!(df, f, p, t) + for j in axes(f, 2) + for i in axes(f, 1) + df[i, j] = p[1] * i + p[2] * j + end + end + return nothing +end + +function loss(p; vjp) + prob = ODEProblem(rhs!, f0, tspan, p) + sol = solve(prob, Midpoint(), saveat = tran, sensealg=InterpolatingAdjoint(autojacvec=vjp)) |> Array + l = sum(abs2, sol) + + return l +end + +dp1 = Zygote.pullback(x -> loss(x; vjp = EnzymeVJP()), p0)[2](1)[1] +dp2 = Zygote.pullback(x -> loss(x; vjp = ReverseDiffVJP()), p0)[2](1)[1] +dp3 = Zygote.pullback(x -> loss(x; vjp = TrackerVJP()), p0)[2](1)[1] +dp4 = Zygote.pullback(x -> loss(x; vjp = EnzymeVJP()), p0)[2](1)[1] +dp5 = Zygote.pullback(x -> loss(x; vjp = true), p0)[2](1)[1] +dp6 = Zygote.pullback(x -> loss(x; vjp = false), p0)[2](1)[1] + +@test dp1 ≈ dp2 +@test dp1 ≈ dp3 +@test dp1 ≈ dp4 +@test dp1 ≈ dp5 +@test dp1 ≈ dp6 \ No newline at end of file