Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: remove RecursiveArrayTools dependency, use SymbolicIndexingInterface #1005

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
RuntimeGeneratedFunctions = "7e49a35a-f44a-4d26-94aa-eba1b4ca6b47"
Expand All @@ -35,6 +34,7 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
TreeViews = "a2a6695c-b41b-5b7d-aed9-dbfdeacea5d7"

Expand Down Expand Up @@ -64,7 +64,6 @@ MacroTools = "0.5"
NaNMath = "0.3, 1"
PrecompileTools = "1"
RecipesBase = "1.1"
RecursiveArrayTools = "2"
Reexport = "0.2, 1"
ReferenceTests = "0.9"
Requires = "1.1"
Expand All @@ -73,6 +72,7 @@ SciMLBase = "1.8, 2"
Setfield = "0.7, 0.8, 1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "1.1"
SymbolicIndexingInterface = "0.3"
SymbolicUtils = "1.4"
TreeViews = "0.3"
julia = "1.6"
Expand Down
2 changes: 2 additions & 0 deletions src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ using PrecompileTools
using RuntimeGeneratedFunctions
using SciMLBase, IfElse
using MacroTools

using SymbolicIndexingInterface
end
@reexport using SymbolicUtils
RuntimeGeneratedFunctions.init(@__MODULE__)
Expand Down
6 changes: 0 additions & 6 deletions src/num.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,6 @@ Num(x::Num) = x # ideally this should never be called
(n::Num)(args...) = Num(value(n)(map(value,args)...))
value(x) = unwrap(x)

SciMLBase.issymbollike(::Num) = true
SciMLBase.issymbollike(::SymbolicUtils.Symbolic) = true

SymbolicUtils.@number_methods(
Num,
Num(f(value(a))),
Expand Down Expand Up @@ -197,6 +194,3 @@ function Base.Docs.getdoc(x::Num)
end
Markdown.parse(join(strings, "\n\n "))
end

using RecursiveArrayTools
RecursiveArrayTools.issymbollike(::Union{BasicSymbolic,Num}) = true
11 changes: 10 additions & 1 deletion src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,16 @@ end

getsource(x, val=_fail) = getmetadata(unwrap(x), VariableSource, val)

getname(x, val=_fail) = _getname(unwrap(x), val)
SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Num}) = ScalarSymbolic()
SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Arr}) = ArraySymbolic()

SymbolicIndexingInterface.hasname(x::Union{Num,Arr}) = hasname(unwrap(x))

function SymbolicIndexingInterface.hasname(x::Symbolic)
issym(x) || !istree(x) || istree(x) && (issym(operation(x)) || operation(x) == getindex)
end

SymbolicIndexingInterface.getname(x, val=_fail) = _getname(unwrap(x), val)

function getparent(x, val=_fail)
maybe_parent = getmetadata(x, Symbolics.GetindexParent, nothing)
Expand Down
4 changes: 2 additions & 2 deletions src/wrapper-types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ function set_where(subt, supert)
Expr(:where, supert, Ts...)
end

getname(x::Symbol) = x
SymbolicIndexingInterface.getname(x::Symbol) = x

function getname(x::Expr)
function SymbolicIndexingInterface.getname(x::Expr)
@assert x.head == :curly
return x.args[1]
end
Expand Down
3 changes: 0 additions & 3 deletions test/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,3 @@ for f in [<, <=, >, >=, isless]
end

@test_nowarn binomial(t, 1)

using RecursiveArrayTools
@test RecursiveArrayTools.issymbollike(t)
9 changes: 9 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,15 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "LogExpFunctions Test" begin include("logexpfunctions.jl") end
end

if GROUP == "All" || GROUP == "Core" || GROUP == "SymbolicIndexingInterface"
@safetestset "SymbolicIndexingInterface Trait Test" begin
include("symbolic_indexing_interface_trait.jl")
end
@safetestset "SymbolicIndexingInterface Parameter Indexing Test" begin
include("symbolic_indexing_interface_parameter_indexing.jl")
end
end

if GROUP == "Downstream"
activate_downstream_env()
#@time @safetestset "ParameterizedFunctions MATLABDiffEq Regression Test" begin include("downstream/ParameterizedFunctions_MATLAB.jl") end
Expand Down
23 changes: 23 additions & 0 deletions test/symbolic_indexing_interface_parameter_indexing.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
using SymbolicIndexingInterface
using Symbolics

struct FakeIntegrator{P}
p::P
end

SymbolicIndexingInterface.symbolic_container(fp::FakeIntegrator) = fp.sys
SymbolicIndexingInterface.parameter_values(fp::FakeIntegrator) = fp.p

@variables a[1:2] b
sys = SymbolCache([:x, :y, :z], [a[1], a[2], b], [:t])
p = [1.0, 2.0, 3.0]
fi = FakeIntegrator(copy(p))
for (i, sym) in [(1, a[1]), (2, a[2]), (3, b), ([1,2], a), ([1, 3], [a[1], b]), ((2, 3), (a[2], b))]
get = getp(sys, sym)
set! = setp(sys, sym)
true_value = i isa Tuple ? getindex.((p,), i) : p[i]
@test get(fi) == true_value
set!(fi, 0.5 .* i)
@test get(fi) == 0.5 .* i
set!(fi, true_value)
end
12 changes: 12 additions & 0 deletions test/symbolic_indexing_interface_trait.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
using Symbolics
using SymbolicUtils
using SymbolicIndexingInterface

@test all(symbolic_type.([SymbolicUtils.BasicSymbolic, Symbolics.Num]) .==
(ScalarSymbolic(),))
@test symbolic_type(Symbolics.Arr) == ArraySymbolic()
@variables x
@test symbolic_type(x) == ScalarSymbolic()
@variables y[1:3]
@test symbolic_type(y) == ArraySymbolic()
@test all(symbolic_type.(collect(y)) .== (ScalarSymbolic(),))
Loading