Skip to content

Commit

Permalink
Fix the bad getindex for BlockDiag
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Feb 17, 2024
1 parent 367b0bf commit c85602a
Showing 1 changed file with 16 additions and 19 deletions.
35 changes: 16 additions & 19 deletions src/blockdiagonals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,30 +20,27 @@ blocks(B::BlockDiag) = B.blocks
nblocks(B::BlockDiag) = length(B.blocks)
size(B::BlockDiag) = mapreduce(size, ((a, b), (c, d)) -> (a + c, b + d), blocks(B))

function _block_indices(B::BlockDiag, i::Integer, j::Integer)
all((0, 0) .< (i, j) .<= size(B)) || throw(BoundsError(B, (i, j)))
# find the on-diagonal block `p` in column `j`
p = 0
@inbounds while j > 0
p += 1
j -= size(blocks(B)[p], 2)
end
# isempty to avoid reducing over an empty collection
@views @inbounds i -= isempty(1:(p-1)) ? 0 : sum(size.(blocks(B)[1:(p-1)], 1))
# if row `i` outside of block `p`, set `p` to place-holder value `-1`
if i <= 0 || i > size(blocks(B)[p], 2)
p = -1
end
return p, i, j
end
Base.@propagate_inbounds function Base.getindex(
B::BlockDiag{T},
i::Integer,
j::Integer,
) where {T}
p, i, j = _block_indices(B, i, j)
# if not in on-diagonal block `p` then value at `i, j` must be zero
@inbounds return p > 0 ? blocks(B)[p][i, end+j] : zero(T)
all((0, 0) .< (i, j) .<= size(B)) || throw(BoundsError(B, (i, j)))

p = 1
Si, Sj = size(blocks(B)[p])
while p <= nblocks(B)
if i <= Si && j <= Sj
return blocks(B)[p][i, j]
elseif (i <= Si && j > Sj) || (j <= Sj && i > Si)
return zero(T)
else
i -= Si
j -= Sj
p += 1
end
end
error("This shouldn't happen")
end

Base.view(::BlockDiag, idxs...) =
Expand Down

0 comments on commit c85602a

Please sign in to comment.