Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some more SDE sensitivity dispatches #368

Closed
ChrisRackauckas opened this issue Dec 21, 2020 · 6 comments
Closed

Some more SDE sensitivity dispatches #368

ChrisRackauckas opened this issue Dec 21, 2020 · 6 comments

Comments

@ChrisRackauckas
Copy link
Member

This is a nice set of tests:

using DiffEqFlux, StochasticDiffEq, DiffEqSensitivity, Zygote, Random

u0 = Float32[2.; 0.]
datasize = 30
tspan = (0.0f0, 1.0f0)
tsteps = range(tspan[1], tspan[2], length = datasize)

drift_dudt = FastChain((x, p) -> x.^3,
                       FastDense(2, 50, tanh),
                       FastDense(50, 2))
diffusion_dudt = FastChain(FastDense(2, 2))

neuralsderd = NeuralDSDE(drift_dudt, diffusion_dudt, tspan, EulerHeun(),dt=1e-2,
                         sensealg = ReverseDiffAdjoint(),
                         saveat = tsteps)

p = neuralsderd.p

# ReverseDiff: Works

function predict_neuralsde_reversediff(p)
  return sum(Array(neuralsderd(u0, p)))
end

Random.seed!(1001)
@time Zygote.gradient(predict_neuralsde_reversediff,p)
Random.seed!(1001)
@time Zygote.gradient(predict_neuralsde_reversediff,p)
Random.seed!(1001)
@time Zygote.gradient(predict_neuralsde_reversediff,p)

# Tracker: Works

neuralsdetracker = NeuralDSDE(drift_dudt, diffusion_dudt, tspan, EulerHeun(),dt=1e-2,
                        sensealg = TrackerAdjoint(),
                       saveat = tsteps)

function predict_neuralsde_tracker(p)
  return sum(Array(neuralsdetracker(u0, p)))
end

Random.seed!(1001)
@time Zygote.gradient(predict_neuralsde_tracker,p)
Random.seed!(1001)
@time Zygote.gradient(predict_neuralsde_tracker,p)
Random.seed!(1001)
@time Zygote.gradient(predict_neuralsde_tracker,p)

# Works

neuralsde_backsolve = NeuralDSDE(drift_dudt, diffusion_dudt, tspan,
                        EulerHeun(),dt=1e-2,
                        sensealg = BacksolveAdjoint(),
                        saveat = tsteps)

function predict_neuralsdebacksolve(p)
  return sum(Array(neuralsde_backsolve(u0, p)))
end

Random.seed!(1001)
@time Zygote.gradient(predict_neuralsdebacksolve,p)
Random.seed!(1001)
@time Zygote.gradient(predict_neuralsdebacksolve,p)
Random.seed!(1001)
@time Zygote.gradient(predict_neuralsdebacksolve,p)

# Ito Fails : Mutating arrays not supported

neuralsde_backsolveem = NeuralDSDE(drift_dudt, diffusion_dudt, tspan,
                        EM(),dt=1e-2,
                        sensealg = BacksolveAdjoint(),
                        saveat = tsteps)

function predict_neuralsde_backsolveem(p)
  return sum(Array(neuralsde_backsolveem(u0, p)))
end

Random.seed!(1001)
@time Zygote.gradient(predict_neuralsde_backsolveem,p)
Random.seed!(1001)
@time Zygote.gradient(predict_neuralsde_backsolveem,p)
Random.seed!(1001)
@time Zygote.gradient(predict_neuralsde_backsolveem,p)

# Adaptive Ito Works

neuralsde1_adaptive = NeuralDSDE(drift_dudt, diffusion_dudt, tspan, SOSRI(),
                        sensealg = TrackerAdjoint(),
                        saveat = tsteps, reltol=1e-1, abstol=1e-1)

function predict_neuralsde_adaptive(p)
  return sum(Array(neuralsde1_adaptive(u0, p)))
end

using Random
Random.seed!(1001)
@time Zygote.gradient(predict_neuralsde_adaptive,p)
Random.seed!(1001)
@time Zygote.gradient(predict_neuralsde_adaptive,p)
Random.seed!(1001)
@time Zygote.gradient(predict_neuralsde_adaptive,p)

# Adaptive Ito ReverseDiff : DimensionMismatch

neuralsde1_adaptive_rd = NeuralDSDE(drift_dudt, diffusion_dudt, tspan, SOSRI(),
                            sensealg = ReverseDiffAdjoint(),
                            saveat = tsteps, reltol=1e-1, abstol=1e-1)

function predict_neuralsde_adaptive_rd(p)
  return sum(Array(neuralsde1_adaptive_rd(u0, p)))
end

using Random
Random.seed!(1001)
@time Zygote.gradient(predict_neuralsde_adaptive_rd,p)
Random.seed!(1001)
@time Zygote.gradient(predict_neuralsde_adaptive_rd,p)
Random.seed!(1001)
@time Zygote.gradient(predict_neuralsde_adaptive_rd,p)

# RKMilCommute does not have a ConstantCache dispatch

neuralsde3 = NeuralDSDE(drift_dudt, diffusion_dudt, tspan,
                        RKMilCommute(),dt=1e-4,
                        sensealg = BacksolveAdjoint(),
                        saveat = tsteps)

function predict_neuralsde3(p)
  return sum(Array(neuralsde3(u0, p)))
end

@time Zygote.gradient(predict_neuralsde3,p)
@time Zygote.gradient(predict_neuralsde3,p)
@time Zygote.gradient(predict_neuralsde3,p)

Notice that most pass or have a predictable behavior. The ones that are failing are:

  • Backsolve + Ito: failing because of nested AD. It needs to use Zygote for the vjp, so it's a Zygote over Zygote failure. Should be fixed by Diffractor, so for now don't touch it. It works if you differentiate with ReverseDiff.
  • RKMilCommute is missing an OOP dispatch. Not related to AD at all. RKMilCommute is missing a dispatch for OOP StochasticDiffEq.jl#372

The one that is peculiar is ReverseDiffAdjoint + the adaptive algorithm. I'll see if I can get an MWE.

@frankschae let me know if I missed anything big in this summary.

@ChrisRackauckas
Copy link
Member Author

@mohamed82008 I have a reduced version of the ReverseDiffAdjoint issue:

using StochasticDiffEq, DiffEqSensitivity, Zygote

function loss(p)
  f(u,p,t) = u
  g(u,p,t) = u
  prob = SDEProblem{false}(f,g,Float32[2.; 0.],(0.0f0, 1.0f0),p)
  sol = solve(prob, SOSRI();sensealg=ReverseDiffAdjoint())
  return sum(Array(sol))
end

@time Zygote.gradient(loss,[1f0])

@frankschae
Copy link
Member

Was there already an attempt to fix the ReverseDiffAdjoint issues? For the MWE from above, the error (probably?) changed in the mean time:

using StochasticDiffEq, DiffEqSensitivity, Zygote, Random
function loss1(p)
  f(u,p,t) = u
  g(u,p,t) = u
  prob = SDEProblem{false}(f,g,Float32[2.; 0.],(0.0f0, 1.0f0),p)
  sol = solve(prob, SOSRI();sensealg=ReverseDiffAdjoint())
  return sum(Array(sol))
end

@time Zygote.gradient(loss1,[1f0])

#  type Float32 has no field partials

Adding a dt = to the solve call changes the error message

function loss1dt(p)
  f(u,p,t) = u
  g(u,p,t) = u
  prob = SDEProblem{false}(f,g,Float32[2.; 0.],(0.0f0, 1.0f0),p)
  sol = solve(prob, SOSRI(), dt=0.01;sensealg=ReverseDiffAdjoint())
  return sum(Array(sol))
end

@time Zygote.gradient(loss1dt,[1f0])

# Mutating arrays is not supported

and ultimately writing it as:

f2(u,p,t) = u
g2(u,p,t) = u
prob2 = SDEProblem{false}(f2,g2,Float32[2.; 0.],(0.0f0, 1.0f0),[1f0])
function loss5(p,prob,sensealg)
  _prob = remake(prob)
  sol = solve(_prob,SOSRI(),dt=0.01;sensealg=sensealg)
  return sum(Array(sol))
end
@time Zygote.gradient(p->loss5(p,prob2,ReverseDiffAdjoint()),[1f0])

is fine.

Is there anything odd with the MWE in general? Also for TrackerAdjoint()

function loss3(p)
  f(u,p,t) = u
  g(u,p,t) = u
  prob = SDEProblem{false}(f,g,Float32[2.; 0.],(0.0f0, 1.0f0),p)
  sol = solve(prob, SOSRI();sensealg=TrackerAdjoint())
  return sum(Array(sol))
end
@time Zygote.gradient(loss3,[1f0])
# Mutating arrays is not supported

the mutation error is thrown -- pointing to the prob = .. line.

It doesn't seem to me that the error is all that much related to SDEs at the moment because

using OrdinaryDiffEq, DiffEqSensitivity, Zygote, Random
function lossODE(p)
  f(u,p,t) = u
  prob = ODEProblem{false}(f,Float32[2.; 0.],(0.0f0, 1.0f0),p)
  sol = solve(prob, Tsit5();sensealg=ReverseDiffAdjoint())
  return sum(Array(sol))
end

@time Zygote.gradient(lossODE,[1f0])

# Mutating arrays is not supported

also fails.

@ChrisRackauckas
Copy link
Member Author

@ChrisRackauckas
Copy link
Member Author

using StochasticDiffEq, DiffEqSensitivity, Zygote, Random
function loss1(p)
  f(u,p,t) = u
  g(u,p,t) = u
  prob = SDEProblem{false}(f,g,Float32[2., 0.],(0.0f0, 1.0f0),p)
  sol = solve(prob, SOSRI();dt=0.01,sensealg=ReverseDiffAdjoint())
  return sum(Array(sol))
end

@time Zygote.gradient(loss1,[1f0])

is fine, so the issue is just Float32[2.;0.] uses hvcat which internally mutates 🤦 . That's an AD issue and not our issue.

@ChrisRackauckas
Copy link
Member Author

Most things are working here now. I'll split to specific issues.

@ChrisRackauckas
Copy link
Member Author

The two remaining issues were opened, and they are both upstream AD issues with the respective maintainers tagged.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants