Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix bug in reparameterization with Bijectors.TransformedDistribution #52

Merged
merged 2 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 16 additions & 14 deletions ext/AdvancedVIBijectorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,30 @@ function AdvancedVI.reparam_with_entropy(
n_samples::Int,
ent_est ::AdvancedVI.AbstractEntropyEstimator
)
transform = q.transform
q_base = q.dist
q_base_stop = q_stop.dist
base_samples = rand(rng, q_base, n_samples)
it = AdvancedVI.eachsample(base_samples)
sample_init = first(it)
transform = q.transform
q_unconst = q.dist
q_unconst_stop = q_stop.dist

# Draw samples and compute entropy of the uncontrained distribution
unconstr_samples, unconst_entropy = AdvancedVI.reparam_with_entropy(
rng, q_unconst, q_unconst_stop, n_samples, ent_est
)

# Apply bijector to samples while estimating its jacobian
unconstr_iter = AdvancedVI.eachsample(unconstr_samples)
unconstr_init = first(unconstr_iter)
samples_init, logjac_init = with_logabsdet_jacobian(transform, unconstr_init)
samples_and_logjac = mapreduce(
AdvancedVI.catsamples_and_acc,
Iterators.drop(it, 1);
init=with_logabsdet_jacobian(transform, sample_init)
Iterators.drop(unconstr_iter, 1);
init=(reshape(samples_init, (:,1)), logjac_init)
) do sample
with_logabsdet_jacobian(transform, sample)
end
samples = first(samples_and_logjac)
logjac = last(samples_and_logjac)

entropy_base = AdvancedVI.estimate_entropy_maybe_stl(
ent_est, base_samples, q_base, q_base_stop
)
logjac = last(samples_and_logjac)/n_samples

entropy = entropy_base + logjac/n_samples
entropy = unconst_entropy + logjac
samples, entropy
end
end
2 changes: 0 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ end

eachsample(samples::AbstractMatrix) = eachcol(samples)

eachsample(samples::AbstractVector) = samples

function catsamples_and_acc(
state_curr::Tuple{<:AbstractArray, <:Real},
state_new ::Tuple{<:AbstractVector, <:Real}
Expand Down
11 changes: 6 additions & 5 deletions test/inference/repgradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ using Test
(modelname, modelconstr) ∈ Dict(
:Normal=> normal_meanfield,
),
(objname, objective) ∈ Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
:RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
:RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()),
),
(adbackname, adbackend) ∈ Dict(
:ForwarDiff => AutoForwardDiff(),
Expand All @@ -33,7 +34,7 @@ using Test
q0 = TuringDiagMvNormal(μ0, diag(L0))

@testset "convergence" begin
Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
q, stats, _ = optimize(
rng, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
Expand All @@ -45,7 +46,7 @@ using Test
L = sqrt(cov(q))
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ ≤ Δλ₀/T^(1/4)
@test Δλ ≤ Δλ0/T^(1/4)
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end
Expand Down
17 changes: 9 additions & 8 deletions test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@ using Test

@testset "inference RepGradELBO VILocationScale" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
realtype [Float64, Float32],
(modelname, modelconstr) Dict(
realtype in [Float64, Float32],
(modelname, modelconstr) in Dict(
:Normal=> normal_meanfield,
:Normal=> normal_fullrank,
),
(objname, objective) ∈ Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
:RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
:RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()),
),
(adbackname, adbackend) Dict(
(adbackname, adbackend) in Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
Expand All @@ -37,7 +38,7 @@ using Test
end

@testset "convergence" begin
Δλ₀ = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true)
Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true)
q, stats, _ = optimize(
rng, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
Expand All @@ -49,7 +50,7 @@ using Test
L = q.scale
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ ≤ Δλ₀/T^(1/4)
@test Δλ ≤ Δλ0/T^(1/4)
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end
Expand Down
17 changes: 9 additions & 8 deletions test/inference/repgradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ using Test

@testset "inference RepGradELBO VILocationScale Bijectors" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
realtype [Float64, Float32],
(modelname, modelconstr) Dict(
realtype in [Float64, Float32],
(modelname, modelconstr) in Dict(
:NormalLogNormalMeanField => normallognormal_meanfield,
),
(objname, objective) ∈ Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
:RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
:RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()),
),
(adbackname, adbackend) Dict(
(adbackname, adbackend) in Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
#:Zygote => AutoZygote(),
Expand Down Expand Up @@ -42,7 +43,7 @@ using Test
q0_z = Bijectors.transformed(q0_η, b⁻¹)

@testset "convergence" begin
Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
q, stats, _ = optimize(
rng, model, objective, q0_z, T;
optimizer = Optimisers.Adam(realtype(η)),
Expand All @@ -54,7 +55,7 @@ using Test
L = q.dist.scale
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ ≤ Δλ₀/T^(1/4)
@test Δλ ≤ Δλ0/T^(1/4)
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end
Expand Down
Loading