Skip to content

Commit

Permalink
Small cleanup and better Gaussian printing
Browse files Browse the repository at this point in the history
  • Loading branch information
nathanaelbosch committed Jun 9, 2024
1 parent 9ab8d99 commit 527fe80
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 14 deletions.
3 changes: 2 additions & 1 deletion src/ProbNumDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ __precompile__()
module ProbNumDiffEq

import Base:
copy, copy!, show, size, ndims, similar, isapprox, isequal, iterate, ==, length, zero
copy, copy!, show, size, ndims, similar, isapprox, isequal, iterate, ==, length, zero,
eltype

using LinearAlgebra
import LinearAlgebra: mul!, norm_sqr
Expand Down
24 changes: 11 additions & 13 deletions src/gaussians.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ Implementation details: On `Σ` the functions `logdet`, `whiten` and `unwhiten`
struct Gaussian{T,S}
μ::T
Σ::S
Gaussian{T,S}(μ, Σ) where {T,S} = new(μ, Σ)
Gaussian::T, Σ::S) where {T,S} = new{T,S}(μ, Σ)
end
Base.convert(::Type{Gaussian{T,S}}, g::Gaussian) where {T,S} =
Expand All @@ -26,20 +25,16 @@ Base.isapprox(g1::Gaussian, g2::Gaussian; kwargs...) =
isapprox(g1.μ, g2.μ; kwargs...) && isapprox(g1.Σ, g2.Σ; kwargs...)
copy(P::Gaussian) = Gaussian(copy(P.μ), copy(P.Σ))
similar(P::Gaussian) = Gaussian(similar(P.μ), similar(P.Σ))
Base.copyto!(P::AbstractArray{<:Gaussian}, idx::Integer, el::Gaussian) = begin
P[idx] = copy(el)
P
end
Base.copyto!(P::AbstractArray{<:Gaussian}, idx::Integer, el::Gaussian) =
(P[idx] = copy(el); P)

Check warning on line 29 in src/gaussians.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussians.jl#L29

Added line #L29 was not covered by tests
function Base.copy!(dst::Gaussian, src::Gaussian)
copy!(dst.μ, src.μ)
copy!(dst.Σ, src.Σ)
return dst
end
Base.iterate(::Gaussian) = error()
Base.iterate(::Gaussian, s) = error()
Base.length(P::Gaussian) = length(mean(P))
length(P::Gaussian) = length(mean(P))
size(g::Gaussian) = size(g.μ)
Base.eltype(::Type{G}) where {G<:Gaussian} = G
eltype(::Type{G}) where {G<:Gaussian} = G
Base.@propagate_inbounds Base.getindex(P::Gaussian, i::Integer) =
Gaussian(P.μ[i], diag(P.Σ)[i])

Expand All @@ -52,7 +47,6 @@ var(g::Gaussian) = diag(g.Σ)
std(g::Gaussian) = sqrt.(diag(g.Σ))

dim(P::Gaussian) = length(P.μ)
ndims(g::Gaussian) = ndims(g.μ)

# whiten(Σ::PSD, z) = Σ.σ\z
whiten(Σ, z) = cholesky(Σ).U' \ z
Expand Down Expand Up @@ -121,9 +115,13 @@ RecursiveArrayTools.recursivecopy(P::Gaussian) = copy(P)
RecursiveArrayTools.recursivecopy!(dst::Gaussian, src::Gaussian) = copy!(dst, src)

# Print
show(io::IO, g::Gaussian) = print(io, "Gaussian($(g.μ), $(g.Σ))")
show(io::IO, ::MIME"text/plain", g::Gaussian{T,S}) where {T,S} =
print(io, "Gaussian{$T,$S}($(g.μ), $(g.Σ))")
show(io::IO, g::Gaussian) = print(io, "Gaussian(μ=$(g.μ), Σ=$(g.Σ))")
show(io::IO, ::MIME"text/plain", g::Gaussian{T,S}) where {T,S} = begin
println(io, "Gaussian{$T,$S}(")
println(io, " μ=$(g.μ),")
println(io, " Σ=$(g.Σ)")
print(io, ")")

Check warning on line 123 in src/gaussians.jl

View check run for this annotation

Codecov / codecov/patch

src/gaussians.jl#L118-L123

Added lines #L118 - L123 were not covered by tests
end

############################################################################################
# `SRGaussian`: Gaussians with PDFMatrix covariances
Expand Down

0 comments on commit 527fe80

Please sign in to comment.