diff --git a/src/array-lib.jl b/src/array-lib.jl index 65e87ca7a..cc5923032 100644 --- a/src/array-lib.jl +++ b/src/array-lib.jl @@ -20,7 +20,10 @@ end function Base.getindex(x::SymArray, idx...) idx = unwrap.(idx) meta = metadata(unwrap(x)) - if shape(x) !== Unknown() && all(i -> i isa Integer, idx) + if istree(x) && (op = operation(x)) isa Operator + args = arguments(x) + return op(only(args)[idx...]) + elseif shape(x) !== Unknown() && all(i -> i isa Integer, idx) II = CartesianIndices(axes(x)) ii = CartesianIndex(idx) @boundscheck begin diff --git a/src/diff.jl b/src/diff.jl index 1ec0908a2..41bfc458a 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -1,4 +1,5 @@ abstract type Operator <: Function end +propagate_shape(::Operator, x) = axes(x) """ $(TYPEDEF) @@ -33,8 +34,15 @@ struct Differential <: Operator x Differential(x) = new(value(x)) end -(D::Differential)(x) = Term{symtype(x)}(D, [x]) -(D::Differential)(x::Num) = Num(D(value(x))) +function (D::Differential)(x) + x = unwrap(x) + if isarraysymbolic(x) + array_term(D, x) + else + term(D, x) + end +end +(D::Differential)(x::Union{Num, Arr}) = wrap(D(unwrap(x))) (D::Differential)(x::Complex{Num}) = wrap(ComplexTerm{Real}(D(unwrap(real(x))), D(unwrap(imag(x))))) SymbolicUtils.promote_symtype(::Differential, T) = T diff --git a/src/equations.jl b/src/equations.jl index 842373151..9739f51de 100644 --- a/src/equations.jl +++ b/src/equations.jl @@ -107,7 +107,7 @@ function Base.show(io::IO, eq::Equation) end end -scalarize(eq::Equation) = scalarize(eq.lhs) ~ scalarize(eq.rhs) +scalarize(eq::Equation) = scalarize(eq.lhs) .~ scalarize(eq.rhs) SymbolicUtils.simplify(x::Equation; kw...) = simplify(x.lhs; kw...) ~ simplify(x.rhs; kw...) # ambiguity for T in [:Pair, :Any] @@ -156,12 +156,8 @@ julia> A .~ 3x ``` """ function Base.:~(lhs, rhs) - if isarraysymbolic(lhs) || isarraysymbolic(rhs) - if isarraysymbolic(lhs) && isarraysymbolic(rhs) - lhs .~ rhs - else - throw(ArgumentError("Cannot equate an array with a scalar. Please use broadcast `.~`.")) - end + if (isarraysymbolic(lhs) || isarraysymbolic(rhs)) && ((sl = size(lhs)) != (sr = size(rhs))) + throw(ArgumentError("Cannot equate an array of different sizes. Got $sl and $sr.")) else Equation(lhs, rhs) end diff --git a/test/arrays.jl b/test/arrays.jl index 925ed1459..24497a0a1 100644 --- a/test/arrays.jl +++ b/test/arrays.jl @@ -93,8 +93,12 @@ getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue) b[3] * A[1, 3]))) D = Differential(t) - @test isequal(collect(D.(x) ~ x), map(i -> D(x[i]) ~ x[i], eachindex(x))) + @test isequal(collect(D.(x) .~ x), map(i -> D(x[i]) ~ x[i], eachindex(x))) @test_throws ArgumentError A ~ t + @test isequal(D(x[1]), D(x)[1]) + a = Symbolics.unwrap(D(x)[1]) + @test Symbolics.operation(a) == D + @test isequal(only(Symbolics.arguments(a)), Symbolics.unwrap(x[1])) # #448 @test isequal(Symbolics.scalarize(u + u), [2u[1]])