From 255c1b13a941f7edf90acc2754996858a46f864b Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 7 May 2024 17:55:54 +0200 Subject: [PATCH 1/4] Add ADTypes sparsity detector --- Project.toml | 2 ++ docs/src/manual/sparsity_detection.md | 6 ++++++ src/Symbolics.jl | 6 ++++++ src/adtypes.jl | 31 +++++++++++++++++++++++++++ test/adtypes.jl | 19 ++++++++++++++++ test/runtests.jl | 1 + 6 files changed, 65 insertions(+) create mode 100644 src/adtypes.jl create mode 100644 test/adtypes.jl diff --git a/Project.toml b/Project.toml index a41f01e53..5cbe39fe0 100644 --- a/Project.toml +++ b/Project.toml @@ -4,6 +4,7 @@ authors = ["Shashi Gowda "] version = "5.28.1" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" @@ -52,6 +53,7 @@ SymbolicsPreallocationToolsExt = "PreallocationTools" SymbolicsSymPyExt = "SymPy" [compat] +ADTypes = "1.0" ArrayInterface = "7" Bijections = "0.1" ConstructionBase = "1.2" diff --git a/docs/src/manual/sparsity_detection.md b/docs/src/manual/sparsity_detection.md index 03e0843a4..9d974cdca 100644 --- a/docs/src/manual/sparsity_detection.md +++ b/docs/src/manual/sparsity_detection.md @@ -27,3 +27,9 @@ Symbolics.hessian_sparsity Symbolics.islinear Symbolics.isaffine ``` + +## ADTypes.jl interface + +```@docs +Symbolics.SymbolicsSparsityDetector +``` diff --git a/src/Symbolics.jl b/src/Symbolics.jl index b48673c71..d316fe608 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -39,6 +39,8 @@ using PrecompileTools using SymbolicIndexingInterface import SymbolicLimits + + using ADTypes: ADTypes end @reexport using SymbolicUtils RuntimeGeneratedFunctions.init(@__MODULE__) @@ -103,6 +105,10 @@ export Differential, expand_derivatives, is_derivative include("diff.jl") +export SymbolicsSparsityDetector + +include("adtypes.jl") + export Difference, DiscreteUpdate include("difference.jl") diff --git a/src/adtypes.jl b/src/adtypes.jl new file mode 100644 index 000000000..41bf9fe36 --- /dev/null +++ b/src/adtypes.jl @@ -0,0 +1,31 @@ +""" + SymbolicsSparsityDetector <: ADTypes.AbstractSparsityDetector + +Sparsity detection algorithm based on the [Symbolics.jl tracing system](https://symbolics.juliasymbolics.org/stable/manual/sparsity_detection/). + +This type makes Symbolics.jl compatible with the [ADTypes.jl sparsity detection framework](https://sciml.github.io/ADTypes.jl/stable/#Sparsity-detector). The following functions are implemented: + +- `ADTypes.jacobian_sparsity` based on [`Symbolics.jacobian_sparsity`](@ref) +- `ADTypes.hessian_sparsity` based on [`Symbolics.hessian_sparsity`](@ref) + +# Reference + +> [Sparsity Programming: Automated Sparsity-Aware Optimizations in Differentiable Programming](https://openreview.net/forum?id=rJlPdcY38B), Gowda et al. (2019) +""" +struct SymbolicsSparsityDetector <: ADTypes.AbstractSparsityDetector end + +function ADTypes.jacobian_sparsity(f, x::AbstractArray, ::SymbolicsSparsityDetector) + y = similar(f(x)) + f!(y, x) = copyto!(y, f(x)) + return jacobian_sparsity(f!, y, x) +end + +function ADTypes.jacobian_sparsity(f!, y::AbstractArray, x::AbstractArray, ::SymbolicsSparsityDetector) + f!_vec(y_vec, x_vec) = f!(reshape(y_vec, size(y)), reshape(x_vec, size(x))) + return jacobian_sparsity(f!_vec, vec(y), vec(x)) +end + +function ADTypes.hessian_sparsity(f, x::AbstractArray, ::SymbolicsSparsityDetector) + f_vec(x_vec) = f(reshape(x_vec, size(x))) + return hessian_sparsity(f_vec, vec(x)) +end diff --git a/test/adtypes.jl b/test/adtypes.jl new file mode 100644 index 000000000..93643cd39 --- /dev/null +++ b/test/adtypes.jl @@ -0,0 +1,19 @@ +using Symbolics: Symbolics, SymbolicsSparsityDetector +using ADTypes: ADTypes +using Test + +detector = SymbolicsSparsityDetector() + +f(x) = reshape(vcat(x[1], diff(vec(x))), size(x)) +f!(y, x) = copyto!(vec(y), vec(f(x))) + +for x in (rand(4), rand(4, 5)) + @test sum(ADTypes.jacobian_sparsity(f, x, detector)) == 2length(x) - 1 + @test sum(ADTypes.jacobian_sparsity(f!, similar(x), x, detector)) == 2length(x) - 1 +end + +g(x) = sum(abs2, diff(vec(x))) + +for x in (rand(4), rand(4, 5)) + @test sum(ADTypes.hessian_sparsity(g, x, detector)) == 3length(x) - 2 +end diff --git a/test/runtests.jl b/test/runtests.jl index fbd79cab7..fe7e614b7 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -26,6 +26,7 @@ if GROUP == "All" || GROUP == "Core" @safetestset "Semi-polynomial" begin include("semipoly.jl") end @safetestset "Fuzz Arrays" begin include("fuzz-arrays.jl") end @safetestset "Differentiation Test" begin include("diff.jl") end + @safetestset "ADTypes Test" begin include("adtypes.jl") end @safetestset "Difference Test" begin include("difference.jl") end @safetestset "Degree Test" begin include("degree.jl") end @safetestset "Coeff Test" begin include("coeff.jl") end From 3c3229cc10c9aba7befa08306f3e6e05d639cdd3 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 7 May 2024 17:57:30 +0200 Subject: [PATCH 2/4] Test subtype --- test/adtypes.jl | 1 + 1 file changed, 1 insertion(+) diff --git a/test/adtypes.jl b/test/adtypes.jl index 93643cd39..20ce26e7d 100644 --- a/test/adtypes.jl +++ b/test/adtypes.jl @@ -3,6 +3,7 @@ using ADTypes: ADTypes using Test detector = SymbolicsSparsityDetector() +@test detector isa ADTypes.AbstractSparsityDetector f(x) = reshape(vcat(x[1], diff(vec(x))), size(x)) f!(y, x) = copyto!(vec(y), vec(f(x))) From ab815ed7dff553e7f738d42636ffcb61b26ec743 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 7 May 2024 17:59:30 +0200 Subject: [PATCH 3/4] Test exported --- test/adtypes.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/adtypes.jl b/test/adtypes.jl index 20ce26e7d..6417f3f35 100644 --- a/test/adtypes.jl +++ b/test/adtypes.jl @@ -1,4 +1,4 @@ -using Symbolics: Symbolics, SymbolicsSparsityDetector +using Symbolics using ADTypes: ADTypes using Test From 535f88982594ed5279b5bd33237744ad6978af07 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Tue, 7 May 2024 18:00:14 +0200 Subject: [PATCH 4/4] Bump version to 5.29.1 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5cbe39fe0..d34145063 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Symbolics" uuid = "0c5d862f-8b57-4792-8d23-62f2024744c7" authors = ["Shashi Gowda "] -version = "5.28.1" +version = "5.29.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"