Skip to content

Commit

Permalink
Merge pull request #1125 from JuliaSymbolics/myb/unwrap
Browse files Browse the repository at this point in the history
Fix array ops overloading on mixing symbolic arrays and arrays of symbolics
  • Loading branch information
ChrisRackauckas authored Jul 5, 2024
2 parents dd848d9 + a4b1920 commit 84f00ee
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
31 changes: 22 additions & 9 deletions src/array-lib.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Base: getindex
inner_unwrap(x) = x isa AbstractArray ? unwrap.(x) : x

##### getindex #####
struct GetindexPosthookCtx end
Expand Down Expand Up @@ -167,20 +168,21 @@ isonedim(x, i) = shape(x) == Unknown() ? false : isone(size(x, i))

function Broadcast.copy(bc::Broadcast.Broadcasted{SymBroadcast})
# Do the thing here
ndim = mapfoldl(ndims, max, bc.args, init=0)
args = inner_unwrap.(bc.args)
ndim = mapfoldl(ndims, max, args, init=0)
subscripts = makesubscripts(ndim)

onedim_count = mapreduce(+, bc.args) do x
onedim_count = mapreduce(+, args) do x
if ndims(x) != 0
map(i -> isonedim(x, i) ? 1 : 0, 1:ndim)
else
map(i -> 1, 1:ndim)
end
end

extruded = map(x -> x < length(bc.args), onedim_count)
extruded = map(x -> x < length(args), onedim_count)

expr_args′ = map(bc.args) do x
expr_args′ = map(args) do x
if ndims(x) != 0
subs = map(i -> extruded[i] && isonedim(x, i) ?
1 : subscripts[i], 1:ndims(x))
Expand All @@ -194,8 +196,8 @@ function Broadcast.copy(bc::Broadcast.Broadcasted{SymBroadcast})
expr = term(bc.f, expr_args′...) # Imagine x .=> y -- if you don't have a term
# then you get pairs, and index matcher cannot
# recurse into pairs
Atype = propagate_atype(broadcast, bc.f, bc.args...)
args = map(x -> x isa Base.RefValue ? Term{Any}(Ref, [x[]]) : x, bc.args)
Atype = propagate_atype(broadcast, bc.f, args...)
args = map(x -> x isa Base.RefValue ? Term{Any}(Ref, [x[]]) : x, args)
ArrayOp(Atype{symtype(expr),ndim},
(subscripts...,),
expr,
Expand Down Expand Up @@ -261,17 +263,22 @@ end

isadjointvec(A::ArrayOp) = isadjointvec(A.term)

__symtype(x::Type{<:Symbolic{T}}) where T = T
function symeltype(A)
T = eltype(A)
T <: Symbolic ? __symtype(T) : T
end
# TODO: add more such methods
function getindex(A::AbstractArray, i::Symbolic{<:Integer}, ii::Symbolic{<:Integer}...)
Term{eltype(A)}(getindex, [A, i, ii...])
Term{symeltype(A)}(getindex, [A, i, ii...])
end

function getindex(A::AbstractArray, i::Int, j::Symbolic{<:Integer})
Term{eltype(A)}(getindex, [A, i, j])
Term{symeltype(A)}(getindex, [A, i, j])
end

function getindex(A::AbstractArray, j::Symbolic{<:Integer}, i::Int)
Term{eltype(A)}(getindex, [A, j, i])
Term{symeltype(A)}(getindex, [A, j, i])
end

function getindex(A::Arr, i::Int, j::Symbolic{<:Integer})
Expand All @@ -283,6 +290,8 @@ function getindex(A::Arr, j::Symbolic{<:Integer}, i::Int)
end

function _matmul(A, B)
A = inner_unwrap(A)
B = inner_unwrap(B)
@syms i::Int j::Int k::Int
if isadjointvec(A)
op = operation(A.term)
Expand All @@ -295,6 +304,8 @@ end
@wrapped (*)(A::AbstractVector, B::AbstractMatrix) = _matmul(A, B)

function _matvec(A, b)
A = inner_unwrap(A)
b = inner_unwrap(b)
@syms i::Int k::Int
sym_res = @arrayop (i,) A[i, k] * b[k] term=(A*b)
if isdot(A, b)
Expand All @@ -320,6 +331,8 @@ end
function _map(f, x, xs...)
N = ndims(x)
idx = makesubscripts(N)
x = inner_unwrap(x)
xs = inner_unwrap.(xs)

expr = f(map(a -> a[idx...], [x, xs...])...)

Expand Down
8 changes: 8 additions & 0 deletions test/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -252,3 +252,11 @@ let
false
end
end

using Symbolics: scalarize
@variables X[1:3, 1:3] x
sX = fill(x, 3, 3)
sx = fill(x, 3)
@test isequal(scalarize(X + sX), scalarize(X) + sX)
@test isequal(scalarize(X * sX), scalarize(X) * sX)
@test isequal(scalarize(X * sx), scalarize(X) * sx)

0 comments on commit 84f00ee

Please sign in to comment.