-
Notifications
You must be signed in to change notification settings - Fork 17
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add
covariance_factorizations.jl
to clean up the if/else KRONECKER …
…stuff
- Loading branch information
1 parent
7085943
commit 9d2b481
Showing
9 changed files
with
100 additions
and
77 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
abstract type CovarianceFactorization end | ||
struct KroneckerCovariance <: CovarianceFactorization end | ||
struct DenseCovariance <: CovarianceFactorization end | ||
|
||
function get_covariance_factorization(alg) | ||
if ( | ||
alg isa EK0 && | ||
!( | ||
alg.diffusionmodel isa DynamicMVDiffusion || | ||
alg.diffusionmodel isa FixedMVDiffusion | ||
) && | ||
alg.prior isa IWP | ||
) | ||
return KroneckerCovariance() | ||
else | ||
return DenseCovariance() | ||
end | ||
end | ||
|
||
factorized_zeros(::KroneckerCovariance, elType, sizes...; d, q) = begin | ||
for s in sizes | ||
@assert s % d == 0 | ||
end | ||
return IsoKroneckerProduct(d, zeros(elType, (s ÷ d for s in sizes)...)) | ||
end | ||
|
||
factorized_zeros(::DenseCovariance, elType, sizes...; d, q) = zeros(elType, sizes...) | ||
|
||
to_factorized_matrix(::DenseCovariance, M::AbstractMatrix) = Matrix(M) | ||
to_factorized_matrix(::KroneckerCovariance, M::AbstractMatrix) = IsoKroneckerProduct(M) # probably errors | ||
to_factorized_matrix(::KroneckerCovariance, M::IKP) = M |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,32 @@ | ||
function projection(d::Integer, q::Integer, ::Type{elType}=typeof(1.0)) where {elType} | ||
function projection( | ||
::DenseCovariance, | ||
d::Integer, | ||
q::Integer, | ||
::Type{elType}=typeof(1.0), | ||
) where {elType} | ||
D = d * (q + 1) | ||
Proj(deriv) = begin | ||
e_i = zeros(elType, q+1, 1) | ||
P = zeros(elType, d, D) | ||
@simd ivdep for i in deriv*d+1:D+1:d*D | ||
@inbounds P[i] = 1 | ||
end | ||
return P | ||
end | ||
return Proj | ||
end | ||
|
||
function projection( | ||
::KroneckerCovariance, | ||
d::Integer, | ||
q::Integer, | ||
::Type{elType}=typeof(1.0), | ||
) where {elType} | ||
Proj(deriv) = begin | ||
e_i = zeros(elType, q + 1, 1) | ||
if deriv <= q | ||
e_i[deriv+1] = 1 | ||
end | ||
IsoKroneckerProduct(d, e_i') | ||
return IsoKroneckerProduct(d, e_i') | ||
end | ||
return Proj | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters