Skip to content

Commit

Permalink
Fix getindex symtype
Browse files Browse the repository at this point in the history
  • Loading branch information
YingboMa committed Apr 25, 2024
1 parent a915561 commit c8fa6f1
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 12 deletions.
28 changes: 18 additions & 10 deletions src/array-lib.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import Base: getindex
inner_unwrap(x) = x isa AbstractArray ? unwrap.(x) : x

##### getindex #####
struct GetindexPosthookCtx end
Expand Down Expand Up @@ -167,20 +168,21 @@ isonedim(x, i) = shape(x) == Unknown() ? false : isone(size(x, i))

function Broadcast.copy(bc::Broadcast.Broadcasted{SymBroadcast})
# Do the thing here
ndim = mapfoldl(ndims, max, bc.args, init=0)
args = inner_unwrap.(bc.args)
ndim = mapfoldl(ndims, max, args, init=0)
subscripts = makesubscripts(ndim)

onedim_count = mapreduce(+, bc.args) do x
onedim_count = mapreduce(+, args) do x
if ndims(x) != 0
map(i -> isonedim(x, i) ? 1 : 0, 1:ndim)
else
map(i -> 1, 1:ndim)
end
end

extruded = map(x -> x < length(bc.args), onedim_count)
extruded = map(x -> x < length(args), onedim_count)

expr_args′ = map(bc.args) do x
expr_args′ = map(args) do x
if ndims(x) != 0
subs = map(i -> extruded[i] && isonedim(x, i) ?
1 : subscripts[i], 1:ndims(x))
Expand All @@ -194,8 +196,8 @@ function Broadcast.copy(bc::Broadcast.Broadcasted{SymBroadcast})
expr = term(bc.f, expr_args′...) # Imagine x .=> y -- if you don't have a term
# then you get pairs, and index matcher cannot
# recurse into pairs
Atype = propagate_atype(broadcast, bc.f, bc.args...)
args = map(x -> x isa Base.RefValue ? Term{Any}(Ref, [x[]]) : x, bc.args)
Atype = propagate_atype(broadcast, bc.f, args...)
args = map(x -> x isa Base.RefValue ? Term{Any}(Ref, [x[]]) : x, args)
ArrayOp(Atype{symtype(expr),ndim},
(subscripts...,),
expr,
Expand Down Expand Up @@ -261,17 +263,22 @@ end

isadjointvec(A::ArrayOp) = isadjointvec(A.term)

__symtype(x::Type{<:Symbolic{T}}) where T = T
function symeltype(A)
T = eltype(A)
T <: Symbolic ? __symtype(T) : T
end
# TODO: add more such methods
function getindex(A::AbstractArray, i::Symbolic{<:Integer}, ii::Symbolic{<:Integer}...)
Term{eltype(A)}(getindex, [A, i, ii...])
Term{symeltype(A)}(getindex, [A, i, ii...])
end

function getindex(A::AbstractArray, i::Int, j::Symbolic{<:Integer})
Term{eltype(A)}(getindex, [A, i, j])
Term{symeltype(A)}(getindex, [A, i, j])
end

function getindex(A::AbstractArray, j::Symbolic{<:Integer}, i::Int)
Term{eltype(A)}(getindex, [A, j, i])
Term{symeltype(A)}(getindex, [A, j, i])
end

function getindex(A::Arr, i::Int, j::Symbolic{<:Integer})
Expand All @@ -282,7 +289,6 @@ function getindex(A::Arr, j::Symbolic{<:Integer}, i::Int)
wrap(unwrap(A)[j, i])
end

inner_unwrap(x) = x isa AbstractArray ? unwrap.(x) : x
function _matmul(A, B)
A = inner_unwrap(A)
B = inner_unwrap(B)
Expand Down Expand Up @@ -325,6 +331,8 @@ end
function _map(f, x, xs...)
N = ndims(x)
idx = makesubscripts(N)
x = inner_unwrap(x)
xs = inner_unwrap.(xs)

expr = f(map(a -> a[idx...], [x, xs...])...)

Expand Down
4 changes: 2 additions & 2 deletions test/overloads.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,6 @@ using Symbolics: scalarize
@variables X[1:3, 1:3] x
sX = fill(x, 3, 3)
sx = fill(x, 3)
@test isequal(scalarize(X + XX), scalarize(X) + XX)
@test isequal(scalarize(X * XX), scalarize(X) * XX)
@test isequal(scalarize(X + sX), scalarize(X) + sX)
@test isequal(scalarize(X * sX), scalarize(X) * sX)
@test isequal(scalarize(X * sx), scalarize(X) * sx)

0 comments on commit c8fa6f1

Please sign in to comment.