Skip to content

Commit

Permalink
Merge pull request #1057 from JuliaSymbolics/myb/array
Browse files Browse the repository at this point in the history
Normalize `D(x)[i]` to `D(x[i])` and remove implicit broadcasting
  • Loading branch information
YingboMa authored Feb 21, 2024
2 parents b70f6d7 + 0216298 commit 3020097
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 11 deletions.
5 changes: 4 additions & 1 deletion src/array-lib.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ end
function Base.getindex(x::SymArray, idx...)
idx = unwrap.(idx)
meta = metadata(unwrap(x))
if shape(x) !== Unknown() && all(i -> i isa Integer, idx)
if istree(x) && (op = operation(x)) isa Operator
args = arguments(x)
return op(only(args)[idx...])
elseif shape(x) !== Unknown() && all(i -> i isa Integer, idx)
II = CartesianIndices(axes(x))
ii = CartesianIndex(idx)
@boundscheck begin
Expand Down
12 changes: 10 additions & 2 deletions src/diff.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
abstract type Operator <: Function end
propagate_shape(::Operator, x) = axes(x)

"""
$(TYPEDEF)
Expand Down Expand Up @@ -33,8 +34,15 @@ struct Differential <: Operator
x
Differential(x) = new(value(x))
end
(D::Differential)(x) = Term{symtype(x)}(D, [x])
(D::Differential)(x::Num) = Num(D(value(x)))
function (D::Differential)(x)
x = unwrap(x)
if isarraysymbolic(x)
array_term(D, x)
else
term(D, x)
end
end
(D::Differential)(x::Union{Num, Arr}) = wrap(D(unwrap(x)))
(D::Differential)(x::Complex{Num}) = wrap(ComplexTerm{Real}(D(unwrap(real(x))), D(unwrap(imag(x)))))
SymbolicUtils.promote_symtype(::Differential, T) = T

Expand Down
10 changes: 3 additions & 7 deletions src/equations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ function Base.show(io::IO, eq::Equation)
end
end

scalarize(eq::Equation) = scalarize(eq.lhs) ~ scalarize(eq.rhs)
scalarize(eq::Equation) = scalarize(eq.lhs) .~ scalarize(eq.rhs)
SymbolicUtils.simplify(x::Equation; kw...) = simplify(x.lhs; kw...) ~ simplify(x.rhs; kw...)
# ambiguity
for T in [:Pair, :Any]
Expand Down Expand Up @@ -156,12 +156,8 @@ julia> A .~ 3x
```
"""
function Base.:~(lhs, rhs)
if isarraysymbolic(lhs) || isarraysymbolic(rhs)
if isarraysymbolic(lhs) && isarraysymbolic(rhs)
lhs .~ rhs
else
throw(ArgumentError("Cannot equate an array with a scalar. Please use broadcast `.~`."))
end
if (isarraysymbolic(lhs) || isarraysymbolic(rhs)) && ((sl = size(lhs)) != (sr = size(rhs)))
throw(ArgumentError("Cannot equate an array of different sizes. Got $sl and $sr."))
else
Equation(lhs, rhs)
end
Expand Down
6 changes: 5 additions & 1 deletion test/arrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,12 @@ getdef(v) = getmetadata(v, Symbolics.VariableDefaultValue)
b[3] * A[1, 3])))

D = Differential(t)
@test isequal(collect(D.(x) ~ x), map(i -> D(x[i]) ~ x[i], eachindex(x)))
@test isequal(collect(D.(x) .~ x), map(i -> D(x[i]) ~ x[i], eachindex(x)))
@test_throws ArgumentError A ~ t
@test isequal(D(x[1]), D(x)[1])
a = Symbolics.unwrap(D(x)[1])
@test Symbolics.operation(a) == D
@test isequal(only(Symbolics.arguments(a)), Symbolics.unwrap(x[1]))

# #448
@test isequal(Symbolics.scalarize(u + u), [2u[1]])
Expand Down

0 comments on commit 3020097

Please sign in to comment.