From b7f8b5c79e918c638d747e6baf5b6384ec988911 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Thu, 15 Feb 2024 17:32:43 -0500 Subject: [PATCH 1/3] Normalize `D(x)[i]` to `D(x[i])` and remove implicit broadcasting --- src/array-lib.jl | 5 ++++- src/diff.jl | 10 +++++++++- src/equations.jl | 10 +++------- 3 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/array-lib.jl b/src/array-lib.jl index 65e87ca7a..cc5923032 100644 --- a/src/array-lib.jl +++ b/src/array-lib.jl @@ -20,7 +20,10 @@ end function Base.getindex(x::SymArray, idx...) idx = unwrap.(idx) meta = metadata(unwrap(x)) - if shape(x) !== Unknown() && all(i -> i isa Integer, idx) + if istree(x) && (op = operation(x)) isa Operator + args = arguments(x) + return op(only(args)[idx...]) + elseif shape(x) !== Unknown() && all(i -> i isa Integer, idx) II = CartesianIndices(axes(x)) ii = CartesianIndex(idx) @boundscheck begin diff --git a/src/diff.jl b/src/diff.jl index 1ec0908a2..15309d7e7 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -1,4 +1,5 @@ abstract type Operator <: Function end +propagate_shape(::Operator, x) = axes(x) """ $(TYPEDEF) @@ -33,7 +34,14 @@ struct Differential <: Operator x Differential(x) = new(value(x)) end -(D::Differential)(x) = Term{symtype(x)}(D, [x]) +function (D::Differential)(x) + x = unwrap(x) + if isarraysymbolic(x) + wrap(array_term(D, x)) + else + wrap(term(D, x)) + end +end (D::Differential)(x::Num) = Num(D(value(x))) (D::Differential)(x::Complex{Num}) = wrap(ComplexTerm{Real}(D(unwrap(real(x))), D(unwrap(imag(x))))) SymbolicUtils.promote_symtype(::Differential, T) = T diff --git a/src/equations.jl b/src/equations.jl index 5736afcc5..b036ba51a 100644 --- a/src/equations.jl +++ b/src/equations.jl @@ -50,7 +50,7 @@ function Base.show(io::IO, eq::Equation) end end -scalarize(eq::Equation) = scalarize(eq.lhs) ~ scalarize(eq.rhs) +scalarize(eq::Equation) = scalarize(eq.lhs) .~ scalarize(eq.rhs) SymbolicUtils.simplify(x::Equation; kw...) = simplify(x.lhs; kw...) ~ simplify(x.rhs; kw...) # ambiguity for T in [:Pair, :Any] @@ -99,12 +99,8 @@ julia> A .~ 3x ``` """ function Base.:~(lhs, rhs) - if isarraysymbolic(lhs) || isarraysymbolic(rhs) - if isarraysymbolic(lhs) && isarraysymbolic(rhs) - lhs .~ rhs - else - throw(ArgumentError("Cannot equate an array with a scalar. Please use broadcast `.~`.")) - end + if (isarraysymbolic(lhs) || isarraysymbolic(rhs)) && ((sl = size(lhs)) != (sr = size(rhs))) + throw(ArgumentError("Cannot equate an array of different sizes. Got $sl and $sr.")) else Equation(lhs, rhs) end From 0a90248794f5474a639c77e5dd04ac211c71c954 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Mon, 19 Feb 2024 16:54:15 -0500 Subject: [PATCH 2/3] Fix D instantiation --- src/diff.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/diff.jl b/src/diff.jl index 15309d7e7..41bfc458a 100644 --- a/src/diff.jl +++ b/src/diff.jl @@ -37,12 +37,12 @@ end function (D::Differential)(x) x = unwrap(x) if isarraysymbolic(x) - wrap(array_term(D, x)) + array_term(D, x) else - wrap(term(D, x)) + term(D, x) end end -(D::Differential)(x::Num) = Num(D(value(x))) +(D::Differential)(x::Union{Num, Arr}) = wrap(D(unwrap(x))) (D::Differential)(x::Complex{Num}) = wrap(ComplexTerm{Real}(D(unwrap(real(x))), D(unwrap(imag(x))))) SymbolicUtils.promote_symtype(::Differential, T) = T From 02162983710680952a273c893ab34f449e9ec4d6 Mon Sep 17 00:00:00 2001 From: Yingbo Ma Date: Wed, 21 Feb 2024 12:50:30 -0500 Subject: [PATCH 3/3] Update tests --- test/arrays.jl | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/test/arrays.jl b/test/arrays.jl index 925ed1459..24497a0a1 100644 --- a/test/arrays.jl +++ b/test/arrays.jl @@ -93,8 +93,12 @@ getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue) b[3] * A[1, 3]))) D = Differential(t) - @test isequal(collect(D.(x) ~ x), map(i -> D(x[i]) ~ x[i], eachindex(x))) + @test isequal(collect(D.(x) .~ x), map(i -> D(x[i]) ~ x[i], eachindex(x))) @test_throws ArgumentError A ~ t + @test isequal(D(x[1]), D(x)[1]) + a = Symbolics.unwrap(D(x)[1]) + @test Symbolics.operation(a) == D + @test isequal(only(Symbolics.arguments(a)), Symbolics.unwrap(x[1])) # #448 @test isequal(Symbolics.scalarize(u + u), [2u[1]])