From 299dba74dc0a82500d661b1722988d4ade0200d7 Mon Sep 17 00:00:00 2001 From: Aayush Sabharwal Date: Thu, 5 Sep 2024 16:35:35 +0530 Subject: [PATCH] fix: implement `symbolic_type` for `ArrayOp` and `CallWithMetadata` --- src/arrays.jl | 2 ++ src/variable.jl | 2 ++ test/symbolic_indexing_interface_trait.jl | 8 ++++++-- 3 files changed, 10 insertions(+), 2 deletions(-) diff --git a/src/arrays.jl b/src/arrays.jl index 2f0e277dc..6df9fbffb 100644 --- a/src/arrays.jl +++ b/src/arrays.jl @@ -72,6 +72,8 @@ SymbolicUtils.sorted_arguments(s::ArrayOp) = sorted_arguments(s.term) shape(aop::ArrayOp) = aop.shape +SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.ArrayOp}) = ArraySymbolic() + const show_arrayop = Ref{Bool}(false) function Base.show(io::IO, aop::ArrayOp) if iscall(aop.term) && !show_arrayop[] diff --git a/src/variable.jl b/src/variable.jl index f4813fcc8..a8845171b 100644 --- a/src/variable.jl +++ b/src/variable.jl @@ -242,6 +242,8 @@ SymbolicUtils.Code.toexpr(x::CallWithMetadata, st) = SymbolicUtils.Code.toexpr(x CallWithMetadata(f) = CallWithMetadata(f, nothing) +SymbolicIndexingInterface.symbolic_type(::Type{<:CallWithMetadata}) = ScalarSymbolic() + function Base.show(io::IO, c::CallWithMetadata) show(io, c.f) print(io, "⋆") diff --git a/test/symbolic_indexing_interface_trait.jl b/test/symbolic_indexing_interface_trait.jl index 6c8af1457..15d38e328 100644 --- a/test/symbolic_indexing_interface_trait.jl +++ b/test/symbolic_indexing_interface_trait.jl @@ -2,14 +2,18 @@ using Symbolics using SymbolicUtils using SymbolicIndexingInterface -@test all(symbolic_type.([SymbolicUtils.BasicSymbolic, Symbolics.Num]) .== +@test all(symbolic_type.([SymbolicUtils.BasicSymbolic, Symbolics.Num, Symbolics.CallWithMetadata]) .== (ScalarSymbolic(),)) -@test symbolic_type(Symbolics.Arr) == ArraySymbolic() +@test all(symbolic_type.([Symbolics.ArrayOp, Symbolics.Arr]) .== + (ArraySymbolic(),)) @variables x @test symbolic_type(x) == ScalarSymbolic() @variables y[1:3] @test symbolic_type(y) == ArraySymbolic() @test all(symbolic_type.(collect(y)) .== (ScalarSymbolic(),)) +@test symbolic_type(Symbolics.unwrap(y .* y)) == ArraySymbolic() +@variables z(..) +@test symbolic_type(z) == ScalarSymbolic() @variables x y z subs = Dict(x => 0.1, y => 2z)