Skip to content

Commit

Permalink
Merge pull request #1134 from gdalle/gd/adtypes
Browse files Browse the repository at this point in the history
Add ADTypes sparsity detector
  • Loading branch information
ChrisRackauckas authored Jun 3, 2024
2 parents 33fe3e9 + 07fdabd commit 630b442
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 2 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
name = "Symbolics"
uuid = "0c5d862f-8b57-4792-8d23-62f2024744c7"
authors = ["Shashi Gowda <[email protected]>"]
version = "5.29.0"
version = "5.30.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Bijections = "e2ed5e7c-b2de-5872-ae92-c73ca462fb04"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
Expand Down Expand Up @@ -53,6 +54,7 @@ SymbolicsPreallocationToolsExt = "PreallocationTools"
SymbolicsSymPyExt = "SymPy"

[compat]
ADTypes = "1.0"
ArrayInterface = "7"
Bijections = "0.1"
ConstructionBase = "1.2"
Expand Down
6 changes: 6 additions & 0 deletions docs/src/manual/sparsity_detection.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,9 @@ Symbolics.hessian_sparsity
Symbolics.islinear
Symbolics.isaffine
```

## ADTypes.jl interface

```@docs
Symbolics.SymbolicsSparsityDetector
```
6 changes: 6 additions & 0 deletions src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ using PrecompileTools
using SymbolicIndexingInterface

import SymbolicLimits

using ADTypes: ADTypes
end
@reexport using SymbolicUtils
RuntimeGeneratedFunctions.init(@__MODULE__)
Expand Down Expand Up @@ -104,6 +106,10 @@ export Differential, expand_derivatives, is_derivative

include("diff.jl")

export SymbolicsSparsityDetector

include("adtypes.jl")

export Difference, DiscreteUpdate

include("difference.jl")
Expand Down
31 changes: 31 additions & 0 deletions src/adtypes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
"""
SymbolicsSparsityDetector <: ADTypes.AbstractSparsityDetector
Sparsity detection algorithm based on the [Symbolics.jl tracing system](https://symbolics.juliasymbolics.org/stable/manual/sparsity_detection/).
This type makes Symbolics.jl compatible with the [ADTypes.jl sparsity detection framework](https://sciml.github.io/ADTypes.jl/stable/#Sparsity-detector). The following functions are implemented:
- `ADTypes.jacobian_sparsity` based on [`Symbolics.jacobian_sparsity`](@ref)
- `ADTypes.hessian_sparsity` based on [`Symbolics.hessian_sparsity`](@ref)
# Reference
> [Sparsity Programming: Automated Sparsity-Aware Optimizations in Differentiable Programming](https://openreview.net/forum?id=rJlPdcY38B), Gowda et al. (2019)
"""
struct SymbolicsSparsityDetector <: ADTypes.AbstractSparsityDetector end

function ADTypes.jacobian_sparsity(f, x::AbstractArray, ::SymbolicsSparsityDetector)
y = similar(f(x))
f!(y, x) = copyto!(y, f(x))
return jacobian_sparsity(f!, y, x)
end

function ADTypes.jacobian_sparsity(f!, y::AbstractArray, x::AbstractArray, ::SymbolicsSparsityDetector)
f!_vec(y_vec, x_vec) = f!(reshape(y_vec, size(y)), reshape(x_vec, size(x)))
return jacobian_sparsity(f!_vec, vec(y), vec(x))
end

function ADTypes.hessian_sparsity(f, x::AbstractArray, ::SymbolicsSparsityDetector)
f_vec(x_vec) = f(reshape(x_vec, size(x)))
return hessian_sparsity(f_vec, vec(x))
end
20 changes: 20 additions & 0 deletions test/adtypes.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using Symbolics
using ADTypes: ADTypes
using Test

detector = SymbolicsSparsityDetector()
@test detector isa ADTypes.AbstractSparsityDetector

f(x) = reshape(vcat(x[1], diff(vec(x))), size(x))
f!(y, x) = copyto!(vec(y), vec(f(x)))

for x in (rand(4), rand(4, 5))
@test sum(ADTypes.jacobian_sparsity(f, x, detector)) == 2length(x) - 1
@test sum(ADTypes.jacobian_sparsity(f!, similar(x), x, detector)) == 2length(x) - 1
end

g(x) = sum(abs2, diff(vec(x)))

for x in (rand(4), rand(4, 5))
@test sum(ADTypes.hessian_sparsity(g, x, detector)) == 3length(x) - 2
end
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limit(a, N) = a == N + 1 ? 1 : a == 0 ? N : a
@register_symbolic limit(a, N)::Integer

if GROUP == "All" || GROUP == "Core"
@testset begin
@testset begin
@safetestset "Struct Test" begin include("struct.jl") end
@safetestset "Macro Test" begin include("macro.jl") end
@safetestset "Arrays" begin include("arrays.jl") end
Expand All @@ -27,6 +27,7 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "Semi-polynomial" begin include("semipoly.jl") end
@safetestset "Fuzz Arrays" begin include("fuzz-arrays.jl") end
@safetestset "Differentiation Test" begin include("diff.jl") end
@safetestset "ADTypes Test" begin include("adtypes.jl") end
@safetestset "Difference Test" begin include("difference.jl") end
@safetestset "Degree Test" begin include("degree.jl") end
@safetestset "Coeff Test" begin include("coeff.jl") end
Expand Down

0 comments on commit 630b442

Please sign in to comment.