From c45af156315b2cc96a47a18af02712d1b2d244b7 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 14 Jan 2025 13:39:33 +0100 Subject: [PATCH] Fix `mode` of `LKJCholesky` and define `mean(::LKJCholesky)` --- Project.toml | 2 +- src/cholesky/lkjcholesky.jl | 11 ++++++++++- test/cholesky/lkjcholesky.jl | 15 +++++++++++++-- 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/Project.toml b/Project.toml index eba90ae18..50607efaa 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Distributions" uuid = "31c24e10-a181-5473-b8eb-7969acd0382f" authors = ["JuliaStats"] -version = "0.25.116" +version = "0.25.117" [deps] AliasTables = "66dad0bd-aa9a-41b7-9441-69ab47430ed8" diff --git a/src/cholesky/lkjcholesky.jl b/src/cholesky/lkjcholesky.jl index 7556620f6..fda164f0a 100644 --- a/src/cholesky/lkjcholesky.jl +++ b/src/cholesky/lkjcholesky.jl @@ -109,11 +109,20 @@ function insupport(d::LKJCholesky, R::LinearAlgebra.Cholesky) return true end -function StatsBase.mode(d::LKJCholesky) +function StatsBase.mean(d::LKJCholesky) factors = Matrix{eltype(d)}(LinearAlgebra.I, size(d)) return LinearAlgebra.Cholesky(factors, d.uplo, 0) end +function mode(d::LKJCholesky; check_args::Bool=true) + @check_args( + LKJCholesky, + @setup(η = d.η), + (η, η > 1, "mode is defined only when η > 1."), + ) + return mean(d) +end + StatsBase.params(d::LKJCholesky) = (d.d, d.η, d.uplo) @inline partype(::LKJCholesky{T}) where {T <: Real} = T diff --git a/test/cholesky/lkjcholesky.jl b/test/cholesky/lkjcholesky.jl index b9afb59be..b23cbd1de 100644 --- a/test/cholesky/lkjcholesky.jl +++ b/test/cholesky/lkjcholesky.jl @@ -124,14 +124,25 @@ using FiniteDifferences end @testset "properties" begin - @testset for p in (4, 5), η in (2, 3.5), uplo in ('L', 'U') + @testset for p in (4, 5), η in (0.5, 2, 3.5), uplo in ('L', 'U') d = LKJCholesky(p, η, uplo) @test d.d == p @test size(d) == (p, p) @test Distributions.params(d) == (d.d, d.η, d.uplo) @test partype(d) <: Float64 - m = mode(d) + if η > 1 + m = mode(d) + @test m isa Cholesky{eltype(d)} + @test Matrix(m) ≈ I + else + @test_throws DomainError(η, "LKJCholesky: mode is defined only when η > 1.") mode(d) + end + m = mode(d; check_args = false) + @test m isa Cholesky{eltype(d)} + @test Matrix(m) ≈ I + + m = mean(d) @test m isa Cholesky{eltype(d)} @test Matrix(m) ≈ I end