diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index 89f866963..a69b8d893 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -141,6 +141,16 @@ export include("objectives/elbo/entropy.jl") include("objectives/elbo/repgradelbo.jl") + +# Variational Families +export + VILocationScale, + MeanFieldGaussian, + FullRankGaussian + +include("families/location_scale.jl") + + # Optimization Routine function optimize end diff --git a/src/families/location_scale.jl b/src/families/location_scale.jl new file mode 100644 index 000000000..152cd15da --- /dev/null +++ b/src/families/location_scale.jl @@ -0,0 +1,164 @@ + +""" + MvLocationScale(location, scale, dist) <: ContinuousMultivariateDistribution + +The location scale variational family broadly represents various variational +families using `location` and `scale` variational parameters. + +It generally represents any distribution for which the sampling path can be +represented as follows: +```julia + d = length(location) + u = rand(dist, d) + z = scale*u + location +``` +""" +struct MvLocationScale{ + S, D <: ContinuousDistribution, L +} <: ContinuousMultivariateDistribution + location::L + scale ::S + dist ::D +end + +Functors.@functor MvLocationScale (location, scale) + +# Specialization of `Optimisers.destructure` for mean-field location-scale families. +# These are necessary because we only want to extract the diagonal elements of +# `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD +# is very inefficient. +# begin +struct RestructureMeanField{S <: Diagonal, D, L} + q::MvLocationScale{S, D, L} +end + +function (re::RestructureMeanField)(flat::AbstractVector) + n_dims = div(length(flat), 2) + location = first(flat, n_dims) + scale = Diagonal(last(flat, n_dims)) + MvLocationScale(location, scale, re.q.dist) +end + +function Optimisers.destructure( + q::MvLocationScale{<:Diagonal, D, L} +) where {D, L} + @unpack location, scale, dist = q + flat = vcat(location, diag(scale)) + flat, RestructureMeanField(q) +end +# end + +Base.length(q::MvLocationScale) = length(q.location) + +Base.size(q::MvLocationScale) = size(q.location) + +Base.eltype(::Type{<:MvLocationScale{S, D, L}}) where {S, D, L} = eltype(D) + +function StatsBase.entropy(q::MvLocationScale) + @unpack location, scale, dist = q + n_dims = length(location) + n_dims*convert(eltype(location), entropy(dist)) + first(logabsdet(scale)) +end + +function Distributions.logpdf(q::MvLocationScale, z::AbstractVector{<:Real}) + @unpack location, scale, dist = q + sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale)) +end + +function Distributions._logpdf(q::MvLocationScale, z::AbstractVector{<:Real}) + @unpack location, scale, dist = q + sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale)) +end + +function Distributions.rand(q::MvLocationScale) + @unpack location, scale, dist = q + n_dims = length(location) + scale*rand(dist, n_dims) + location +end + +function Distributions.rand( + rng::AbstractRNG, q::MvLocationScale{S, D, L}, num_samples::Int +) where {S, D, L} + @unpack location, scale, dist = q + n_dims = length(location) + scale*rand(rng, dist, n_dims, num_samples) .+ location +end + +# This specialization improves AD performance of the sampling path +function Distributions.rand( + rng::AbstractRNG, q::MvLocationScale{<:Diagonal, D, L}, num_samples::Int +) where {L, D} + @unpack location, scale, dist = q + n_dims = length(location) + scale_diag = diag(scale) + scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location +end + +function Distributions._rand!(rng::AbstractRNG, q::MvLocationScale, x::AbstractVecOrMat{<:Real}) + @unpack location, scale, dist = q + rand!(rng, dist, x) + x[:] = scale*x + return x .+= location +end + +Distributions.mean(q::MvLocationScale) = q.location + +function Distributions.var(q::MvLocationScale) + C = q.scale + Diagonal(C*C') +end + +function Distributions.cov(q::MvLocationScale) + C = q.scale + Hermitian(C*C') +end + +""" + FullRankGaussian(location, scale; check_args = true) + +Construct a Gaussian variational approximation with a dense covariance matrix. + +# Arguments +- `location::AbstractVector{T}`: Mean of the Gaussian. +- `scale::LinearAlgebra.AbstractTriangular{T}`: Cholesky factor of the covariance of the Gaussian. + +# Keyword Arguments +- `check_args`: Check the conditioning of the initial scale (default: `true`). +""" +function FullRankGaussian( + μ::AbstractVector{T}, + L::LinearAlgebra.AbstractTriangular{T}; + check_args::Bool = true +) where {T <: Real} + @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be positive definite" + if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) + @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." + end + q_base = Normal{T}(zero(T), one(T)) + MvLocationScale(μ, L, q_base) +end + +""" + MeanFieldGaussian(location, scale; check_args = true) + +Construct a Gaussian variational approximation with a diagonal covariance matrix. + +# Arguments +- `location::AbstractVector{T}`: Mean of the Gaussian. +- `scale::Diagonal{T}`: Diagonal Cholesky factor of the covariance of the Gaussian. + +# Keyword Arguments +- `check_args`: Check the conditioning of the initial scale (default: `true`). +""" +function MeanFieldGaussian( + μ::AbstractVector{T}, + L::Diagonal{T}; + check_args::Bool = true +) where {T <: Real} + @assert minimum(diag(L)) > eps(eltype(L)) "Scale must be a Cholesky factor" + if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L)))) + @warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior." + end + q_base = Normal{T}(zero(T), one(T)) + MvLocationScale(μ, L, q_base) +end diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl new file mode 100644 index 000000000..8ac9d2caa --- /dev/null +++ b/test/inference/repgradelbo_locationscale.jl @@ -0,0 +1,82 @@ + +const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false + +using Test + +@testset "inference RepGradELBO VILocationScale" begin + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for + realtype ∈ [Float64, Float32], + (modelname, modelconstr) ∈ Dict( + :Normal=> normal_meanfield, + :Normal=> normal_fullrank, + ), + (objname, objective) ∈ Dict( + :RepGradELBOClosedFormEntropy => RepGradELBO(10), + :RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()), + ), + (adbackname, adbackend) ∈ Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + #:Enzyme => AutoEnzyme(), + ) + + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = modelconstr(rng, realtype) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) + + q0 = if is_meanfield + MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) + else + L0 = Matrix{realtype}(I, n_dims, n_dims) |> LowerTriangular + FullRankGaussian(zeros(realtype, n_dims), L0) + end + + @testset "convergence" begin + Δλ₀ = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) + q, stats, _ = optimize( + rng, model, objective, q0, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + + μ = q.location + L = q.scale + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ Δλ₀/T^(1/4) + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = StableRNG(seed) + q, stats, _ = optimize( + rng, model, objective, q0, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + μ = q.location + L = q.scale + + rng_repl = StableRNG(seed) + q, stats, _ = optimize( + rng_repl, model, objective, q0, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + μ_repl = q.location + L_repl = q.scale + @test μ == μ_repl + @test L == L_repl + end + end +end + diff --git a/test/inference/repgradelbo_distributionsad_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl similarity index 68% rename from test/inference/repgradelbo_distributionsad_bijectors.jl rename to test/inference/repgradelbo_locationscale_bijectors.jl index 9f1e3cc4a..7154440c1 100644 --- a/test/inference/repgradelbo_distributionsad_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -3,7 +3,7 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false using Test -@testset "inference RepGradELBO DistributionsAD Bijectors" begin +@testset "inference RepGradELBO VILocationScale Bijectors" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype ∈ [Float64, Float32], (modelname, modelconstr) ∈ Dict( @@ -15,7 +15,7 @@ using Test ), (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), - #:ReverseDiff => AutoReverseDiff(), + :ReverseDiff => AutoReverseDiff(), #:Zygote => AutoZygote(), #:Enzyme => AutoEnzyme(), ) @@ -30,23 +30,28 @@ using Test b = Bijectors.bijector(model) b⁻¹ = inverse(b) - μ₀ = Zeros(realtype, n_dims) - L₀ = Diagonal(Ones(realtype, n_dims)) + μ0 = Zeros(realtype, n_dims) + L0 = Diagonal(Ones(realtype, n_dims)) - q₀_η = TuringDiagMvNormal(μ₀, diag(L₀)) - q₀_z = Bijectors.transformed(q₀_η, b⁻¹) + q0_η = if is_meanfield + MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) + else + L0 = Matrix{realtype}(I, n_dims, n_dims) |> LowerTriangular + FullRankGaussian(zeros(realtype, n_dims), L0) + end + q0_z = Bijectors.transformed(q0_η, b⁻¹) @testset "convergence" begin - Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) q, stats, _ = optimize( - rng, model, objective, q₀_z, T; + rng, model, objective, q0_z, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, adbackend = adbackend, ) - μ = mean(q.dist) - L = sqrt(cov(q.dist)) + μ = q.dist.location + L = q.dist.scale Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) @test Δλ ≤ Δλ₀/T^(1/4) @@ -57,23 +62,23 @@ using Test @testset "determinism" begin rng = StableRNG(seed) q, stats, _ = optimize( - rng, model, objective, q₀_z, T; + rng, model, objective, q0_z, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, adbackend = adbackend, ) - μ = mean(q.dist) - L = sqrt(cov(q.dist)) + μ = q.dist.location + L = q.dist.scale rng_repl = StableRNG(seed) q, stats, _ = optimize( - rng_repl, model, objective, q₀_z, T; + rng_repl, model, objective, q0_z, T; optimizer = Optimisers.Adam(realtype(η)), show_progress = PROGRESS, adbackend = adbackend, ) - μ_repl = mean(q.dist) - L_repl = sqrt(cov(q.dist)) + μ_repl = q.dist.location + L_repl = q.dist.scale @test μ == μ_repl @test L == L_repl end diff --git a/test/interface/location_scale.jl b/test/interface/location_scale.jl new file mode 100644 index 000000000..6670f5c2c --- /dev/null +++ b/test/interface/location_scale.jl @@ -0,0 +1,140 @@ + +@testset "interface LocationScale" begin + @testset "$(string(covtype)) $(basedist) $(realtype)" for + basedist = [:gaussian], + covtype = [:meanfield, :fullrank], + realtype = [Float32, Float64] + + n_dims = 10 + n_montecarlo = 1000_000 + + μ = randn(realtype, n_dims) + L = if covtype == :fullrank + tril(I + ones(realtype, n_dims, n_dims)/2) |> LowerTriangular + else + Diagonal(ones(realtype, n_dims)) + end + Σ = L*L' + + q = if covtype == :fullrank && basedist == :gaussian + FullRankGaussian(μ, L) + elseif covtype == :meanfield && basedist == :gaussian + MeanFieldGaussian(μ, L) + end + q_true = if basedist == :gaussian + MvNormal(μ, Σ) + end + + @testset "eltype" begin + @test eltype(q) == realtype + end + + @testset "logpdf" begin + z = rand(q) + @test logpdf(q, z) ≈ logpdf(q_true, z) rtol=realtype(1e-2) + @test eltype(logpdf(q, z)) == realtype + end + + @testset "_logpdf" begin + z = rand(q) + @test Distributions._logpdf(q, z) ≈ logpdf(q_true, z) rtol=realtype(1e-2) + @test eltype(Distributions.logpdf(q, z)) == realtype + end + + @testset "entropy" begin + @test eltype(entropy(q)) == realtype + @test entropy(q) ≈ entropy(q_true) + end + + @testset "length" begin + @test length(q) == n_dims + end + + @testset "statistics" begin + @testset "mean" begin + @test eltype(mean(q)) == realtype + @test mean(q) == μ + end + @testset "var" begin + @test eltype(var(q)) == realtype + @test var(q) ≈ Diagonal(Σ) + end + @testset "cov" begin + @test eltype(cov(q)) == realtype + @test cov(q) ≈ Σ + end + end + + @testset "sampling" begin + @testset "rand" begin + z_samples = mapreduce(x -> rand(q), hcat, 1:n_montecarlo) + @test eltype(z_samples) == realtype + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + + z_sample_ref = rand(StableRNG(1), q) + @test z_sample_ref == rand(StableRNG(1), q) + end + + @testset "rand batch" begin + z_samples = rand(q, n_montecarlo) + @test eltype(z_samples) == realtype + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + + samples_ref = rand(StableRNG(1), q, n_montecarlo) + @test samples_ref == rand(StableRNG(1), q, n_montecarlo) + end + + @testset "rand! AbstractVector" begin + res = map(1:n_montecarlo) do _ + z_sample = Array{realtype}(undef, n_dims) + z_sample_ret = rand!(q, z_sample) + (z_sample, z_sample_ret) + end + z_samples = mapreduce(first, hcat, res) + z_samples_ret = mapreduce(last, hcat, res) + @test z_samples == z_samples_ret + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + + z_sample_ref = Array{realtype}(undef, n_dims) + rand!(StableRNG(1), q, z_sample_ref) + + z_sample = Array{realtype}(undef, n_dims) + rand!(StableRNG(1), q, z_sample) + @test z_sample_ref == z_sample + end + + @testset "rand! AbstractMatrix" begin + z_samples = Array{realtype}(undef, n_dims, n_montecarlo) + z_samples_ret = rand!(q, z_samples) + @test z_samples == z_samples_ret + @test dropdims(mean(z_samples, dims=2), dims=2) ≈ μ rtol=realtype(1e-2) + @test dropdims(var(z_samples, dims=2), dims=2) ≈ diag(Σ) rtol=realtype(1e-2) + @test cov(z_samples, dims=2) ≈ Σ rtol=realtype(1e-2) + + z_samples_ref = Array{realtype}(undef, n_dims, n_montecarlo) + rand!(StableRNG(1), q, z_samples_ref) + + z_samples = Array{realtype}(undef, n_dims, n_montecarlo) + rand!(StableRNG(1), q, z_samples) + @test z_samples_ref == z_samples + end + end + end + + @testset "Diagonal destructure" begin + n_dims = 10 + μ = zeros(n_dims) + L = ones(n_dims) + q = MeanFieldGaussian(μ, L |> Diagonal) + λ, re = Optimisers.destructure(q) + + @test length(λ) == 2*n_dims + @test q == re(λ) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index b14b8b2ed..8e540a088 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -37,6 +37,8 @@ include("models/normallognormal.jl") include("interface/ad.jl") include("interface/optimize.jl") include("interface/repgradelbo.jl") +include("interface/location_scale.jl") include("inference/repgradelbo_distributionsad.jl") -include("inference/repgradelbo_distributionsad_bijectors.jl") +include("inference/repgradelbo_locationscale.jl") +include("inference/repgradelbo_locationscale_bijectors.jl")