diff --git a/src/array-lib.jl b/src/array-lib.jl index 242d8eff8..74f11e420 100644 --- a/src/array-lib.jl +++ b/src/array-lib.jl @@ -1,4 +1,5 @@ import Base: getindex +inner_unwrap(x) = x isa AbstractArray ? unwrap.(x) : x ##### getindex ##### struct GetindexPosthookCtx end @@ -167,10 +168,11 @@ isonedim(x, i) = shape(x) == Unknown() ? false : isone(size(x, i)) function Broadcast.copy(bc::Broadcast.Broadcasted{SymBroadcast}) # Do the thing here - ndim = mapfoldl(ndims, max, bc.args, init=0) + args = inner_unwrap.(bc.args) + ndim = mapfoldl(ndims, max, args, init=0) subscripts = makesubscripts(ndim) - onedim_count = mapreduce(+, bc.args) do x + onedim_count = mapreduce(+, args) do x if ndims(x) != 0 map(i -> isonedim(x, i) ? 1 : 0, 1:ndim) else @@ -178,9 +180,9 @@ function Broadcast.copy(bc::Broadcast.Broadcasted{SymBroadcast}) end end - extruded = map(x -> x < length(bc.args), onedim_count) + extruded = map(x -> x < length(args), onedim_count) - expr_args′ = map(bc.args) do x + expr_args′ = map(args) do x if ndims(x) != 0 subs = map(i -> extruded[i] && isonedim(x, i) ? 1 : subscripts[i], 1:ndims(x)) @@ -194,8 +196,8 @@ function Broadcast.copy(bc::Broadcast.Broadcasted{SymBroadcast}) expr = term(bc.f, expr_args′...) # Imagine x .=> y -- if you don't have a term # then you get pairs, and index matcher cannot # recurse into pairs - Atype = propagate_atype(broadcast, bc.f, bc.args...) - args = map(x -> x isa Base.RefValue ? Term{Any}(Ref, [x[]]) : x, bc.args) + Atype = propagate_atype(broadcast, bc.f, args...) + args = map(x -> x isa Base.RefValue ? Term{Any}(Ref, [x[]]) : x, args) ArrayOp(Atype{symtype(expr),ndim}, (subscripts...,), expr, @@ -261,17 +263,22 @@ end isadjointvec(A::ArrayOp) = isadjointvec(A.term) +__symtype(x::Type{<:Symbolic{T}}) where T = T +function symeltype(A) + T = eltype(A) + T <: Symbolic ? __symtype(T) : T +end # TODO: add more such methods function getindex(A::AbstractArray, i::Symbolic{<:Integer}, ii::Symbolic{<:Integer}...) - Term{eltype(A)}(getindex, [A, i, ii...]) + Term{symeltype(A)}(getindex, [A, i, ii...]) end function getindex(A::AbstractArray, i::Int, j::Symbolic{<:Integer}) - Term{eltype(A)}(getindex, [A, i, j]) + Term{symeltype(A)}(getindex, [A, i, j]) end function getindex(A::AbstractArray, j::Symbolic{<:Integer}, i::Int) - Term{eltype(A)}(getindex, [A, j, i]) + Term{symeltype(A)}(getindex, [A, j, i]) end function getindex(A::Arr, i::Int, j::Symbolic{<:Integer}) @@ -283,6 +290,8 @@ function getindex(A::Arr, j::Symbolic{<:Integer}, i::Int) end function _matmul(A, B) + A = inner_unwrap(A) + B = inner_unwrap(B) @syms i::Int j::Int k::Int if isadjointvec(A) op = operation(A.term) @@ -295,6 +304,8 @@ end @wrapped (*)(A::AbstractVector, B::AbstractMatrix) = _matmul(A, B) function _matvec(A, b) + A = inner_unwrap(A) + b = inner_unwrap(b) @syms i::Int k::Int sym_res = @arrayop (i,) A[i, k] * b[k] term=(A*b) if isdot(A, b) @@ -320,6 +331,8 @@ end function _map(f, x, xs...) N = ndims(x) idx = makesubscripts(N) + x = inner_unwrap(x) + xs = inner_unwrap.(xs) expr = f(map(a -> a[idx...], [x, xs...])...) diff --git a/test/overloads.jl b/test/overloads.jl index 84fcc3ecd..0bfbc0a90 100644 --- a/test/overloads.jl +++ b/test/overloads.jl @@ -252,3 +252,11 @@ let false end end + +using Symbolics: scalarize +@variables X[1:3, 1:3] x +sX = fill(x, 3, 3) +sx = fill(x, 3) +@test isequal(scalarize(X + sX), scalarize(X) + sX) +@test isequal(scalarize(X * sX), scalarize(X) * sX) +@test isequal(scalarize(X * sx), scalarize(X) * sx)