Skip to content

Commit

Permalink
Fix nesting of forwarddiff and sparsity detection
Browse files Browse the repository at this point in the history
Duo with SciML/PreallocationTools.jl#91, preventing circular dependency
  • Loading branch information
ChrisRackauckas committed Dec 31, 2023
1 parent 4ea5cf0 commit 7513ec6
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 1 deletion.
5 changes: 4 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,11 @@ SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"

[weakdeps]
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"

[extensions]
SymbolicsPreallocationToolsExt = "Symbolics"
SymbolicsSymPyExt = "SymPy"

[compat]
Expand Down Expand Up @@ -79,11 +81,12 @@ julia = "1.6"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
PkgBenchmark = "32113eaa-f34f-5b0d-bd6c-c81e245fc73d"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReferenceTests = "324d217c-45ce-50fc-942e-d289b448e8cf"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SymPy = "24249f21-da20-56a4-8eb1-6a02cf4ae2e6"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "BenchmarkTools", "ReferenceTests", "SymPy", "Random"]
test = ["Test", "SafeTestsets", "Pkg", "PkgBenchmark", "PreallocationTools", "BenchmarkTools", "ReferenceTests", "SymPy", "Random"]
49 changes: 49 additions & 0 deletions ext/SymbolicsPreallocationToolsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
module PreallocationToolsSymbolicsExt

using PreallocationTools
import PreallocationTools: _restructure, get_tmp
using Symbolics, ForwardDiff

function get_tmp(dc::DiffCache, u::Type{X}) where {T,N, X<: ForwardDiff.Dual{T, Num, N}}
if length(dc.du) > length(dc.any_du)
resize!(dc.any_du, length(dc.du))
end
_restructure(dc.du, dc.any_du)
end

function get_tmp(dc::DiffCache, u::X) where {T,N, X<: ForwardDiff.Dual{T, Num, N}}
if length(dc.du) > length(dc.any_du)
resize!(dc.any_du, length(dc.du))
end
_restructure(dc.du, dc.any_du)
end

function get_tmp(dc::DiffCache, u::AbstractArray{X}) where {T,N, X<: ForwardDiff.Dual{T, Num, N}}
if length(dc.du) > length(dc.any_du)
resize!(dc.any_du, length(dc.du))
end
_restructure(dc.du, dc.any_du)
end

function get_tmp(dc::FixedSizeDiffCache, u::Type{X}) where {T,N, X<: ForwardDiff.Dual{T, Num, N}}
if length(dc.du) > length(dc.any_du)
resize!(dc.any_du, length(dc.du))
end
_restructure(dc.du, dc.any_du)
end

function get_tmp(dc::FixedSizeDiffCache, u::X) where {T,N, X<: ForwardDiff.Dual{T, Num, N}}
if length(dc.du) > length(dc.any_du)
resize!(dc.any_du, length(dc.du))
end
_restructure(dc.du, dc.any_du)
end

function get_tmp(dc::FixedSizeDiffCache, u::AbstractArray{X}) where {T,N, X<: ForwardDiff.Dual{T, Num, N}}
if length(dc.du) > length(dc.any_du)
resize!(dc.any_du, length(dc.du))
end
_restructure(dc.du, dc.any_du)
end

end
23 changes: 23 additions & 0 deletions test/nested_forwarddiff_sparsity.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using ForwardDiff, SparseArrays, Symbolics, PreallocationTools
# Test Nesting https://discourse.julialang.org/t/preallocationtools-jl-with-nested-forwarddiff-and-sparsity-pattern-detection-errors/107897

function foo(x, cache)
d = get_tmp(cache, x)

d[:] = x

0.5 * x'*x
end

function residual(r, x, cache)
function foo_wrap(x)
foo(x, cache)
end

r[:] = ForwardDiff.gradient(foo_wrap, x)
end

cache = DiffCache(zeros(2))
pattern = Symbolics.jacobian_sparsity((r, x) -> residual(r, x, cache), zeros(2), zeros(2))
@test pattern == sparse([1 0
0 1])
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "Algebraic Solver Test" begin include("solver.jl") end
@safetestset "Groebner Bases Test" begin include("groebner_basis.jl") end
@safetestset "Overloading Test" begin include("overloads.jl") end
@safetestset "Nested ForwardDiff Sparsity Test" begin include("nested_forwarddiff_sparsity.jl") end
@safetestset "Build Function Test" begin include("build_function.jl") end
@safetestset "Build Function Array Test" begin include("build_function_arrayofarray.jl") end
@safetestset "Build Function Array Test Named Tuples" begin include("build_function_arrayofarray_named_tuples.jl") end
Expand Down

0 comments on commit 7513ec6

Please sign in to comment.