Skip to content

Commit

Permalink
Merge pull request #1026 from SciML/aos_to_soa
Browse files Browse the repository at this point in the history
Fix Array of Structs to Struct of Array in ReverseDiff
  • Loading branch information
ChrisRackauckas authored May 14, 2024
2 parents 52edc2b + 99124d9 commit 548b8ae
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 5 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "1"
ADTypes = "0.1, 0.2, 1.0"
Adapt = "1.0, 2.0, 3.0, 4"
AlgebraicMultigrid = "0.6.0"
Aqua = "0.8.4"
Expand Down Expand Up @@ -81,7 +81,7 @@ PreallocationTools = "0.4.4"
QuadGK = "2.9.1"
Random = "1.10"
RandomNumbers = "1.5.3"
RecursiveArrayTools = "3.5.1"
RecursiveArrayTools = "3.18.1"
Reexport = "1.0"
ReverseDiff = "1.15.1"
SafeTestsets = "0.1.0"
Expand Down
4 changes: 2 additions & 2 deletions src/concrete_solve.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1309,9 +1309,9 @@ function DiffEqBase._concrete_solve_adjoint(
end
else
# use TrackedArray for efficiency of the tape
_f(args...) = reduce(vcat, prob.f(args...))
_f(args...) = ArrayInterface.aos_to_soa(prob.f(args...))
if prob isa SDEProblem
_g(args...) = reduce(vcat, prob.g(args...))
_g(args...) = ArrayInterface.aos_to_soa(prob.g(args...))
_prob = remake(prob,
f = DiffEqBase.parameterless_type(prob.f){
SciMLBase.isinplace(prob),
Expand Down
17 changes: 16 additions & 1 deletion test/adjoint_shapes.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using OrdinaryDiffEq, SciMLSensitivity, Zygote
using OrdinaryDiffEq, SciMLSensitivity, Zygote, ForwardDiff

tspan = (0.0, 1.0)
X = randn(3, 4)
Expand Down Expand Up @@ -53,3 +53,18 @@ solve(
ODEAdjointProblem(fwd_sol, sensealg, Tsit5(), [1.0],
(out, x, p, t, i) -> (out .= 1)),
Tsit5())

# https://github.com/SciML/SciMLSensitivity.jl/issues/581

p = rand(1)

function dudt(u, p, t)
u .* p
end

function loss(p)
prob = ODEProblem(dudt, [3.0], (0.0, 1.0), p)
sol = solve(prob, Tsit5(), dt = 0.01, sensealg = ReverseDiffAdjoint())
sum(abs2, Array(sol))
end
Zygote.gradient(loss, p)[1][1] ForwardDiff.gradient(loss, p)[1]

0 comments on commit 548b8ae

Please sign in to comment.