From b0e9bfaebc7c8076511f6b627958a7b53028040e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 9 Dec 2023 03:06:42 -0500 Subject: [PATCH 1/9] add location scale family --- src/AdvancedVI.jl | 10 ++ src/families/location_scale.jl | 167 ++++++++++++++++++++ test/inference/repgradelbo_locationscale.jl | 82 ++++++++++ 3 files changed, 259 insertions(+) create mode 100644 src/families/location_scale.jl create mode 100644 test/inference/repgradelbo_locationscale.jl diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 89f86696..a69b8d89 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -141,6 +141,16 @@ export include("objectives/elbo/entropy.jl") include("objectives/elbo/repgradelbo.jl") + +# Variational Families +export + VILocationScale, + MeanFieldGaussian, + FullRankGaussian + +include("families/location_scale.jl") + + # Optimization Routine function optimize end diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl new file mode 100644 index 00000000..7aed751e --- /dev/null +++ b/src/families/location_scale.jl @@ -0,0 +1,167 @@ + +""" + VILocationScale(location, scale, dist) <: ContinuousMultivariateDistribution + +The location scale variational family broadly represents various variational +families using `location` and `scale` variational parameters. + +It generally represents any distribution for which the sampling path can be +represented as follows: +```julia + d = length(location) + u = rand(dist, d) + z = scale*u + location +``` +""" +struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution + location::L + scale ::S + dist ::D +end + +Functors.@functor VILocationScale (location, scale) + +# Specialization of `Optimisers.destructure` for mean-field location-scale families. +# These are necessary because we only want to extract the diagonal elements of +# `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD +# is very inefficient. +# begin +struct RestructureMeanField{L, S<:Diagonal, D} + q::VILocationScale{L, S, D} +end + +function (re::RestructureMeanField)(flat::AbstractVector) + n_dims = div(length(flat), 2) + location = first(flat, n_dims) + scale = Diagonal(last(flat, n_dims)) + VILocationScale(location, scale, re.q.dist) +end + +function Optimisers.destructure( + q::VILocationScale{L, <:Diagonal, D} +) where {L, D} + @unpack location, scale, dist = q + flat = vcat(location, diag(scale)) + flat, RestructureMeanField(q) +end +# end + +Base.length(q::VILocationScale) = length(q.location) + +Base.size(q::VILocationScale) = size(q.location) + +Base.eltype(::Type{<:VILocationScale{L, S, D}}) where {L, S, D} = eltype(D) + +function StatsBase.entropy(q::VILocationScale) + @unpack location, scale, dist = q + n_dims = length(location) + n_dims*convert(eltype(location), entropy(dist)) + first(logabsdet(scale)) +end + +function Distributions.logpdf(q::VILocationScale, z::AbstractVector{<:Real}) + @unpack location, scale, dist = q + sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale)) +end + +function Distributions._logpdf(q::VILocationScale, z::AbstractVector{<:Real}) + @unpack location, scale, dist = q + sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale)) +end + +function Distributions.rand(q::VILocationScale) + @unpack location, scale, dist = q + n_dims = length(location) + scale*rand(dist, n_dims) + location +end + +function Distributions.rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int) + @unpack location, scale, dist = q + n_dims = length(location) + scale*rand(rng, dist, n_dims, num_samples) .+ location +end + +# This specialization improves AD performance of the sampling path +function Distributions.rand( + rng::AbstractRNG, q::VILocationScale{L, <:Diagonal, D}, num_samples::Int +) where {L, D} + @unpack location, scale, dist = q + n_dims = length(location) + scale_diag = diag(scale) + scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location +end + +function Distributions._rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVector{<:Real}) + @unpack location, scale, dist = q + rand!(rng, dist, x) + x .= scale*x + return x += location +end + +function Distributions._rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real}) + @unpack location, scale, dist = q + rand!(rng, dist, x) + x[:] = scale*x + return x .+= location +end + +Distributions.mean(q::VILocationScale) = q.location + +function Distributions.var(q::VILocationScale) + C = q.scale + Diagonal(C*C') +end + +function Distributions.cov(q::VILocationScale) + C = q.scale + Hermitian(C*C') +end + +""" + FullRankGaussian(location, scale; check_args = true) + +Construct a Gaussian variational approximation with a dense covariance matrix. + +# Arguments +- `location::AbstractVector{T}`: Mean of the Gaussian. +- `scale::LinearAlgebra.AbstractTriangular{T}`: Cholesky factor of the covariance of the Gaussian. + +# Keyword Arguments +- `check_args`: Check the conditioning of the initial scale (default: `true`). +""" +function FullRankGaussian( + μ::AbstractVector{T}, + L::LinearAlgebra.AbstractTriangular{T}; + check_args::Bool = true +) where {T <: Real} + @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be positive definite" + if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) + @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." + end + q_base = Normal{T}(zero(T), one(T)) + VILocationScale(μ, L, q_base) +end + +""" + MeanFieldGaussian(location, scale; check_args = true) + +Construct a Gaussian variational approximation with a diagonal covariance matrix. + +# Arguments +- `location::AbstractVector{T}`: Mean of the Gaussian. +- `scale::Diagonal{T}`: Diagonal Cholesky factor of the covariance of the Gaussian. + +# Keyword Arguments +- `check_args`: Check the conditioning of the initial scale (default: `true`). +""" +function MeanFieldGaussian( + μ::AbstractVector{T}, + L::Diagonal{T}; + check_args::Bool = true +) where {T <: Real} + @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be a Cholesky factor" + if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) + @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." + end + q_base = Normal{T}(zero(T), one(T)) + VILocationScale(μ, L, q_base) +end diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl new file mode 100644 index 00000000..d5177fb6 --- /dev/null +++ b/test/inference/repgradelbo_locationscale.jl @@ -0,0 +1,82 @@ + +const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false + +using Test + +@testset "inference RepGradELBO DistributionsAD" begin + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for + realtype ∈ [Float64, Float32], + (modelname, modelconstr) ∈ Dict( + :Normal=> normal_meanfield, + :Normal=> normal_fullrank, + ), + (objname, objective) ∈ Dict( + :RepGradELBOClosedFormEntropy => RepGradELBO(10), + :RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()), + ), + (adbackname, adbackend) ∈ Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + #:Enzyme => AutoEnzyme(), + ) + + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = modelconstr(rng, realtype) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) + + q0 = if is_meanfield + MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) + else + L0 = Matrix{realtype}(I, n_dims, n_dims) |> LowerTriangular + FullRankGaussian(zeros(realtype, n_dims), L0) + end + + @testset "convergence" begin + Δλ₀ = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) + q, stats, _ = optimize( + rng, model, objective, q0, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + + μ = q.location + L = q.scale + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ Δλ₀/T^(1/4) + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = StableRNG(seed) + q, stats, _ = optimize( + rng, model, objective, q0, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + μ = q.location + L = q.scale + + rng_repl = StableRNG(seed) + q, stats, _ = optimize( + rng_repl, model, objective, q0, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + μ_repl = q.location + L_repl = q.scale + @test μ == μ_repl + @test L == L_repl + end + end +end + From 830b4a6b8a93990df6a7df5191ddde0b6f1fc34b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 9 Dec 2023 03:09:33 -0500 Subject: [PATCH 2/9] refactor switch bijector tests to use locscale, enable ReverseDiff --- .../repgradelbo_distributionsad_bijectors.jl | 31 +++++++++++-------- 1 file changed, 18 insertions(+), 13 deletions(-) diff --git a/test/inference/repgradelbo_distributionsad_bijectors.jl b/test/inference/repgradelbo_distributionsad_bijectors.jl index 9f1e3cc4..53e9e62f 100644 --- a/test/inference/repgradelbo_distributionsad_bijectors.jl +++ b/test/inference/repgradelbo_distributionsad_bijectors.jl @@ -15,7 +15,7 @@ using Test ), (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), - #:ReverseDiff => AutoReverseDiff(), + :ReverseDiff => AutoReverseDiff(), #:Zygote => AutoZygote(), #:Enzyme => AutoEnzyme(), ) @@ -30,23 +30,28 @@ using Test b = Bijectors.bijector(model) b⁻¹ = inverse(b) - μ₀ = Zeros(realtype, n_dims) - L₀ = Diagonal(Ones(realtype, n_dims)) + μ0 = Zeros(realtype, n_dims) + L0 = Diagonal(Ones(realtype, n_dims)) - q₀_η = TuringDiagMvNormal(μ₀, diag(L₀)) - q₀_z = Bijectors.transformed(q₀_η, b⁻¹) + q0_η = if is_meanfield + MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) + else + L0 = Matrix{realtype}(I, n_dims, n_dims) |> LowerTriangular + FullRankGaussian(zeros(realtype, n_dims), L0) + end + q0_z = Bijectors.transformed(q0_η, b⁻¹) @testset "convergence" begin - Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) q, stats, _ = optimize( - rng, model, objective, q₀_z, T; + rng, model, objective, q0_z, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, adbackend = adbackend, ) - μ = mean(q.dist) - L = sqrt(cov(q.dist)) + μ = q.dist.location + L = q.dist.scale Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) @test Δλ ≤ Δλ₀/T^(1/4) @@ -57,17 +62,17 @@ using Test @testset "determinism" begin rng = StableRNG(seed) q, stats, _ = optimize( - rng, model, objective, q₀_z, T; + rng, model, objective, q0_z, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, adbackend = adbackend, ) - μ = mean(q.dist) - L = sqrt(cov(q.dist)) + μ = q.dist.location + L = q.dist.scale rng_repl = StableRNG(seed) q, stats, _ = optimize( - rng_repl, model, objective, q₀_z, T; + rng_repl, model, objective, q0_z, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, adbackend = adbackend, From 9df544d9e16cc4c024187488c2f2c5c386f2d048 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 9 Dec 2023 03:37:30 -0500 Subject: [PATCH 3/9] fix test file name for location-scale plus bijector inference test --- ...nsad_bijectors.jl => repgradelbo_locationscale_bijectors.jl} | 0 test/runtests.jl | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename test/inference/{repgradelbo_distributionsad_bijectors.jl => repgradelbo_locationscale_bijectors.jl} (100%) diff --git a/test/inference/repgradelbo_distributionsad_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl similarity index 100% rename from test/inference/repgradelbo_distributionsad_bijectors.jl rename to test/inference/repgradelbo_locationscale_bijectors.jl diff --git a/test/runtests.jl b/test/runtests.jl index b14b8b2e..6f8bcd1a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -39,4 +39,4 @@ include("interface/optimize.jl") include("interface/repgradelbo.jl") include("inference/repgradelbo_distributionsad.jl") -include("inference/repgradelbo_distributionsad_bijectors.jl") +include("inference/repgradelbo_locationscale_bijectors.jl") From ebb55efaa532917c858ada2e437389d8ee827f7f Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 9 Dec 2023 03:47:52 -0500 Subject: [PATCH 4/9] fix wrong testset names, add interface test for VILocationScale --- test/inference/repgradelbo_locationscale.jl | 2 +- .../repgradelbo_locationscale_bijectors.jl | 2 +- test/interface/location_scale.jl | 92 +++++++++++++++++++ test/runtests.jl | 2 + 4 files changed, 96 insertions(+), 2 deletions(-) create mode 100644 test/interface/location_scale.jl diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index d5177fb6..8ac9d2ca 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -3,7 +3,7 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false using Test -@testset "inference RepGradELBO DistributionsAD" begin +@testset "inference RepGradELBO VILocationScale" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype ∈ [Float64, Float32], (modelname, modelconstr) ∈ Dict( diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 53e9e62f..0bfe2ec9 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -3,7 +3,7 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false using Test -@testset "inference RepGradELBO DistributionsAD Bijectors" begin +@testset "inference RepGradELBO VILocationScale Bijectors" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype ∈ [Float64, Float32], (modelname, modelconstr) ∈ Dict( diff --git a/test/interface/location_scale.jl b/test/interface/location_scale.jl new file mode 100644 index 00000000..4da1691a --- /dev/null +++ b/test/interface/location_scale.jl @@ -0,0 +1,92 @@ + +@testset "interface LocationScale" begin + @testset "$(string(covtype)) $(basedist) $(realtype)" for + basedist = [:gaussian], + covtype = [:meanfield, :fullrank], + realtype = [Float32, Float64] + + n_dims = 10 + n_montecarlo = 1000_000 + + μ = randn(realtype, n_dims) + L = if covtype == :fullrank + tril(I + ones(realtype, n_dims, n_dims)/2) |> LowerTriangular + else + Diagonal(realtype.(1:10)) + end + Σ = L*L' + + q = if covtype == :fullrank && basedist == :gaussian + FullRankGaussian(μ, L) + elseif covtype == :meanfield && basedist == :gaussian + MeanFieldGaussian(μ, L) + end + q_true = if basedist == :gaussian + MvNormal(μ, Σ) + end + + @testset "logpdf" begin + z = rand(q) + @test eltype(z) == realtype + @test logpdf(q, z) ≈ logpdf(q_true, z) rtol=realtype(1e-2) + @test eltype(logpdf(q, z)) == realtype + end + + @testset "entropy" begin + @test eltype(entropy(q)) == realtype + @test entropy(q) ≈ entropy(q_true) + end + + @testset "statistics" begin + @testset "mean" begin + @test eltype(mean(q)) == realtype + @test mean(q) == μ + end + @testset "var" begin + @test eltype(var(q)) == realtype + @test var(q) ≈ Diagonal(Σ) + end + @testset "cov" begin + @test eltype(cov(q)) == realtype + @test cov(q) ≈ Σ + end + end + + @testset "sampling" begin + @testset "rand" begin + z_samples = mapreduce(x -> rand(q), hcat, 1:n_montecarlo) + @test eltype(z_samples) == realtype + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + end + + @testset "rand batch" begin + z_samples = rand(q, n_montecarlo) + @test eltype(z_samples) == realtype + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + end + + @testset "rand!" begin + z_samples = Array{realtype}(undef, n_dims, n_montecarlo) + rand!(q, z_samples) + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + end + end + end + + @testset "Diagonal destructure" for + n_dims = 10 + μ = zeros(n_dims) + L = ones(n_dims) + q = MeanFieldGaussian(μ, L |> Diagonal) + λ, re = Optimisers.destructure(q) + + @test length(λ) == 2*n_dims + @test q == re(λ) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 6f8bcd1a..8e540a08 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -37,6 +37,8 @@ include("models/normallognormal.jl") include("interface/ad.jl") include("interface/optimize.jl") include("interface/repgradelbo.jl") +include("interface/location_scale.jl") include("inference/repgradelbo_distributionsad.jl") +include("inference/repgradelbo_locationscale.jl") include("inference/repgradelbo_locationscale_bijectors.jl") From 3b9a07b6788ee0c794d44c6984790ba967e42878 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 9 Dec 2023 14:06:11 -0500 Subject: [PATCH 5/9] fix test parameters for `LocationScale` --- test/interface/location_scale.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/interface/location_scale.jl b/test/interface/location_scale.jl index 4da1691a..ef30cf44 100644 --- a/test/interface/location_scale.jl +++ b/test/interface/location_scale.jl @@ -12,7 +12,7 @@ L = if covtype == :fullrank tril(I + ones(realtype, n_dims, n_dims)/2) |> LowerTriangular else - Diagonal(realtype.(1:10)) + Diagonal(ones(realtype, n_dims)) end Σ = L*L' From 802a83cdba47c2777b8240593b4991411a829de3 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sat, 9 Dec 2023 22:58:39 -0500 Subject: [PATCH 6/9] fix test for LocationScale with Bijectors --- test/inference/repgradelbo_locationscale_bijectors.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 0bfe2ec9..7154440c 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -77,8 +77,8 @@ using Test show_progress = PROGRESS, adbackend = adbackend, ) - μ_repl = mean(q.dist) - L_repl = sqrt(cov(q.dist)) + μ_repl = q.dist.location + L_repl = q.dist.scale @test μ == μ_repl @test L == L_repl end From 021fd46754a84795e386a3de1f571bfabb2e447a Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Sun, 10 Dec 2023 00:59:56 -0500 Subject: [PATCH 7/9] add tests to improve coverage, fix bug for `rand!` with vectors --- src/families/location_scale.jl | 9 +---- test/interface/location_scale.jl | 56 ++++++++++++++++++++++++++++++-- 2 files changed, 54 insertions(+), 11 deletions(-) diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index 7aed751e..b474ee3e 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -90,14 +90,7 @@ function Distributions.rand( scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location end -function Distributions._rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVector{<:Real}) - @unpack location, scale, dist = q - rand!(rng, dist, x) - x .= scale*x - return x += location -end - -function Distributions._rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractMatrix{<:Real}) +function Distributions._rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVecOrMat{<:Real}) @unpack location, scale, dist = q rand!(rng, dist, x) x[:] = scale*x diff --git a/test/interface/location_scale.jl b/test/interface/location_scale.jl index ef30cf44..6f1d1d1a 100644 --- a/test/interface/location_scale.jl +++ b/test/interface/location_scale.jl @@ -32,11 +32,26 @@ @test eltype(logpdf(q, z)) == realtype end + @testset "_logpdf" begin + z = rand(q) + @test eltype(z) == realtype + @test Distributions._logpdf(q, z) ≈ logpdf(q_true, z) rtol=realtype(1e-2) + @test eltype(Distributions.logpdf(q, z)) == realtype + end + @testset "entropy" begin @test eltype(entropy(q)) == realtype @test entropy(q) ≈ entropy(q_true) end + @testset "length" begin + @test length(q) == n_dims + end + + @testset "eltype" begin + @test eltype(q) == realtype + end + @testset "statistics" begin @testset "mean" begin @test eltype(mean(q)) == realtype @@ -59,6 +74,9 @@ @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + + z_sample_ref = rand(StableRNG(1), q) + @test z_sample_ref == rand(StableRNG(1), q) end @testset "rand batch" begin @@ -67,14 +85,46 @@ @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + + samples_ref = rand(StableRNG(1), q, n_montecarlo) + @test samples_ref == rand(StableRNG(1), q, n_montecarlo) end - @testset "rand!" begin - z_samples = Array{realtype}(undef, n_dims, n_montecarlo) - rand!(q, z_samples) + @testset "rand! AbstractVector" begin + res = map(1:n_montecarlo) do _ + z_sample = Array{realtype}(undef, n_dims) + z_sample_ret = rand!(q, z_sample) + (z_sample, z_sample_ret) + end + z_samples = mapreduce(first, hcat, res) + z_samples_ret = mapreduce(last, hcat, res) + @test z_samples == z_samples_ret @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + + z_sample_ref = Array{realtype}(undef, n_dims) + rand!(StableRNG(1), q, z_sample_ref) + + z_sample = Array{realtype}(undef, n_dims) + rand!(StableRNG(1), q, z_sample) + @test z_sample_ref == z_sample + end + + @testset "rand! AbstractMatrix" begin + z_samples = Array{realtype}(undef, n_dims, n_montecarlo) + z_samples_ret = rand!(q, z_samples) + @test z_samples == z_samples_ret + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + + z_samples_ref = Array{realtype}(undef, n_dims, n_montecarlo) + rand!(StableRNG(1), q, z_samples_ref) + + z_samples = Array{realtype}(undef, n_dims, n_montecarlo) + rand!(StableRNG(1), q, z_samples) + @test z_samples_ref == z_samples end end end From 1c80dec61a20fe6f1f0d960cbdf55470186ff8f3 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 20 Dec 2023 00:26:10 -0500 Subject: [PATCH 8/9] rename location scale, fix type ambiguity for `rand` --- src/families/location_scale.jl | 50 ++++++++++++++++++---------------- 1 file changed, 27 insertions(+), 23 deletions(-) diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl index b474ee3e..152cd15d 100644 --- a/src/families/location_scale.jl +++ b/src/families/location_scale.jl @@ -1,6 +1,6 @@ """ - VILocationScale(location, scale, dist) <: ContinuousMultivariateDistribution + MvLocationScale(location, scale, dist) <: ContinuousMultivariateDistribution The location scale variational family broadly represents various variational families using `location` and `scale` variational parameters. @@ -13,68 +13,72 @@ represented as follows: z = scale*u + location ``` """ -struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution +struct MvLocationScale{ + S, D <: ContinuousDistribution, L +} <: ContinuousMultivariateDistribution location::L scale ::S dist ::D end -Functors.@functor VILocationScale (location, scale) +Functors.@functor MvLocationScale (location, scale) # Specialization of `Optimisers.destructure` for mean-field location-scale families. # These are necessary because we only want to extract the diagonal elements of # `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD # is very inefficient. # begin -struct RestructureMeanField{L, S<:Diagonal, D} - q::VILocationScale{L, S, D} +struct RestructureMeanField{S <: Diagonal, D, L} + q::MvLocationScale{S, D, L} end function (re::RestructureMeanField)(flat::AbstractVector) n_dims = div(length(flat), 2) location = first(flat, n_dims) scale = Diagonal(last(flat, n_dims)) - VILocationScale(location, scale, re.q.dist) + MvLocationScale(location, scale, re.q.dist) end function Optimisers.destructure( - q::VILocationScale{L, <:Diagonal, D} -) where {L, D} + q::MvLocationScale{<:Diagonal, D, L} +) where {D, L} @unpack location, scale, dist = q flat = vcat(location, diag(scale)) flat, RestructureMeanField(q) end # end -Base.length(q::VILocationScale) = length(q.location) +Base.length(q::MvLocationScale) = length(q.location) -Base.size(q::VILocationScale) = size(q.location) +Base.size(q::MvLocationScale) = size(q.location) -Base.eltype(::Type{<:VILocationScale{L, S, D}}) where {L, S, D} = eltype(D) +Base.eltype(::Type{<:MvLocationScale{S, D, L}}) where {S, D, L} = eltype(D) -function StatsBase.entropy(q::VILocationScale) +function StatsBase.entropy(q::MvLocationScale) @unpack location, scale, dist = q n_dims = length(location) n_dims*convert(eltype(location), entropy(dist)) + first(logabsdet(scale)) end -function Distributions.logpdf(q::VILocationScale, z::AbstractVector{<:Real}) +function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale)) end -function Distributions._logpdf(q::VILocationScale, z::AbstractVector{<:Real}) +function Distributions._logpdf(q::MvLocationScale, z::AbstractVector{<:Real}) @unpack location, scale, dist = q sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale)) end -function Distributions.rand(q::VILocationScale) +function Distributions.rand(q::MvLocationScale) @unpack location, scale, dist = q n_dims = length(location) scale*rand(dist, n_dims) + location end -function Distributions.rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int) +function Distributions.rand( + rng::AbstractRNG, q::MvLocationScale{S, D, L}, num_samples::Int +) where {S, D, L} @unpack location, scale, dist = q n_dims = length(location) scale*rand(rng, dist, n_dims, num_samples) .+ location @@ -82,7 +86,7 @@ end # This specialization improves AD performance of the sampling path function Distributions.rand( - rng::AbstractRNG, q::VILocationScale{L, <:Diagonal, D}, num_samples::Int + rng::AbstractRNG, q::MvLocationScale{<:Diagonal, D, L}, num_samples::Int ) where {L, D} @unpack location, scale, dist = q n_dims = length(location) @@ -90,21 +94,21 @@ function Distributions.rand( scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location end -function Distributions._rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVecOrMat{<:Real}) +function Distributions._rand!(rng::AbstractRNG, q::MvLocationScale, x::AbstractVecOrMat{<:Real}) @unpack location, scale, dist = q rand!(rng, dist, x) x[:] = scale*x return x .+= location end -Distributions.mean(q::VILocationScale) = q.location +Distributions.mean(q::MvLocationScale) = q.location -function Distributions.var(q::VILocationScale) +function Distributions.var(q::MvLocationScale) C = q.scale Diagonal(C*C') end -function Distributions.cov(q::VILocationScale) +function Distributions.cov(q::MvLocationScale) C = q.scale Hermitian(C*C') end @@ -131,7 +135,7 @@ function FullRankGaussian( @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." end q_base = Normal{T}(zero(T), one(T)) - VILocationScale(μ, L, q_base) + MvLocationScale(μ, L, q_base) end """ @@ -156,5 +160,5 @@ function MeanFieldGaussian( @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." end q_base = Normal{T}(zero(T), one(T)) - VILocationScale(μ, L, q_base) + MvLocationScale(μ, L, q_base) end From bbfac2a2f2386fe49aab48d5076b34191f8bcb06 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 20 Dec 2023 01:29:58 -0500 Subject: [PATCH 9/9] remove duplicate type tests for `LocationScale` --- test/interface/location_scale.jl | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/test/interface/location_scale.jl b/test/interface/location_scale.jl index 6f1d1d1a..6670f5c2 100644 --- a/test/interface/location_scale.jl +++ b/test/interface/location_scale.jl @@ -25,16 +25,18 @@ MvNormal(μ, Σ) end + @testset "eltype" begin + @test eltype(q) == realtype + end + @testset "logpdf" begin z = rand(q) - @test eltype(z) == realtype @test logpdf(q, z) ≈ logpdf(q_true, z) rtol=realtype(1e-2) @test eltype(logpdf(q, z)) == realtype end @testset "_logpdf" begin z = rand(q) - @test eltype(z) == realtype @test Distributions._logpdf(q, z) ≈ logpdf(q_true, z) rtol=realtype(1e-2) @test eltype(Distributions.logpdf(q, z)) == realtype end @@ -48,10 +50,6 @@ @test length(q) == n_dims end - @testset "eltype" begin - @test eltype(q) == realtype - end - @testset "statistics" begin @testset "mean" begin @test eltype(mean(q)) == realtype @@ -129,7 +127,7 @@ end end - @testset "Diagonal destructure" for + @testset "Diagonal destructure" begin n_dims = 10 μ = zeros(n_dims) L = ones(n_dims)