From 3afe412e0cf141578e09d9ed839228351c6eac13 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Tue, 14 Nov 2023 20:58:18 +0530 Subject: [PATCH] refactor: overload getname from SymbolicIndexingInterface --- Project.toml | 2 ++ src/Symbolics.jl | 2 ++ src/variable.jl | 5 ++++- src/wrapper-types.jl | 4 ++-- 4 files changed, 10 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 8aca8c0dc..d7d14695f 100644 --- a/Project.toml +++ b/Project.toml @@ -34,6 +34,7 @@ Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b" TreeViews = "a2a6695c-b41b-5b7d-aed9-dbfdeacea5d7" @@ -71,6 +72,7 @@ SciMLBase = "1.8, 2" Setfield = "0.7, 0.8, 1" SpecialFunctions = "0.7, 0.8, 0.9, 0.10, 1.0, 2" StaticArrays = "1.1" +SymbolicIndexingInterface = "0.3" SymbolicUtils = "1.4" TreeViews = "0.3" julia = "1.6" diff --git a/src/Symbolics.jl b/src/Symbolics.jl index f9565f4dc..e0afeaa75 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -35,6 +35,8 @@ using PrecompileTools using RuntimeGeneratedFunctions using SciMLBase, IfElse using MacroTools + + using SymbolicIndexingInterface end @reexport using SymbolicUtils RuntimeGeneratedFunctions.init(@__MODULE__) diff --git a/src/variable.jl b/src/variable.jl index 19f101503..4d09212ab 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -406,7 +406,10 @@ end getsource(x, val=_fail) = getmetadata(unwrap(x), VariableSource, val) -getname(x, val=_fail) = _getname(unwrap(x), val) +SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Num}) = ScalarSymbolic() +SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.Arr}) = ArraySymbolic() + +SymbolicIndexingInterface.getname(x, val=_fail) = _getname(unwrap(x), val) function getparent(x, val=_fail) maybe_parent = getmetadata(x, Symbolics.GetindexParent, nothing) diff --git a/src/wrapper-types.jl b/src/wrapper-types.jl index c36fc9eca..f1007dbd3 100644 --- a/src/wrapper-types.jl +++ b/src/wrapper-types.jl @@ -17,9 +17,9 @@ function set_where(subt, supert) Expr(:where, supert, Ts...) end -getname(x::Symbol) = x +SymbolicIndexingInterface.getname(x::Symbol) = x -function getname(x::Expr) +function SymbolicIndexingInterface.getname(x::Expr) @assert x.head == :curly return x.args[1] end