diff --git a/test/adjoint_shapes.jl b/test/adjoint_shapes.jl index ba8f49453..7be7c0bcd 100644 --- a/test/adjoint_shapes.jl +++ b/test/adjoint_shapes.jl @@ -53,3 +53,18 @@ solve( ODEAdjointProblem(fwd_sol, sensealg, Tsit5(), [1.0], (out, x, p, t, i) -> (out .= 1)), Tsit5()) + +# https://github.com/SciML/SciMLSensitivity.jl/issues/581 + +p = rand(1) + +function dudt(u, p, t) + u .* p +end + +function loss(p) + prob = ODEProblem(dudt, [3.0], (0.0, 1.0), p) + sol = solve(prob, Tsit5(), dt=0.01, sensealg=ReverseDiffAdjoint()) + sum(abs2, Array(sol)) +end +Zygote.gradient(loss, p)