Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic rewrite of the package 2023 edition Part II: Location-scale variational families #51

Merged
merged 9 commits into from
Dec 20, 2023
10 changes: 10 additions & 0 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,16 @@ export
include("objectives/elbo/entropy.jl")
include("objectives/elbo/repgradelbo.jl")


# Variational Families
export
VILocationScale,
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
MeanFieldGaussian,
FullRankGaussian

include("families/location_scale.jl")


# Optimization Routine

function optimize end
Expand Down
160 changes: 160 additions & 0 deletions src/families/location_scale.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@

"""
VILocationScale(location, scale, dist) <: ContinuousMultivariateDistribution

The location scale variational family broadly represents various variational
families using `location` and `scale` variational parameters.

It generally represents any distribution for which the sampling path can be
represented as follows:
```julia
d = length(location)
u = rand(dist, d)
z = scale*u + location
```
"""
struct VILocationScale{L, S, D} <: ContinuousMultivariateDistribution
location::L
scale ::S
dist ::D
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
end

Functors.@functor VILocationScale (location, scale)

# Specialization of `Optimisers.destructure` for mean-field location-scale families.
# These are necessary because we only want to extract the diagonal elements of
# `scale <: Diagonal`, which is not the default behavior. Otherwise, forward-mode AD
# is very inefficient.
# begin
struct RestructureMeanField{L, S<:Diagonal, D}
q::VILocationScale{L, S, D}
end

function (re::RestructureMeanField)(flat::AbstractVector)
n_dims = div(length(flat), 2)
location = first(flat, n_dims)
scale = Diagonal(last(flat, n_dims))
VILocationScale(location, scale, re.q.dist)
end

function Optimisers.destructure(
q::VILocationScale{L, <:Diagonal, D}
) where {L, D}
@unpack location, scale, dist = q
flat = vcat(location, diag(scale))
flat, RestructureMeanField(q)
end
# end

Base.length(q::VILocationScale) = length(q.location)

Base.size(q::VILocationScale) = size(q.location)

Base.eltype(::Type{<:VILocationScale{L, S, D}}) where {L, S, D} = eltype(D)

function StatsBase.entropy(q::VILocationScale)
@unpack location, scale, dist = q
n_dims = length(location)
n_dims*convert(eltype(location), entropy(dist)) + first(logabsdet(scale))
end

function Distributions.logpdf(q::VILocationScale, z::AbstractVector{<:Real})
@unpack location, scale, dist = q
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale))
end

function Distributions._logpdf(q::VILocationScale, z::AbstractVector{<:Real})
@unpack location, scale, dist = q
sum(Base.Fix1(logpdf, dist), scale \ (z - location)) - first(logabsdet(scale))
end

function Distributions.rand(q::VILocationScale)
@unpack location, scale, dist = q
n_dims = length(location)
scale*rand(dist, n_dims) + location
end

function Distributions.rand(rng::AbstractRNG, q::VILocationScale, num_samples::Int)
@unpack location, scale, dist = q
n_dims = length(location)
scale*rand(rng, dist, n_dims, num_samples) .+ location
end

# This specialization improves AD performance of the sampling path
function Distributions.rand(
rng::AbstractRNG, q::VILocationScale{L, <:Diagonal, D}, num_samples::Int
) where {L, D}
@unpack location, scale, dist = q
n_dims = length(location)
scale_diag = diag(scale)
scale_diag.*rand(rng, dist, n_dims, num_samples) .+ location
end

function Distributions._rand!(rng::AbstractRNG, q::VILocationScale, x::AbstractVecOrMat{<:Real})
@unpack location, scale, dist = q
rand!(rng, dist, x)
x[:] = scale*x
return x .+= location
end

Distributions.mean(q::VILocationScale) = q.location

function Distributions.var(q::VILocationScale)
C = q.scale
Diagonal(C*C')
end

function Distributions.cov(q::VILocationScale)
C = q.scale
Hermitian(C*C')
end

"""
FullRankGaussian(location, scale; check_args = true)

Construct a Gaussian variational approximation with a dense covariance matrix.

# Arguments
- `location::AbstractVector{T}`: Mean of the Gaussian.
- `scale::LinearAlgebra.AbstractTriangular{T}`: Cholesky factor of the covariance of the Gaussian.

# Keyword Arguments
- `check_args`: Check the conditioning of the initial scale (default: `true`).
"""
function FullRankGaussian(
μ::AbstractVector{T},
L::LinearAlgebra.AbstractTriangular{T};
check_args::Bool = true
) where {T <: Real}
@assert minimum(diag(L)) > eps(eltype(L)) "Scale must be positive definite"
if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
@warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."

Check warning on line 131 in src/families/location_scale.jl

View check run for this annotation

Codecov / codecov/patch

src/families/location_scale.jl#L131

Added line #L131 was not covered by tests
end
q_base = Normal{T}(zero(T), one(T))
VILocationScale(μ, L, q_base)
end

"""
MeanFieldGaussian(location, scale; check_args = true)

Construct a Gaussian variational approximation with a diagonal covariance matrix.

# Arguments
- `location::AbstractVector{T}`: Mean of the Gaussian.
- `scale::Diagonal{T}`: Diagonal Cholesky factor of the covariance of the Gaussian.

# Keyword Arguments
- `check_args`: Check the conditioning of the initial scale (default: `true`).
"""
function MeanFieldGaussian(
μ::AbstractVector{T},
L::Diagonal{T};
check_args::Bool = true
) where {T <: Real}
@assert minimum(diag(L)) > eps(eltype(L)) "Scale must be a Cholesky factor"
if check_args && (minimum(diag(L)) < sqrt(eps(eltype(L))))
@warn "Initial scale is too small (minimum eigenvalue is $(minimum(diag(L)))). This might result in unstable optimization behavior."

Check warning on line 156 in src/families/location_scale.jl

View check run for this annotation

Codecov / codecov/patch

src/families/location_scale.jl#L156

Added line #L156 was not covered by tests
end
q_base = Normal{T}(zero(T), one(T))
VILocationScale(μ, L, q_base)
end
82 changes: 82 additions & 0 deletions test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@

const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false

using Test

@testset "inference RepGradELBO VILocationScale" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
realtype ∈ [Float64, Float32],
(modelname, modelconstr) ∈ Dict(
:Normal=> normal_meanfield,
:Normal=> normal_fullrank,
),
(objname, objective) ∈ Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
:RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
),
(adbackname, adbackend) ∈ Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
#:Enzyme => AutoEnzyme(),
)

seed = (0x38bef07cf9cc549d)
rng = StableRNG(seed)

modelstats = modelconstr(rng, realtype)
@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats

T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)

q0 = if is_meanfield
MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims)))
else
L0 = Matrix{realtype}(I, n_dims, n_dims) |> LowerTriangular
FullRankGaussian(zeros(realtype, n_dims), L0)
end

@testset "convergence" begin
Δλ₀ = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true)
q, stats, _ = optimize(
rng, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)

μ = q.location
L = q.scale
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ ≤ Δλ₀/T^(1/4)
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end

@testset "determinism" begin
rng = StableRNG(seed)
q, stats, _ = optimize(
rng, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)
μ = q.location
L = q.scale

rng_repl = StableRNG(seed)
q, stats, _ = optimize(
rng_repl, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)
μ_repl = q.location
L_repl = q.scale
@test μ == μ_repl
@test L == L_repl
end
end
end

Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false

using Test

@testset "inference RepGradELBO DistributionsAD Bijectors" begin
@testset "inference RepGradELBO VILocationScale Bijectors" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
realtype ∈ [Float64, Float32],
(modelname, modelconstr) ∈ Dict(
Expand All @@ -15,7 +15,7 @@ using Test
),
(adbackname, adbackend) ∈ Dict(
:ForwarDiff => AutoForwardDiff(),
#:ReverseDiff => AutoReverseDiff(),
:ReverseDiff => AutoReverseDiff(),
#:Zygote => AutoZygote(),
#:Enzyme => AutoEnzyme(),
)
Expand All @@ -30,23 +30,28 @@ using Test

b = Bijectors.bijector(model)
b⁻¹ = inverse(b)
μ₀ = Zeros(realtype, n_dims)
L₀ = Diagonal(Ones(realtype, n_dims))
μ0 = Zeros(realtype, n_dims)
L0 = Diagonal(Ones(realtype, n_dims))

q₀_η = TuringDiagMvNormal(μ₀, diag(L₀))
q₀_z = Bijectors.transformed(q₀_η, b⁻¹)
q0_η = if is_meanfield
MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims)))
else
L0 = Matrix{realtype}(I, n_dims, n_dims) |> LowerTriangular
FullRankGaussian(zeros(realtype, n_dims), L0)
end
q0_z = Bijectors.transformed(q0_η, b⁻¹)

@testset "convergence" begin
Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true)
Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
q, stats, _ = optimize(
rng, model, objective, q₀_z, T;
rng, model, objective, q0_z, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)

μ = mean(q.dist)
L = sqrt(cov(q.dist))
μ = q.dist.location
L = q.dist.scale
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ ≤ Δλ₀/T^(1/4)
Expand All @@ -57,23 +62,23 @@ using Test
@testset "determinism" begin
rng = StableRNG(seed)
q, stats, _ = optimize(
rng, model, objective, q₀_z, T;
rng, model, objective, q0_z, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)
μ = mean(q.dist)
L = sqrt(cov(q.dist))
μ = q.dist.location
L = q.dist.scale

rng_repl = StableRNG(seed)
q, stats, _ = optimize(
rng_repl, model, objective, q₀_z, T;
rng_repl, model, objective, q0_z, T;
optimizer = Optimisers.Adam(realtype(η)),
show_progress = PROGRESS,
adbackend = adbackend,
)
μ_repl = mean(q.dist)
L_repl = sqrt(cov(q.dist))
μ_repl = q.dist.location
L_repl = q.dist.scale
@test μ == μ_repl
@test L == L_repl
end
Expand Down
Loading
Loading