From 194e276e138f7958ea7999fded2985df3619037f Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Fri, 2 Feb 2024 15:08:39 +0530 Subject: [PATCH] fix: implement ForwarDiff.extract_derivative for AbstractVectorOfArray --- Project.toml | 2 ++ ext/RecursiveArrayToolsForwardDiffExt.jl | 10 ++++++++++ test/adjoints.jl | 5 +++++ 3 files changed, 17 insertions(+) create mode 100644 ext/RecursiveArrayToolsForwardDiffExt.jl diff --git a/Project.toml b/Project.toml index c5e37b15..9d0cf3a8 100644 --- a/Project.toml +++ b/Project.toml @@ -19,6 +19,7 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c" [weakdeps] FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7" MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" @@ -27,6 +28,7 @@ ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" [extensions] RecursiveArrayToolsFastBroadcastExt = "FastBroadcast" +RecursiveArrayToolsForwardDiffExt = "ForwardDiff" RecursiveArrayToolsMeasurementsExt = "Measurements" RecursiveArrayToolsMonteCarloMeasurementsExt = "MonteCarloMeasurements" RecursiveArrayToolsReverseDiffExt = ["ReverseDiff", "Zygote"] diff --git a/ext/RecursiveArrayToolsForwardDiffExt.jl b/ext/RecursiveArrayToolsForwardDiffExt.jl new file mode 100644 index 00000000..f6089f64 --- /dev/null +++ b/ext/RecursiveArrayToolsForwardDiffExt.jl @@ -0,0 +1,10 @@ +module RecursiveArrayToolsForwardDiffExt + +using RecursiveArrayTools +using ForwardDiff + +function ForwardDiff.extract_derivative(::Type{T}, y::AbstractVectorOfArray) where {T} + ForwardDiff.extract_derivative.(T, y) +end + +end diff --git a/test/adjoints.jl b/test/adjoints.jl index c657dcf9..becdc9f2 100644 --- a/test/adjoints.jl +++ b/test/adjoints.jl @@ -62,6 +62,10 @@ function loss8(x) return sum(abs2, res) end +function loss9(x) + return VectorOfArray([collect(3i:3i+3) .* x for i in 1:5]) +end + x = float.(6:10) loss(x) @test Zygote.gradient(loss, x)[1] == ForwardDiff.gradient(loss, x) @@ -72,3 +76,4 @@ loss(x) @test Zygote.gradient(loss6, x)[1] == ForwardDiff.gradient(loss6, x) @test Zygote.gradient(loss7, x)[1] == ForwardDiff.gradient(loss7, x) @test Zygote.gradient(loss8, x)[1] == ForwardDiff.gradient(loss8, x) +@test ForwardDiff.derivative(loss9, 0.0) == VectorOfArray([collect(3i:3i+3) for i in 1:5])