Skip to content

Commit

Permalink
JuliaFormatter.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Feb 14, 2024
1 parent 2fb44be commit 427392e
Showing 1 changed file with 14 additions and 8 deletions.
22 changes: 14 additions & 8 deletions src/blockdiagonals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,16 @@ BlockDiagonals.jl didn't cut it, so we're rolling our own.
TODO: Add a way to convert to a `BlockDiagonal`.
"""
struct MinimalAndFastBlockDiagonal{T<:Number, V<:AbstractMatrix{T}} <: AbstractMatrix{T}
struct MinimalAndFastBlockDiagonal{T<:Number,V<:AbstractMatrix{T}} <: AbstractMatrix{T}
blocks::Vector{V}
function MinimalAndFastBlockDiagonal{T, V}(blocks::Vector{V}) where {T, V<:AbstractMatrix{T}}
return new{T, V}(blocks)
function MinimalAndFastBlockDiagonal{T,V}(
blocks::Vector{V},
) where {T,V<:AbstractMatrix{T}}
return new{T,V}(blocks)
end
end
function MinimalAndFastBlockDiagonal(blocks::Vector{V}) where {T, V<:AbstractMatrix{T}}
return MinimalAndFastBlockDiagonal{T, V}(blocks)
function MinimalAndFastBlockDiagonal(blocks::Vector{V}) where {T,V<:AbstractMatrix{T}}
return MinimalAndFastBlockDiagonal{T,V}(blocks)
end
const MFBD = MinimalAndFastBlockDiagonal
blocks(B::MFBD) = B.blocks
Expand All @@ -33,10 +35,14 @@ function _block_indices(B::MFBD, i::Integer, j::Integer)
end
return p, i, j
end
Base.@propagate_inbounds function Base.getindex(B::MFBD{T}, i::Integer, j::Integer) where T
Base.@propagate_inbounds function Base.getindex(
B::MFBD{T},
i::Integer,
j::Integer,
) where {T}
p, i, j = _block_indices(B, i, j)
# if not in on-diagonal block `p` then value at `i, j` must be zero
@inbounds return p > 0 ? blocks(B)[p][i, end + j] : zero(T)
@inbounds return p > 0 ? blocks(B)[p][i, end+j] : zero(T)
end

Base.view(::MFBD, idxs...) =
Expand Down Expand Up @@ -104,7 +110,7 @@ end
_matmul!(
C::MFBD{T},
A::MFBD{T},
B::Adjoint{T, <:MFBD{T}},
B::Adjoint{T,<:MFBD{T}},
) where {T<:LinearAlgebra.BlasFloat} = begin
@assert length(C.blocks) == length(A.blocks) == length(B.parent.blocks)
@simd ivdep for i in eachindex(blocks(C))
Expand Down

0 comments on commit 427392e

Please sign in to comment.