Skip to content

Commit

Permalink
Implement my own BlockDiag type
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Feb 14, 2024
1 parent efeaaad commit 2fb44be
Show file tree
Hide file tree
Showing 15 changed files with 134 additions and 74 deletions.
1 change: 0 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ version = "0.15.0"

[deps]
ArrayAllocators = "c9d4266f-a5cb-439d-837c-c97b191379f5"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
Expand Down
1 change: 0 additions & 1 deletion src/ProbNumDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ using ExponentialUtilities
using Octavian
using FastGaussQuadrature
import Kronecker
using BlockDiagonals
using ArrayAllocators
using FiniteHorizonGramians
using FillArrays
Expand Down
2 changes: 1 addition & 1 deletion src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ struct EK0{PT,DT,IT,RT,CF} <: AbstractEK
pn_observation_noise::RT=nothing,
covariance_factorization::CF=covariance_structure(EK0, prior, diffusionmodel),
) where {PT,DT,IT,RT,CF} = begin
ekargcheck(EK0; diffusionmodel, pn_observation_noise)
ekargcheck(EK0; diffusionmodel, pn_observation_noise, covariance_factorization)
new{PT,DT,IT,RT,CF}(
prior, diffusionmodel, smooth, initialization, pn_observation_noise,
covariance_factorization)
Expand Down
118 changes: 90 additions & 28 deletions src/blockdiagonals.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,60 @@
Base.view(::BlockDiagonal, idxs...) =
throw(MethodError("BlockDiagonal does not support views"))
"""
BlockDiagonals.jl didn't cut it, so we're rolling our own.
TODO: Add a way to convert to a `BlockDiagonal`.
"""
struct MinimalAndFastBlockDiagonal{T<:Number, V<:AbstractMatrix{T}} <: AbstractMatrix{T}
blocks::Vector{V}
function MinimalAndFastBlockDiagonal{T, V}(blocks::Vector{V}) where {T, V<:AbstractMatrix{T}}
return new{T, V}(blocks)
end
end
function MinimalAndFastBlockDiagonal(blocks::Vector{V}) where {T, V<:AbstractMatrix{T}}
return MinimalAndFastBlockDiagonal{T, V}(blocks)
end
const MFBD = MinimalAndFastBlockDiagonal
blocks(B::MFBD) = B.blocks
nblocks(B::MFBD) = length(B.blocks)
size(B::MFBD) = (sum(size.(blocks(B), 1)), sum(size.(blocks(B), 2)))

function _block_indices(B::MFBD, i::Integer, j::Integer)
all((0, 0) .< (i, j) .<= size(B)) || throw(BoundsError(B, (i, j)))
# find the on-diagonal block `p` in column `j`
p = 0
@inbounds while j > 0
p += 1
j -= size(blocks(B)[p], 2)
end
# isempty to avoid reducing over an empty collection
@views @inbounds i -= isempty(1:(p-1)) ? 0 : sum(size.(blocks(B)[1:(p-1)], 1))
# if row `i` outside of block `p`, set `p` to place-holder value `-1`
if i <= 0 || i > size(blocks(B)[p], 2)
p = -1
end
return p, i, j
end
Base.@propagate_inbounds function Base.getindex(B::MFBD{T}, i::Integer, j::Integer) where T
p, i, j = _block_indices(B, i, j)
# if not in on-diagonal block `p` then value at `i, j` must be zero
@inbounds return p > 0 ? blocks(B)[p][i, end + j] : zero(T)
end

Base.view(::MFBD, idxs...) =
throw(ErrorException("`MinimalAndFastBlockDiagonal` does not support views!"))

copy(B::MFBD) = MFBD(copy.(blocks(B)))
copy!(B::MFBD, A::MFBD) = begin
@assert length(A.blocks) == length(B.blocks)
@simd ivdep for i in eachindex(blocks(B))
copy!(B.blocks[i], A.blocks[i])
end
return B
end

_matmul!(
C::BlockDiagonal{T},
A::BlockDiagonal{T},
B::BlockDiagonal{T},
C::MFBD{T},
A::MFBD{T},
B::MFBD{T},
) where {T<:LinearAlgebra.BlasFloat} = begin
@assert length(C.blocks) == length(A.blocks) == length(B.blocks)
@simd ivdep for i in eachindex(blocks(C))
Expand All @@ -14,9 +64,9 @@ _matmul!(
end

_matmul!(
C::BlockDiagonal{T},
A::BlockDiagonal{T},
B::BlockDiagonal{T},
C::MFBD{T},
A::MFBD{T},
B::MFBD{T},
alpha::Number,
beta::Number,
) where {T<:LinearAlgebra.BlasFloat} = begin
Expand All @@ -28,9 +78,9 @@ _matmul!(
end

_matmul!(
C::BlockDiagonal{T},
A::BlockDiagonal{T},
B::Adjoint{T,<:BlockDiagonal{T}},
C::MFBD{T},
A::MFBD{T},
B::Adjoint{T,<:MFBD{T}},
) where {T<:LinearAlgebra.BlasFloat} = begin
@assert length(C.blocks) == length(A.blocks) == length(B.parent.blocks)
@simd ivdep for i in eachindex(blocks(C))
Expand All @@ -40,9 +90,9 @@ _matmul!(
end

_matmul!(
C::BlockDiagonal{T},
A::Adjoint{T,<:BlockDiagonal{T}},
B::BlockDiagonal{T},
C::MFBD{T},
A::Adjoint{T,<:MFBD{T}},
B::MFBD{T},
) where {T<:LinearAlgebra.BlasFloat} = begin
@assert length(C.blocks) == length(A.parent.blocks) == length(B.blocks)
@simd ivdep for i in eachindex(blocks(C))
Expand All @@ -51,9 +101,21 @@ _matmul!(
return C
end

_matmul!(
C::MFBD{T},
A::MFBD{T},
B::Adjoint{T, <:MFBD{T}},
) where {T<:LinearAlgebra.BlasFloat} = begin
@assert length(C.blocks) == length(A.blocks) == length(B.parent.blocks)
@simd ivdep for i in eachindex(blocks(C))
@inbounds _matmul!(C.blocks[i], A.blocks[i], adjoint(B.parent.blocks[i]))
end
return C
end

_matmul!(
C::AbstractVector{T},
A::BlockDiagonal{T},
A::MFBD{T},
B::AbstractVector{T},
) where {T<:LinearAlgebra.BlasFloat} = begin
@assert size(A, 2) == length(B)
Expand All @@ -68,21 +130,21 @@ _matmul!(
return C
end

function BlockDiagonals.isequal_blocksizes(B1::BlockDiagonal, B2::BlockDiagonal)
@assert length(B1.blocks) == length(B2.blocks)
for i in eachindex(B1.blocks)
if size(B1.blocks[i]) != size(B2.blocks[i])
return false
end
end
return true
LinearAlgebra.rmul!(B::MFBD, n::Number) = @simd ivdep for i in eachindex(B.blocks)
rmul!(B.blocks[i], n)
end
LinearAlgebra.adjoint(B::MFBD) = Adjoint(B)

LinearAlgebra.rmul!(B::BlockDiagonal, n::Number) = @simd ivdep for i in eachindex(B.blocks)
rmul!(B.blocks[i], n)
Base.:*(A::MFBD, B::MFBD) = begin
@assert length(A.blocks) == length(B.blocks)
return MFBD([blocks(A)[i] * blocks(B)[i] for i in eachindex(B.blocks)])
end
LinearAlgebra.adjoint(B::BlockDiagonal) = Adjoint(B)
Base.:*(A::Adjoint{T,<:BlockDiagonal}, B::BlockDiagonal) where {T} = begin
Base.:*(A::Adjoint{T,<:MFBD}, B::MFBD) where {T} = begin
@assert length(A.parent.blocks) == length(B.blocks)
return BlockDiagonal([A.parent.blocks[i]' * B.blocks[i] for i in eachindex(B.blocks)])
return MFBD([A.parent.blocks[i]' * B.blocks[i] for i in eachindex(B.blocks)])
end
Base.:*(A::MFBD, B::Adjoint{T,<:MFBD}) where {T} = begin
@assert length(A.blocks) == length(B.parent.blocks)
return MFBD([A.blocks[i] * B.parent.blocks[i]' for i in eachindex(B.parent.blocks)])
end
Base.:*(A::UniformScaling, B::MFBD) = MFBD([A * blocks(B)[i] for i in eachindex(B.blocks)])
2 changes: 1 addition & 1 deletion src/caches.jl
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ function OrdinaryDiffEq.alg_cache(
# Diffusion Model
diffmodel = alg.diffusionmodel
initdiff = initial_diffusion(diffmodel, d, q, uEltypeNoUnits)
copy!(x0.Σ, apply_diffusion(x0.Σ, initdiff))
apply_diffusion!(x0.Σ, initdiff)

# Measurement model related things
R =
Expand Down
6 changes: 3 additions & 3 deletions src/covariance_structure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,19 +34,19 @@ factorized_zeros(C::BlockDiagonalCovariance{T}, sizes...) where {T} = begin
for s in sizes
@assert s % C.d == 0
end
return BlockDiagonal([Array{T}(calloc, (s ÷ C.d for s in sizes)...) for _ in 1:C.d])
return MFBD([Array{T}(calloc, (s ÷ C.d for s in sizes)...) for _ in 1:C.d])
end
factorized_similar(C::BlockDiagonalCovariance{T}, size1, size2) where {T} = begin
for s in (size1, size2)
@assert s % C.d == 0
end
return BlockDiagonal([similar(Matrix{T}, size1 ÷ C.d, size2 ÷ C.d) for _ in 1:C.d])
return MFBD([similar(Matrix{T}, size1 ÷ C.d, size2 ÷ C.d) for _ in 1:C.d])
end

to_factorized_matrix(::DenseCovariance, M::AbstractMatrix) = Matrix(M)
to_factorized_matrix(::IsometricKroneckerCovariance, M::IsometricKroneckerProduct) = M
to_factorized_matrix(C::BlockDiagonalCovariance, M::IsometricKroneckerProduct) =
BlockDiagonal([M.B for _ in 1:C.d])
MFBD([M.B for _ in 1:C.d])

for FT in [:DenseCovariance, :IsometricKroneckerCovariance, :BlockDiagonalCovariance]
@eval to_factorized_matrix(FAC::$FT, M::PSDMatrix) =
Expand Down
2 changes: 1 addition & 1 deletion src/derivative_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ function calc_H!(H, integ, cache)
OrdinaryDiffEq.calc_J!(ddu, integ, cache, true)

@unpack C_dxd = cache
if C_dxd isa BlockDiagonal
if C_dxd isa MFBD
@simd ivdep for i in eachindex(blocks(C_dxd))
@assert length(C_dxd.blocks[i]) == 1
C_dxd.blocks[i][1] = ddu[i, i]
Expand Down
12 changes: 6 additions & 6 deletions src/diffusions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@ apply_diffusion(
) where {T} = begin
PSDMatrix(Q.R * sqrt.(diffusion.diag.value))
end
apply_diffusion(Q::PSDMatrix{T,<:BlockDiagonal}, diffusion::Diagonal) where {T} = begin
apply_diffusion(Q::PSDMatrix{T,<:MFBD}, diffusion::Diagonal) where {T} = begin
PSDMatrix(
BlockDiagonal([
MFBD([
Q.R.blocks[i] * sqrt.(diffusion.diag[i]) for i in eachindex(Q.R.blocks)
]),
)
Expand All @@ -28,7 +28,7 @@ end
apply_diffusion!(Q::PSDMatrix, diffusion::Diagonal{T,<:FillArrays.Fill}) where {T} =
rmul!(Q.R, sqrt.(diffusion.diag.value))
apply_diffusion!(
Q::PSDMatrix{T,<:BlockDiagonal},
Q::PSDMatrix{T,<:MFBD},
diffusion::Diagonal{T,<:Vector},
) where {T} =
@simd ivdep for i in eachindex(blocks(Q.R))
Expand Down Expand Up @@ -105,7 +105,7 @@ function estimate_global_diffusion(::FixedDiffusion, integ)
diffusion_t = if S isa IsometricKroneckerProduct
@assert length(S.B) == 1
dot(v, e) / d / S.B[1]
elseif S isa BlockDiagonal
elseif S isa MFBD
@assert length(S.blocks) == d
@assert length(S.blocks[1]) == 1
@simd ivdep for i in eachindex(e)
Expand Down Expand Up @@ -204,7 +204,7 @@ function local_scalar_diffusion(cache)
σ² = if HQH isa IsometricKroneckerProduct
@assert length(HQH.B) == 1
dot(z, e) / d / HQH.B[1]
elseif HQH isa BlockDiagonal
elseif HQH isa MFBD
@assert length(HQH.blocks) == d
@assert length(HQH.blocks[1]) == 1
for i in eachindex(e)
Expand Down Expand Up @@ -245,7 +245,7 @@ function local_diagonal_diffusion(cache)
# Q_11 = dot(c1, c1)

# @assert
Q_11 = if Qh.R isa BlockDiagonal
Q_11 = if Qh.R isa MFBD
for i in 1:d
c1 = _matmul!(
view(cache.C_Dxd.blocks[i], :, 1:1),
Expand Down
18 changes: 9 additions & 9 deletions src/filtering/markov_kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -121,12 +121,12 @@ function marginalize_cov!(
end

function marginalize_cov!(
Σ_out::PSDMatrix{T,<:BlockDiagonal},
Σ_curr::PSDMatrix{T,<:BlockDiagonal},
Σ_out::PSDMatrix{T,<:MFBD},
Σ_curr::PSDMatrix{T,<:MFBD},
K::AffineNormalKernel{
<:AbstractMatrix,
<:Any,
<:PSDMatrix{S,<:BlockDiagonal},
<:PSDMatrix{S,<:MFBD},
};
C_DxD::AbstractMatrix,
C_3DxD::AbstractMatrix,
Expand Down Expand Up @@ -268,22 +268,22 @@ end

function compute_backward_kernel!(
Kout::KT1,
xpred::SRGaussian{T,<:BlockDiagonal},
x::SRGaussian{T,<:BlockDiagonal},
xpred::SRGaussian{T,<:MFBD},
x::SRGaussian{T,<:MFBD},
K::KT2;
C_DxD::AbstractMatrix,
diffusion=1,
) where {
T,
KT1<:AffineNormalKernel{
<:BlockDiagonal,
<:MFBD,
<:AbstractVector,
<:PSDMatrix{T,<:BlockDiagonal},
<:PSDMatrix{T,<:MFBD},
},
KT2<:AffineNormalKernel{
<:BlockDiagonal,
<:MFBD,
<:Any,
<:PSDMatrix{T,<:BlockDiagonal},
<:PSDMatrix{T,<:MFBD},
},
}
d = length(blocks(xpred.Σ.R))
Expand Down
12 changes: 6 additions & 6 deletions src/filtering/predict.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,12 @@ end

# BlockDiagonal version
function predict_cov!(
Σ_out::PSDMatrix{T,<:BlockDiagonal},
Σ_curr::PSDMatrix{T,<:BlockDiagonal},
Ah::BlockDiagonal,
Qh::PSDMatrix{S,<:BlockDiagonal},
C_DxD::BlockDiagonal,
C_2DxD::BlockDiagonal,
Σ_out::PSDMatrix{T,<:MFBD},
Σ_curr::PSDMatrix{T,<:MFBD},
Ah::MFBD,
Qh::PSDMatrix{S,<:MFBD},
C_DxD::MFBD,
C_2DxD::MFBD,
diffusion::Diagonal,
) where {T,S}
for i in eachindex(blocks(Σ_out.R))
Expand Down
18 changes: 9 additions & 9 deletions src/filtering/update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -194,19 +194,19 @@ function update!(
end

function update!(
x_out::SRGaussian{T,<:BlockDiagonal},
x_pred::SRGaussian{T,<:BlockDiagonal},
x_out::SRGaussian{T,<:MFBD},
x_pred::SRGaussian{T,<:MFBD},
measurement::Gaussian{
<:AbstractVector,
<:Union{<:PSDMatrix{T,<:BlockDiagonal},<:BlockDiagonal},
<:Union{<:PSDMatrix{T,<:MFBD},<:MFBD},
},
H::BlockDiagonal,
K1_cache::BlockDiagonal,
K2_cache::BlockDiagonal,
M_cache::BlockDiagonal,
C_dxd::BlockDiagonal,
H::MFBD,
K1_cache::MFBD,
K2_cache::MFBD,
M_cache::MFBD,
C_dxd::MFBD,
C_d::AbstractVector;
R::Union{Nothing,PSDMatrix{T,<:BlockDiagonal}}=nothing,
R::Union{Nothing,PSDMatrix{T,<:MFBD}}=nothing,
) where {T}
d = length(blocks(x_out.Σ.R))
q = size(blocks(x_out.Σ.R)[1], 1) - 1
Expand Down
4 changes: 2 additions & 2 deletions src/perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ function estimate_errors!(cache::AbstractODEFilterCache)
_Q = apply_diffusion(Qh, local_diffusion)
_matmul!(R, _Q.R, H')
error_estimate = view(cache.tmp, 1:d)
if R isa BlockDiagonal
if R isa MFBD
for i in eachindex(R.blocks)
error_estimate[i] = sum(abs2, R.blocks[i])
end
Expand All @@ -247,7 +247,7 @@ function estimate_errors!(cache::AbstractODEFilterCache)
error_estimate = view(cache.tmp, 1:d)
if R isa IsometricKroneckerProduct
error_estimate .= sum(abs2, R.B)
elseif R isa BlockDiagonal
elseif R isa MFBD
for i in eachindex(blocks(R))
error_estimate[i] = sum(abs2, R.blocks[i])
end
Expand Down
Loading

0 comments on commit 2fb44be

Please sign in to comment.