Skip to content

Commit

Permalink
Merge pull request #347 from AayushSabharwal/as/forwarddiff
Browse files Browse the repository at this point in the history
fix: implement ForwarDiff.extract_derivative for AbstractVectorOfArray
  • Loading branch information
ChrisRackauckas authored Feb 2, 2024
2 parents c976b29 + 194e276 commit daff140
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 0 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"

[weakdeps]
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Expand All @@ -27,6 +28,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"

[extensions]
RecursiveArrayToolsFastBroadcastExt = "FastBroadcast"
RecursiveArrayToolsForwardDiffExt = "ForwardDiff"
RecursiveArrayToolsMeasurementsExt = "Measurements"
RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements"
RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"]
Expand Down
10 changes: 10 additions & 0 deletions ext/RecursiveArrayToolsForwardDiffExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
module RecursiveArrayToolsForwardDiffExt

using RecursiveArrayTools
using ForwardDiff

function ForwardDiff.extract_derivative(::Type{T}, y::AbstractVectorOfArray) where {T}
ForwardDiff.extract_derivative.(T, y)
end

end
5 changes: 5 additions & 0 deletions test/adjoints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,10 @@ function loss8(x)
return sum(abs2, res)
end

function loss9(x)
return VectorOfArray([collect(3i:3i+3) .* x for i in 1:5])
end

x = float.(6:10)
loss(x)
@test Zygote.gradient(loss, x)[1] == ForwardDiff.gradient(loss, x)
Expand All @@ -72,3 +76,4 @@ loss(x)
@test Zygote.gradient(loss6, x)[1] == ForwardDiff.gradient(loss6, x)
@test Zygote.gradient(loss7, x)[1] == ForwardDiff.gradient(loss7, x)
@test Zygote.gradient(loss8, x)[1] == ForwardDiff.gradient(loss8, x)
@test ForwardDiff.derivative(loss9, 0.0) == VectorOfArray([collect(3i:3i+3) for i in 1:5])

0 comments on commit daff140

Please sign in to comment.