Skip to content

Commit

Permalink
Merge pull request #919 from SciML/enzyme_size
Browse files Browse the repository at this point in the history
Fix enzyme matrix resizing
  • Loading branch information
ChrisRackauckas authored Oct 20, 2023
2 parents a635cf2 + f628e44 commit 9cbadff
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 5 deletions.
8 changes: 4 additions & 4 deletions src/derivative_wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -693,10 +693,10 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad,
dup,
Enzyme.Const(t), Enzyme.Const(W))
end
!== nothing && (dλ .= tmp1)
!== nothing && recursive_copyto!(dλ,tmp1)
dgrad !== nothing && !(typeof(tmp2) <: DiffEqBase.NullParameters) &&
recursive_copyto!(dgrad, tmp2)
dy !== nothing && (dy .= tmp3)
dy !== nothing && recursive_copyto!(dy,tmp3)
else
if W === nothing
Enzyme.autodiff(Enzyme.Reverse, S.diffcache.pf, Enzyme.Duplicated(tmp3, tmp4),
Expand All @@ -715,10 +715,10 @@ function _vecjacobian!(dλ, y, λ, p, t, S::TS, isautojacvec::EnzymeVJP, dgrad,
end
recursive_copyto!(dy, out_)
end
!== nothing && (dλ .= tmp1)
!== nothing && recursive_copyto!(dλ,tmp1)
dgrad !== nothing && !(typeof(tmp2) <: DiffEqBase.NullParameters) &&
recursive_copyto!(dgrad, tmp2)
dy !== nothing && (dy .= tmp3)
dy !== nothing && recursive_copyto!(dy,tmp3)
end
return
end
Expand Down
37 changes: 36 additions & 1 deletion test/size_handling_adjoint.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using SciMLSensitivity, Flux, OrdinaryDiffEq, Test # , Plots
using SciMLSensitivity, Zygote, Flux, OrdinaryDiffEq, Test # , Plots

p = [1.5 1.0; 3.0 1.0]
function lotka_volterra(du, u, p, t)
Expand Down Expand Up @@ -36,3 +36,38 @@ cb()
Flux.train!(loss_adjoint, ps, data, opt, cb = cb)

@test loss_adjoint() < 1

tspan = (0, 1)
tran = collect(0:0.1:1)
p0 = rand(2)
f0 = randn(30, 50)

function rhs!(df, f, p, t)
for j in axes(f, 2)
for i in axes(f, 1)
df[i, j] = p[1] * i + p[2] * j
end
end
return nothing
end

function loss(p; vjp)
prob = ODEProblem(rhs!, f0, tspan, p)
sol = solve(prob, Midpoint(), saveat = tran, sensealg=InterpolatingAdjoint(autojacvec=vjp)) |> Array
l = sum(abs2, sol)

return l
end

dp1 = Zygote.pullback(x -> loss(x; vjp = EnzymeVJP()), p0)[2](1)[1]
dp2 = Zygote.pullback(x -> loss(x; vjp = ReverseDiffVJP()), p0)[2](1)[1]
dp3 = Zygote.pullback(x -> loss(x; vjp = TrackerVJP()), p0)[2](1)[1]
dp4 = Zygote.pullback(x -> loss(x; vjp = EnzymeVJP()), p0)[2](1)[1]
dp5 = Zygote.pullback(x -> loss(x; vjp = true), p0)[2](1)[1]
dp6 = Zygote.pullback(x -> loss(x; vjp = false), p0)[2](1)[1]

@test dp1 dp2
@test dp1 dp3
@test dp1 dp4
@test dp1 dp5
@test dp1 dp6

0 comments on commit 9cbadff

Please sign in to comment.