Skip to content

Commit

Permalink
fix: add RecursiveArrayToolsReverseDiffExt
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Jan 22, 2024
1 parent 13b2a67 commit 4d6845f
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 0 deletions.
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

[extensions]
RecursiveArrayToolsFastBroadcastExt = "FastBroadcast"
RecursiveArrayToolsMeasurementsExt = "Measurements"
RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements"
RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"]
RecursiveArrayToolsTrackerExt = "Tracker"
RecursiveArrayToolsZygoteExt = "Zygote"

Expand All @@ -49,6 +51,7 @@ OrdinaryDiffEq = "6.62"
Pkg = "1"
Random = "1"
RecipesBase = "1.1"
ReverseDiff = "1.15"
SafeTestsets = "0.1"
SparseArrays = "1.10"
StaticArrays = "1.6"
Expand Down
25 changes: 25 additions & 0 deletions ext/RecursiveArrayToolsReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
module RecursiveArrayToolsReverseDiffExt

using RecursiveArrayTools
using ReverseDiff
using Zygote: @adjoint

function trackedarraycopyto!(dest, src)
for (i, slice) in zip(eachindex(dest.u), eachslice(src, dims=ndims(src)))
if dest.u[i] isa AbstractArray
dest.u[i] = reshape(reduce(vcat, slice), size(dest.u[i]))

Check warning on line 10 in ext/RecursiveArrayToolsReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsReverseDiffExt.jl#L7-L10

Added lines #L7 - L10 were not covered by tests
else
trackedarraycopyto!(dest.u[i], slice)

Check warning on line 12 in ext/RecursiveArrayToolsReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsReverseDiffExt.jl#L12

Added line #L12 was not covered by tests
end
end

Check warning on line 14 in ext/RecursiveArrayToolsReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsReverseDiffExt.jl#L14

Added line #L14 was not covered by tests
end

@adjoint function Array(VA::AbstractVectorOfArray{<:ReverseDiff.TrackedReal})
function Array_adjoint(y)
VA = recursivecopy(VA)
trackedarraycopyto!(VA, y)
return (VA,)

Check warning on line 21 in ext/RecursiveArrayToolsReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsReverseDiffExt.jl#L17-L21

Added lines #L17 - L21 were not covered by tests
end
return Array(VA), Array_adjoint

Check warning on line 23 in ext/RecursiveArrayToolsReverseDiffExt.jl

View check run for this annotation

Codecov / codecov/patch

ext/RecursiveArrayToolsReverseDiffExt.jl#L23

Added line #L23 was not covered by tests
end
end # module

0 comments on commit 4d6845f

Please sign in to comment.