diff --git a/src/Symbolics.jl b/src/Symbolics.jl index b2d8155c6..4e9fd9703 100644 --- a/src/Symbolics.jl +++ b/src/Symbolics.jl @@ -175,6 +175,10 @@ for T in [Num, Complex{Num}] end end +for sType in [Pair, Vector, Dict] + @eval substitute(expr::Arr, s::$sType; kw...) = wrap(substituter(s)(unwrap(expr); kw...)) +end + function symbolics_to_sympy end export symbolics_to_sympy diff --git a/src/arrays.jl b/src/arrays.jl index ce6504eae..bd4a4699c 100644 --- a/src/arrays.jl +++ b/src/arrays.jl @@ -470,6 +470,8 @@ end Base.hash(x::Arr, u::UInt) = hash(unwrap(x), u) Base.isequal(a::Arr, b::Arr) = isequal(unwrap(a), unwrap(b)) +Base.isequal(a::Arr, b::Symbolic) = isequal(unwrap(a), b) +Base.isequal(a::Symbolic, b::Arr) = isequal(b, a) ArrayOp(x::Arr) = unwrap(x) diff --git a/src/num.jl b/src/num.jl index 7244801a7..0c28ef903 100644 --- a/src/num.jl +++ b/src/num.jl @@ -81,7 +81,7 @@ substitute(expr, s::Vector; kw...) = substituter(s)(expr; kw...) substituter(pair::Pair) = substituter((pair,)) function substituter(pairs) dict = Dict(value(k) => value(v) for (k, v) in pairs) - (expr; kw...) -> SymbolicUtils.substitute(expr, dict; kw...) + (expr; kw...) -> SymbolicUtils.substitute(value(expr), dict; kw...) end SymbolicUtils.symtype(n::Num) = symtype(value(n)) diff --git a/test/arrays.jl b/test/arrays.jl index 24497a0a1..42a383cc5 100644 --- a/test/arrays.jl +++ b/test/arrays.jl @@ -382,6 +382,27 @@ end @test isequal(collect(dtv), collect(A .* u .- u.^2 .* v .+ alpha .* lapv)) end +@testset "Unwrapped array equality" begin + @variables x[1:3] + ux = unwrap(x) + @test isequal(x, x) + @test isequal(x, ux) + @test isequal(ux, x) +end + +@testset "Array expression substitution" begin + @variables x[1:3] p[1:3, 1:3] + bar(x, p) = p * x + @register_array_symbolic bar(x::AbstractVector, p::AbstractMatrix) begin + size = size(x) + eltype = promote_type(eltype(x), eltype(p)) + end + + @test isequal(substitute(bar(x, p), x => ones(3)), bar(ones(3), p)) + @test isequal(substitute(bar(x, p), Dict(x => ones(3), p => ones(3, 3))), wrap(3ones(3))) + @test isequal(substitute(bar(x, p), [x => ones(3), p => ones(3, 3)]), wrap(3ones(3))) +end + @testset "Partial array substitution" begin @variables x[1:3] A[1:2, 1:2, 1:2]