Skip to content

Commit

Permalink
refactor: overload getname from SymbolicIndexingInterface
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Nov 30, 2023
1 parent 59567b4 commit 3afe412
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 3 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
TreeViews = "a2a6695c-b41b-5b7d-aed9-dbfdeacea5d7"

Expand Down Expand Up @@ -71,6 +72,7 @@ SciMLBase = "1.8, 2"
Setfield = "0.7, 0.8, 1"
SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2"
StaticArrays = "1.1"
SymbolicIndexingInterface = "0.3"
SymbolicUtils = "1.4"
TreeViews = "0.3"
julia = "1.6"
Expand Down
2 changes: 2 additions & 0 deletions src/Symbolics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ using PrecompileTools
using RuntimeGeneratedFunctions
using SciMLBase, IfElse
using MacroTools

using SymbolicIndexingInterface
end
@reexport using SymbolicUtils
RuntimeGeneratedFunctions.init(@__MODULE__)
Expand Down
5 changes: 4 additions & 1 deletion src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -406,7 +406,10 @@ end

getsource(x, val=_fail) = getmetadata(unwrap(x), VariableSource, val)

getname(x, val=_fail) = _getname(unwrap(x), val)
SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Num}) = ScalarSymbolic()
SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Arr}) = ArraySymbolic()

SymbolicIndexingInterface.getname(x, val=_fail) = _getname(unwrap(x), val)

function getparent(x, val=_fail)
maybe_parent = getmetadata(x, Symbolics.GetindexParent, nothing)
Expand Down
4 changes: 2 additions & 2 deletions src/wrapper-types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ function set_where(subt, supert)
Expr(:where, supert, Ts...)
end

getname(x::Symbol) = x
SymbolicIndexingInterface.getname(x::Symbol) = x

function getname(x::Expr)
function SymbolicIndexingInterface.getname(x::Expr)
@assert x.head == :curly
return x.args[1]
end
Expand Down

0 comments on commit 3afe412

Please sign in to comment.