-
-
Notifications
You must be signed in to change notification settings - Fork 73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Some more SDE sensitivity dispatches #368
Comments
@mohamed82008 I have a reduced version of the ReverseDiffAdjoint issue: using StochasticDiffEq, DiffEqSensitivity, Zygote
function loss(p)
f(u,p,t) = u
g(u,p,t) = u
prob = SDEProblem{false}(f,g,Float32[2.; 0.],(0.0f0, 1.0f0),p)
sol = solve(prob, SOSRI();sensealg=ReverseDiffAdjoint())
return sum(Array(sol))
end
@time Zygote.gradient(loss,[1f0]) |
Was there already an attempt to fix the using StochasticDiffEq, DiffEqSensitivity, Zygote, Random
function loss1(p)
f(u,p,t) = u
g(u,p,t) = u
prob = SDEProblem{false}(f,g,Float32[2.; 0.],(0.0f0, 1.0f0),p)
sol = solve(prob, SOSRI();sensealg=ReverseDiffAdjoint())
return sum(Array(sol))
end
@time Zygote.gradient(loss1,[1f0])
# type Float32 has no field partials Adding a function loss1dt(p)
f(u,p,t) = u
g(u,p,t) = u
prob = SDEProblem{false}(f,g,Float32[2.; 0.],(0.0f0, 1.0f0),p)
sol = solve(prob, SOSRI(), dt=0.01;sensealg=ReverseDiffAdjoint())
return sum(Array(sol))
end
@time Zygote.gradient(loss1dt,[1f0])
# Mutating arrays is not supported and ultimately writing it as: f2(u,p,t) = u
g2(u,p,t) = u
prob2 = SDEProblem{false}(f2,g2,Float32[2.; 0.],(0.0f0, 1.0f0),[1f0])
function loss5(p,prob,sensealg)
_prob = remake(prob)
sol = solve(_prob,SOSRI(),dt=0.01;sensealg=sensealg)
return sum(Array(sol))
end
@time Zygote.gradient(p->loss5(p,prob2,ReverseDiffAdjoint()),[1f0]) is fine. Is there anything odd with the MWE in general? Also for function loss3(p)
f(u,p,t) = u
g(u,p,t) = u
prob = SDEProblem{false}(f,g,Float32[2.; 0.],(0.0f0, 1.0f0),p)
sol = solve(prob, SOSRI();sensealg=TrackerAdjoint())
return sum(Array(sol))
end
@time Zygote.gradient(loss3,[1f0])
# Mutating arrays is not supported the mutation error is thrown -- pointing to the It doesn't seem to me that the error is all that much related to SDEs at the moment because using OrdinaryDiffEq, DiffEqSensitivity, Zygote, Random
function lossODE(p)
f(u,p,t) = u
prob = ODEProblem{false}(f,Float32[2.; 0.],(0.0f0, 1.0f0),p)
sol = solve(prob, Tsit5();sensealg=ReverseDiffAdjoint())
return sum(Array(sol))
end
@time Zygote.gradient(lossODE,[1f0])
# Mutating arrays is not supported also fails. |
Yes ReverseDiffAdjoint was largely fixed via https://github.com/SciML/DiffEqSensitivity.jl/pull/366/files#diff-4c1a437283c1ed3aeb4267af6e255f3434d8ba2b62eb67192c482a6becd3192dR360 |
using StochasticDiffEq, DiffEqSensitivity, Zygote, Random
function loss1(p)
f(u,p,t) = u
g(u,p,t) = u
prob = SDEProblem{false}(f,g,Float32[2., 0.],(0.0f0, 1.0f0),p)
sol = solve(prob, SOSRI();dt=0.01,sensealg=ReverseDiffAdjoint())
return sum(Array(sol))
end
@time Zygote.gradient(loss1,[1f0]) is fine, so the issue is just |
Most things are working here now. I'll split to specific issues. |
The two remaining issues were opened, and they are both upstream AD issues with the respective maintainers tagged. |
This is a nice set of tests:
Notice that most pass or have a predictable behavior. The ones that are failing are:
The one that is peculiar is ReverseDiffAdjoint + the adaptive algorithm. I'll see if I can get an MWE.
@frankschae let me know if I missed anything big in this summary.
The text was updated successfully, but these errors were encountered: