Skip to content

Commit

Permalink
Merge pull request #337 from AayushSabharwal/as/debug-adjoint
Browse files Browse the repository at this point in the history
build: add SciMLSensitivity tests to downstream CI
  • Loading branch information
ChrisRackauckas authored Jan 24, 2024
2 parents c486e62 + 4d6845f commit 067944a
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 15 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/Downstream.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ jobs:
- {user: SciML, repo: OrdinaryDiffEq.jl, group: Core}
- {user: SciML, repo: OrdinaryDiffEq.jl, group: Interface}
- {user: SciML, repo: DelayDiffEq.jl, group: Interface}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core1}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core2}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core3}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core4}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core5}
- {user: SciML, repo: SciMLSensitivity.jl, group: Core6}
steps:
- uses: actions/checkout@v4
- uses: julia-actions/setup-julia@v1
Expand Down
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

[extensions]
RecursiveArrayToolsFastBroadcastExt = "FastBroadcast"
RecursiveArrayToolsMeasurementsExt = "Measurements"
RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements"
RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"]
RecursiveArrayToolsTrackerExt = "Tracker"
RecursiveArrayToolsZygoteExt = "Zygote"

Expand All @@ -49,6 +51,7 @@ OrdinaryDiffEq = "6.62"
Pkg = "1"
Random = "1"
RecipesBase = "1.1"
ReverseDiff = "1.15"
SafeTestsets = "0.1"
SparseArrays = "1.10"
StaticArrays = "1.6"
Expand Down
25 changes: 25 additions & 0 deletions ext/RecursiveArrayToolsReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module RecursiveArrayToolsReverseDiffExt

using RecursiveArrayTools
using ReverseDiff
using Zygote: @adjoint

function trackedarraycopyto!(dest, src)
for (i, slice) in zip(eachindex(dest.u), eachslice(src, dims=ndims(src)))
if dest.u[i] isa AbstractArray
dest.u[i] = reshape(reduce(vcat, slice), size(dest.u[i]))
else
trackedarraycopyto!(dest.u[i], slice)
end
end
end

@adjoint function Array(VA::AbstractVectorOfArray{<:ReverseDiff.TrackedReal})
function Array_adjoint(y)
VA = recursivecopy(VA)
trackedarraycopyto!(VA, y)
return (VA,)
end
return Array(VA), Array_adjoint
end
end # module
22 changes: 14 additions & 8 deletions ext/RecursiveArrayToolsZygoteExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,23 +110,29 @@ end
@adjoint function Base.Array(VA::AbstractVectorOfArray)
adj = let VA=VA
function Array_adjoint(y)
VA = copy(VA)
VA = recursivecopy(VA)
copyto!(VA, y)
return (VA,)
end
end
Array(VA), adj
end

@adjoint function Base.view(A::AbstractVectorOfArray, I::Colon...)
function adjoint(y)
(recursivecopy(parent(y)), map(_ -> nothing, I)...)
end
return view(A, I...), adjoint
end

@adjoint function Base.view(A::AbstractVectorOfArray, I...)
adj = let A = A, I = I
function view_adjoint(y)
A = zero(A)
view(A, I...) .= y
return (A, map(_ -> nothing, I)...)
end
function view_adjoint(y)
A = recursivecopy(parent(y))
recursivefill!(A, zero(eltype(A)))
A[I...] .= y
return (A, map(_ -> nothing, I)...)
end
view(A, I...), adj
view(A, I...), view_adjoint
end

ChainRulesCore.ProjectTo(a::AbstractVectorOfArray) = ChainRulesCore.ProjectTo{VectorOfArray}((sz = size(a)))
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ end

function recursivecopy(a::AbstractVectorOfArray)
b = copy(a)
b.u = recursivecopy.(a.u)
b.u .= recursivecopy.(a.u)
return b
end

Expand Down
22 changes: 16 additions & 6 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -585,16 +585,26 @@ end
function Base.checkbounds(VA::AbstractVectorOfArray, idx...)
checkbounds(Bool, VA, idx...) || throw(BoundsError(VA, idx))
end
function Base.copyto!(dest::AbstractVectorOfArray{T,N}, src::AbstractVectorOfArray{T,N}) where {T,N}
copyto!.(dest.u, src.u)
function Base.copyto!(dest::AbstractVectorOfArray{T,N}, src::AbstractVectorOfArray{T2,N}) where {T, T2, N}
for (i, j) in zip(eachindex(dest.u), eachindex(src.u))
if ArrayInterface.ismutable(dest.u[i]) || dest.u[i] isa AbstractVectorOfArray
copyto!(dest.u[i], src.u[j])
else
dest.u[i] = StaticArraysCore.similar_type(dest.u[i])(src.u[j])
end
end
end
function Base.copyto!(dest::AbstractVectorOfArray{T, N}, src::AbstractArray{T, N}) where {T, N}
for (i, slice) in enumerate(eachslice(src, dims = ndims(src)))
copyto!(dest.u[i], slice)
function Base.copyto!(dest::AbstractVectorOfArray{T, N}, src::AbstractArray{T2, N}) where {T, T2, N}
for (i, slice) in zip(eachindex(dest.u), eachslice(src, dims = ndims(src)))
if ArrayInterface.ismutable(dest.u[i]) || dest.u[i] isa AbstractVectorOfArray
copyto!(dest.u[i], slice)
else
dest.u[i] = StaticArraysCore.similar_type(dest.u[i])(slice)
end
end
dest
end
function Base.copyto!(dest::AbstractVectorOfArray{T, N, <:AbstractVector{T}}, src::AbstractVector{T}) where {T, N}
function Base.copyto!(dest::AbstractVectorOfArray{T, N, <:AbstractVector{T}}, src::AbstractVector{T2}) where {T, T2, N}
copyto!(dest.u, src)
dest
end
Expand Down

0 comments on commit 067944a

Please sign in to comment.