Skip to content

Commit

Permalink
It works and it's (a little bit) faster than dense!
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Feb 13, 2024
1 parent 58c63eb commit d0a3eb0
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 3 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.15.0"

[deps]
ArrayAllocators = "c9d4266f-a5cb-439d-837c-c97b191379f5"
BlockDiagonals = "0a1fb500-61f7-11e9-3c65-f5ef3456f9f0"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
DiffEqDevTools = "f3b72e0c-5b89-59e1-b016-84e28bfd966d"
Expand Down
37 changes: 34 additions & 3 deletions src/blockdiagonals.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,30 @@ _matmul!(
return C
end

_matmul!(
C::BlockDiagonal{T},
A::BlockDiagonal{T},
B::Adjoint{T, <:BlockDiagonal{T}},
) where {T<:LinearAlgebra.BlasFloat} = begin
@assert length(C.blocks) == length(A.blocks) == length(B.parent.blocks)
@simd ivdep for i in eachindex(blocks(C))
@inbounds _matmul!(C.blocks[i], A.blocks[i], adjoint(B.parent.blocks[i]))
end
return C
end

_matmul!(
C::BlockDiagonal{T},
A::Adjoint{T, <:BlockDiagonal{T}},
B::BlockDiagonal{T},
) where {T<:LinearAlgebra.BlasFloat} = begin
@assert length(C.blocks) == length(A.parent.blocks) == length(B.blocks)
@simd ivdep for i in eachindex(blocks(C))
@inbounds _matmul!(C.blocks[i], adjoint(A.parent.blocks[i]), B.blocks[i])
end
return C
end

_matmul!(
C::AbstractVector{T},
A::BlockDiagonal{T},
Expand All @@ -42,7 +66,14 @@ _matmul!(
return C
end

function LinearAlgebra.cholesky!(B::BlockDiagonal)
C = BlockDiagonal(map(b -> parent(UpperTriangular(cholesky!(b).U)), blocks(B)))
return Cholesky(C, 'U', 0)
function BlockDiagonals.isequal_blocksizes(B1::BlockDiagonal, B2::BlockDiagonal)
@assert length(B1.blocks) == length(B2.blocks)
for i in eachindex(B1.blocks)
if size(B1.blocks[i]) != size(B2.blocks[i])
return false
end
end
return true
end

LinearAlgebra.adjoint(B::BlockDiagonal) = Adjoint(B)
7 changes: 7 additions & 0 deletions src/diffusions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,13 @@ function local_scalar_diffusion(cache)
σ² = if HQH isa IsometricKroneckerProduct
@assert length(HQH.B) == 1
dot(z, e) / d / HQH.B[1]
elseif HQH isa BlockDiagonal
@assert length(HQH.blocks) == d
@assert length(HQH.blocks[1]) == 1
for i in eachindex(e)
e[i] /= HQH.blocks[i][1]
end
dot(z, e) / d
else
C = cholesky!(HQH)
ldiv!(C, e)
Expand Down
45 changes: 45 additions & 0 deletions src/filtering/update.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,51 @@ function update!(
return x_out, loglikelihood
end


function update!(
x_out::SRGaussian{T,<:BlockDiagonal},
x_pred::SRGaussian{T,<:BlockDiagonal},
measurement::Gaussian{
<:AbstractVector,
<:Union{<:PSDMatrix{T,<:BlockDiagonal},<:BlockDiagonal},
},
H::BlockDiagonal,
K1_cache::BlockDiagonal,
K2_cache::BlockDiagonal,
M_cache::BlockDiagonal,
C_dxd::BlockDiagonal,
C_d::AbstractVector;
R::Union{Nothing,PSDMatrix{T,<:BlockDiagonal}}=nothing,
) where {T}
d = length(blocks(x_out.Σ.R))
q = size(blocks(x_out.Σ.R)[1], 1) - 1

ll = zero(eltype(x_out.μ))
for i in eachindex(blocks(x_out.Σ.R))
_, _ll = update!(
Gaussian(view(x_out.μ, (i-1)*(q+1)+1:i*(q+1)),
PSDMatrix(x_out.Σ.R.blocks[i])),
Gaussian(view(x_pred.μ, (i-1)*(q+1)+1:i*(q+1)),
PSDMatrix(x_pred.Σ.R.blocks[i])),
Gaussian(view(measurement.μ, i:i),
if measurement.Σ isa PSDMatrix
PSDMatrix(measurement.Σ.R.blocks[i])
else
measurement.Σ.blocks[i]
end),
H.blocks[i],
K1_cache.blocks[i],
K2_cache.blocks[i],
M_cache.blocks[i],
C_dxd.blocks[i],
view(C_d, i:i);
R,
)
ll += _ll
end
return x_out, ll
end

# Short-hand with cache
function update!(x_out, x, measurement, H; cache, R=nothing)
@unpack K1, m_tmp, C_DxD, C_dxd, C_Dxd, C_d = cache
Expand Down
4 changes: 4 additions & 0 deletions src/perform_step.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ function estimate_errors!(cache::AbstractODEFilterCache)
error_estimate = view(cache.tmp, 1:d)
if R isa IsometricKroneckerProduct
error_estimate .= sum(abs2, R.B)
elseif R isa BlockDiagonal
for i in eachindex(blocks(R))
error_estimate[i] = sum(abs2, R.blocks[i])
end
else
sum!(abs2, error_estimate', view(R, :, 1:d))
end
Expand Down

0 comments on commit d0a3eb0

Please sign in to comment.