Skip to content

Commit

Permalink
fix bug in reparameterization with Bijectors.TransformedDistribution (
Browse files Browse the repository at this point in the history
#52)

* fix bug for bijector with 1 MC sample with tests

* fix remove redundant helpers for `reparam_with_entropy` for bijector
  • Loading branch information
Red-Portal authored Jan 3, 2024
1 parent 9ebfc3f commit 2f6dc4f
Show file tree
Hide file tree
Showing 5 changed files with 40 additions and 37 deletions.
30 changes: 16 additions & 14 deletions ext/AdvancedVIBijectorsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,28 +18,30 @@ function AdvancedVI.reparam_with_entropy(
n_samples::Int,
ent_est ::AdvancedVI.AbstractEntropyEstimator
)
transform = q.transform
q_base = q.dist
q_base_stop = q_stop.dist
base_samples = rand(rng, q_base, n_samples)
it = AdvancedVI.eachsample(base_samples)
sample_init = first(it)
transform = q.transform
q_unconst = q.dist
q_unconst_stop = q_stop.dist

# Draw samples and compute entropy of the uncontrained distribution
unconstr_samples, unconst_entropy = AdvancedVI.reparam_with_entropy(
rng, q_unconst, q_unconst_stop, n_samples, ent_est
)

# Apply bijector to samples while estimating its jacobian
unconstr_iter = AdvancedVI.eachsample(unconstr_samples)
unconstr_init = first(unconstr_iter)
samples_init, logjac_init = with_logabsdet_jacobian(transform, unconstr_init)
samples_and_logjac = mapreduce(
AdvancedVI.catsamples_and_acc,
Iterators.drop(it, 1);
init=with_logabsdet_jacobian(transform, sample_init)
Iterators.drop(unconstr_iter, 1);
init=(reshape(samples_init, (:,1)), logjac_init)
) do sample
with_logabsdet_jacobian(transform, sample)
end
samples = first(samples_and_logjac)
logjac = last(samples_and_logjac)

entropy_base = AdvancedVI.estimate_entropy_maybe_stl(
ent_est, base_samples, q_base, q_base_stop
)
logjac = last(samples_and_logjac)/n_samples

entropy = entropy_base + logjac/n_samples
entropy = unconst_entropy + logjac
samples, entropy
end
end
2 changes: 0 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ end

eachsample(samples::AbstractMatrix) = eachcol(samples)

eachsample(samples::AbstractVector) = samples

function catsamples_and_acc(
state_curr::Tuple{<:AbstractArray, <:Real},
state_new ::Tuple{<:AbstractVector, <:Real}
Expand Down
11 changes: 6 additions & 5 deletions test/inference/repgradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@ using Test
(modelname, modelconstr) Dict(
:Normal=> normal_meanfield,
),
(objname, objective) Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
:RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
:RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()),
),
(adbackname, adbackend) Dict(
:ForwarDiff => AutoForwardDiff(),
Expand All @@ -33,7 +34,7 @@ using Test
q0 = TuringDiagMvNormal(μ0, diag(L0))

@testset "convergence" begin
Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
q, stats, _ = optimize(
rng, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
Expand All @@ -45,7 +46,7 @@ using Test
L = sqrt(cov(q))
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ Δλ₀/T^(1/4)
@test Δλ Δλ0/T^(1/4)
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end
Expand Down
17 changes: 9 additions & 8 deletions test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,17 @@ using Test

@testset "inference RepGradELBO VILocationScale" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
realtype [Float64, Float32],
(modelname, modelconstr) Dict(
realtype in [Float64, Float32],
(modelname, modelconstr) in Dict(
:Normal=> normal_meanfield,
:Normal=> normal_fullrank,
),
(objname, objective) Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
:RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
:RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()),
),
(adbackname, adbackend) Dict(
(adbackname, adbackend) in Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
Expand All @@ -37,7 +38,7 @@ using Test
end

@testset "convergence" begin
Δλ₀ = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true)
Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true)
q, stats, _ = optimize(
rng, model, objective, q0, T;
optimizer = Optimisers.Adam(realtype(η)),
Expand All @@ -49,7 +50,7 @@ using Test
L = q.scale
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ Δλ₀/T^(1/4)
@test Δλ Δλ0/T^(1/4)
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end
Expand Down
17 changes: 9 additions & 8 deletions test/inference/repgradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,16 @@ using Test

@testset "inference RepGradELBO VILocationScale Bijectors" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for
realtype [Float64, Float32],
(modelname, modelconstr) Dict(
realtype in [Float64, Float32],
(modelname, modelconstr) in Dict(
:NormalLogNormalMeanField => normallognormal_meanfield,
),
(objname, objective) Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(10),
:RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()),
n_montecarlo in [1, 10],
(objname, objective) in Dict(
:RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo),
:RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()),
),
(adbackname, adbackend) Dict(
(adbackname, adbackend) in Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
#:Zygote => AutoZygote(),
Expand Down Expand Up @@ -42,7 +43,7 @@ using Test
q0_z = Bijectors.transformed(q0_η, b⁻¹)

@testset "convergence" begin
Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true)
q, stats, _ = optimize(
rng, model, objective, q0_z, T;
optimizer = Optimisers.Adam(realtype(η)),
Expand All @@ -54,7 +55,7 @@ using Test
L = q.dist.scale
Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true)

@test Δλ Δλ₀/T^(1/4)
@test Δλ Δλ0/T^(1/4)
@test eltype(μ) == eltype(μ_true)
@test eltype(L) == eltype(L_true)
end
Expand Down

0 comments on commit 2f6dc4f

Please sign in to comment.