Skip to content

Commit

Permalink
fix: implement symbolic_type for ArrayOp and CallWithMetadata
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Sep 5, 2024
1 parent fb0265e commit 299dba7
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 2 deletions.
2 changes: 2 additions & 0 deletions src/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ SymbolicUtils.sorted_arguments(s::ArrayOp) = sorted_arguments(s.term)

shape(aop::ArrayOp) = aop.shape

SymbolicIndexingInterface.symbolic_type(::Type{<:Symbolics.ArrayOp}) = ArraySymbolic()

const show_arrayop = Ref{Bool}(false)
function Base.show(io::IO, aop::ArrayOp)
if iscall(aop.term) && !show_arrayop[]
Expand Down
2 changes: 2 additions & 0 deletions src/variable.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,8 @@ SymbolicUtils.Code.toexpr(x::CallWithMetadata, st) = SymbolicUtils.Code.toexpr(x

CallWithMetadata(f) = CallWithMetadata(f, nothing)

SymbolicIndexingInterface.symbolic_type(::Type{<:CallWithMetadata}) = ScalarSymbolic()

function Base.show(io::IO, c::CallWithMetadata)
show(io, c.f)
print(io, "")
Expand Down
8 changes: 6 additions & 2 deletions test/symbolic_indexing_interface_trait.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,18 @@ using Symbolics
using SymbolicUtils
using SymbolicIndexingInterface

@test all(symbolic_type.([SymbolicUtils.BasicSymbolic, Symbolics.Num]) .==
@test all(symbolic_type.([SymbolicUtils.BasicSymbolic, Symbolics.Num, Symbolics.CallWithMetadata]) .==
(ScalarSymbolic(),))
@test symbolic_type(Symbolics.Arr) == ArraySymbolic()
@test all(symbolic_type.([Symbolics.ArrayOp, Symbolics.Arr]) .==
(ArraySymbolic(),))
@variables x
@test symbolic_type(x) == ScalarSymbolic()
@variables y[1:3]
@test symbolic_type(y) == ArraySymbolic()
@test all(symbolic_type.(collect(y)) .== (ScalarSymbolic(),))
@test symbolic_type(Symbolics.unwrap(y .* y)) == ArraySymbolic()
@variables z(..)
@test symbolic_type(z) == ScalarSymbolic()

@variables x y z
subs = Dict(x => 0.1, y => 2z)
Expand Down

0 comments on commit 299dba7

Please sign in to comment.