diff --git a/Project.toml b/Project.toml index ad18348e1..9a8674a93 100644 --- a/Project.toml +++ b/Project.toml @@ -1,9 +1,10 @@ name = "Symbolics" uuid = "0c5d862f-8b57-4792-8d23-62f2024744c7" authors = ["Shashi Gowda "] -version = "5.29.0" +version = "5.30.0" [deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" @@ -53,6 +54,7 @@ SymbolicsPreallocationToolsExt = "PreallocationTools" SymbolicsSymPyExt = "SymPy" [compat] +ADTypes = "1.0" ArrayInterface = "7" Bijections = "0.1" ConstructionBase = "1.2" diff --git a/docs/src/manual/sparsity_detection.md b/docs/src/manual/sparsity_detection.md index 03e0843a4..9d974cdca 100644 --- a/docs/src/manual/sparsity_detection.md +++ b/docs/src/manual/sparsity_detection.md @@ -27,3 +27,9 @@ Symbolics.hessian_sparsity Symbolics.islinear Symbolics.isaffine ``` + +## ADTypes.jl interface + +```@docs +Symbolics.SymbolicsSparsityDetector +``` diff --git a/src/Symbolics.jl b/src/Symbolics.jl index a20ab7f01..598a91509 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -40,6 +40,8 @@ using PrecompileTools using SymbolicIndexingInterface import SymbolicLimits + + using ADTypes: ADTypes end @reexport using SymbolicUtils RuntimeGeneratedFunctions.init(@__MODULE__) @@ -104,6 +106,10 @@ export Differential, expand_derivatives, is_derivative include("diff.jl") +export SymbolicsSparsityDetector + +include("adtypes.jl") + export Difference, DiscreteUpdate include("difference.jl") diff --git a/src/adtypes.jl b/src/adtypes.jl new file mode 100644 index 000000000..41bf9fe36 --- /dev/null +++ b/src/adtypes.jl @@ -0,0 +1,31 @@ +""" + SymbolicsSparsityDetector <: ADTypes.AbstractSparsityDetector + +Sparsity detection algorithm based on the [Symbolics.jl tracing system](https://symbolics.juliasymbolics.org/stable/manual/sparsity_detection/). + +This type makes Symbolics.jl compatible with the [ADTypes.jl sparsity detection framework](https://sciml.github.io/ADTypes.jl/stable/#Sparsity-detector). The following functions are implemented: + +- `ADTypes.jacobian_sparsity` based on [`Symbolics.jacobian_sparsity`](@ref) +- `ADTypes.hessian_sparsity` based on [`Symbolics.hessian_sparsity`](@ref) + +# Reference + +> [Sparsity Programming: Automated Sparsity-Aware Optimizations in Differentiable Programming](https://openreview.net/forum?id=rJlPdcY38B), Gowda et al. (2019) +""" +struct SymbolicsSparsityDetector <: ADTypes.AbstractSparsityDetector end + +function ADTypes.jacobian_sparsity(f, x::AbstractArray, ::SymbolicsSparsityDetector) + y = similar(f(x)) + f!(y, x) = copyto!(y, f(x)) + return jacobian_sparsity(f!, y, x) +end + +function ADTypes.jacobian_sparsity(f!, y::AbstractArray, x::AbstractArray, ::SymbolicsSparsityDetector) + f!_vec(y_vec, x_vec) = f!(reshape(y_vec, size(y)), reshape(x_vec, size(x))) + return jacobian_sparsity(f!_vec, vec(y), vec(x)) +end + +function ADTypes.hessian_sparsity(f, x::AbstractArray, ::SymbolicsSparsityDetector) + f_vec(x_vec) = f(reshape(x_vec, size(x))) + return hessian_sparsity(f_vec, vec(x)) +end diff --git a/test/adtypes.jl b/test/adtypes.jl new file mode 100644 index 000000000..6417f3f35 --- /dev/null +++ b/test/adtypes.jl @@ -0,0 +1,20 @@ +using Symbolics +using ADTypes: ADTypes +using Test + +detector = SymbolicsSparsityDetector() +@test detector isa ADTypes.AbstractSparsityDetector + +f(x) = reshape(vcat(x[1], diff(vec(x))), size(x)) +f!(y, x) = copyto!(vec(y), vec(f(x))) + +for x in (rand(4), rand(4, 5)) + @test sum(ADTypes.jacobian_sparsity(f, x, detector)) == 2length(x) - 1 + @test sum(ADTypes.jacobian_sparsity(f!, similar(x), x, detector)) == 2length(x) - 1 +end + +g(x) = sum(abs2, diff(vec(x))) + +for x in (rand(4), rand(4, 5)) + @test sum(ADTypes.hessian_sparsity(g, x, detector)) == 3length(x) - 2 +end diff --git a/test/runtests.jl b/test/runtests.jl index 4d848e695..5573573eb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -18,7 +18,7 @@ limit(a, N) = a == N + 1 ? 1 : a == 0 ? N : a @register_symbolic limit(a, N)::Integer if GROUP == "All" || GROUP == "Core" - @testset begin + @testset begin @safetestset "Struct Test" begin include("struct.jl") end @safetestset "Macro Test" begin include("macro.jl") end @safetestset "Arrays" begin include("arrays.jl") end @@ -27,6 +27,7 @@ if GROUP == "All" || GROUP == "Core" @safetestset "Semi-polynomial" begin include("semipoly.jl") end @safetestset "Fuzz Arrays" begin include("fuzz-arrays.jl") end @safetestset "Differentiation Test" begin include("diff.jl") end + @safetestset "ADTypes Test" begin include("adtypes.jl") end @safetestset "Difference Test" begin include("difference.jl") end @safetestset "Degree Test" begin include("degree.jl") end @safetestset "Coeff Test" begin include("coeff.jl") end