From 05dbb51001999e5c71d624f4e2b5cd052049f033 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 22 Dec 2023 03:53:48 -0500 Subject: [PATCH 1/2] fix bug for bijector with 1 MC sample with tests --- ext/AdvancedVIBijectorsExt.jl | 48 +++++++++++-------- src/utils.jl | 4 ++ test/inference/repgradelbo_distributionsad.jl | 11 +++-- test/inference/repgradelbo_locationscale.jl | 17 +++---- .../repgradelbo_locationscale_bijectors.jl | 17 +++---- 5 files changed, 57 insertions(+), 40 deletions(-) diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl index 1b200ac5..f1b8cd23 100644 --- a/ext/AdvancedVIBijectorsExt.jl +++ b/ext/AdvancedVIBijectorsExt.jl @@ -11,35 +11,45 @@ else using ..Random end -function AdvancedVI.reparam_with_entropy( - rng ::Random.AbstractRNG, - q ::Bijectors.TransformedDistribution, - q_stop ::Bijectors.TransformedDistribution, - n_samples::Int, - ent_est ::AdvancedVI.AbstractEntropyEstimator -) - transform = q.transform - q_base = q.dist - q_base_stop = q_stop.dist - base_samples = rand(rng, q_base, n_samples) - it = AdvancedVI.eachsample(base_samples) - sample_init = first(it) +function transform_samples_with_jacobian(unconst_samples, transform, n_samples) + unconst_iter = AdvancedVI.eachsample(unconst_samples) + unconst_init = first(unconst_iter) + + samples_init, logjac_init = with_logabsdet_jacobian(transform, unconst_init) samples_and_logjac = mapreduce( AdvancedVI.catsamples_and_acc, - Iterators.drop(it, 1); - init=with_logabsdet_jacobian(transform, sample_init) + Iterators.drop(unconst_iter, 1); + init=(AdvancedVI.samples_expand_dim(samples_init), logjac_init) ) do sample with_logabsdet_jacobian(transform, sample) end samples = first(samples_and_logjac) - logjac = last(samples_and_logjac) + logjac = last(samples_and_logjac)/n_samples + samples, logjac +end - entropy_base = AdvancedVI.estimate_entropy_maybe_stl( - ent_est, base_samples, q_base, q_base_stop +function AdvancedVI.reparam_with_entropy( + rng ::Random.AbstractRNG, + q ::Bijectors.TransformedDistribution, + q_stop ::Bijectors.TransformedDistribution, + n_samples::Int, + ent_est ::AdvancedVI.AbstractEntropyEstimator +) + transform = q.transform + q_unconst = q.dist + q_unconst_stop = q_stop.dist + + # Draw samples and compute entropy of the uncontrained distribution + unconst_samples, unconst_entropy = AdvancedVI.reparam_with_entropy( + rng, q_unconst, q_unconst_stop, n_samples, ent_est ) - entropy = entropy_base + logjac/n_samples + # Apply bijector to samples while estimating its jacobian + samples, logjac = transform_samples_with_jacobian( + unconst_samples, transform, n_samples + ) + entropy = unconst_entropy + logjac samples, entropy end end diff --git a/src/utils.jl b/src/utils.jl index 8e67ff1a..7d740f60 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -34,3 +34,7 @@ function catsamples_and_acc( return (x, ∑y) end +function samples_expand_dim(x::AbstractVector) + reshape(x, (:,1)) +end + diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index b6db22a6..53105a4a 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -9,9 +9,10 @@ using Test (modelname, modelconstr) ∈ Dict( :Normal=> normal_meanfield, ), - (objname, objective) ∈ Dict( - :RepGradELBOClosedFormEntropy => RepGradELBO(10), - :RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()), + n_montecarlo in [1, 10], + (objname, objective) in Dict( + :RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo), + :RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()), ), (adbackname, adbackend) ∈ Dict( :ForwarDiff => AutoForwardDiff(), @@ -33,7 +34,7 @@ using Test q0 = TuringDiagMvNormal(μ0, diag(L0)) @testset "convergence" begin - Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) + Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) q, stats, _ = optimize( rng, model, objective, q0, T; optimizer = Optimisers.Adam(realtype(η)), @@ -45,7 +46,7 @@ using Test L = sqrt(cov(q)) Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - @test Δλ ≤ Δλ₀/T^(1/4) + @test Δλ ≤ Δλ0/T^(1/4) @test eltype(μ) == eltype(μ_true) @test eltype(L) == eltype(L_true) end diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index 8ac9d2ca..1a200474 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -5,16 +5,17 @@ using Test @testset "inference RepGradELBO VILocationScale" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for - realtype ∈ [Float64, Float32], - (modelname, modelconstr) ∈ Dict( + realtype in [Float64, Float32], + (modelname, modelconstr) in Dict( :Normal=> normal_meanfield, :Normal=> normal_fullrank, ), - (objname, objective) ∈ Dict( - :RepGradELBOClosedFormEntropy => RepGradELBO(10), - :RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()), + n_montecarlo in [1, 10], + (objname, objective) in Dict( + :RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo), + :RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()), ), - (adbackname, adbackend) ∈ Dict( + (adbackname, adbackend) in Dict( :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), @@ -37,7 +38,7 @@ using Test end @testset "convergence" begin - Δλ₀ = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) + Δλ0 = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) q, stats, _ = optimize( rng, model, objective, q0, T; optimizer = Optimisers.Adam(realtype(η)), @@ -49,7 +50,7 @@ using Test L = q.scale Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - @test Δλ ≤ Δλ₀/T^(1/4) + @test Δλ ≤ Δλ0/T^(1/4) @test eltype(μ) == eltype(μ_true) @test eltype(L) == eltype(L_true) end diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 7154440c..41a4d740 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -5,15 +5,16 @@ using Test @testset "inference RepGradELBO VILocationScale Bijectors" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for - realtype ∈ [Float64, Float32], - (modelname, modelconstr) ∈ Dict( + realtype in [Float64, Float32], + (modelname, modelconstr) in Dict( :NormalLogNormalMeanField => normallognormal_meanfield, ), - (objname, objective) ∈ Dict( - :RepGradELBOClosedFormEntropy => RepGradELBO(10), - :RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()), + n_montecarlo in [1, 10], + (objname, objective) in Dict( + :RepGradELBOClosedFormEntropy => RepGradELBO(n_montecarlo), + :RepGradELBOStickingTheLanding => RepGradELBO(n_montecarlo, entropy = StickingTheLandingEntropy()), ), - (adbackname, adbackend) ∈ Dict( + (adbackname, adbackend) in Dict( :ForwarDiff => AutoForwardDiff(), :ReverseDiff => AutoReverseDiff(), #:Zygote => AutoZygote(), @@ -42,7 +43,7 @@ using Test q0_z = Bijectors.transformed(q0_η, b⁻¹) @testset "convergence" begin - Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) + Δλ0 = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) q, stats, _ = optimize( rng, model, objective, q0_z, T; optimizer = Optimisers.Adam(realtype(η)), @@ -54,7 +55,7 @@ using Test L = q.dist.scale Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) - @test Δλ ≤ Δλ₀/T^(1/4) + @test Δλ ≤ Δλ0/T^(1/4) @test eltype(μ) == eltype(μ_true) @test eltype(L) == eltype(L_true) end From 31db7bc35a708a327893559d213d3096122e3479 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Wed, 3 Jan 2024 00:47:44 -0500 Subject: [PATCH 2/2] fix remove redundant helpers for `reparam_with_entropy` for bijector --- ext/AdvancedVIBijectorsExt.jl | 36 ++++++++++++++--------------------- src/utils.jl | 6 ------ 2 files changed, 14 insertions(+), 28 deletions(-) diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl index f1b8cd23..29a877fd 100644 --- a/ext/AdvancedVIBijectorsExt.jl +++ b/ext/AdvancedVIBijectorsExt.jl @@ -11,24 +11,6 @@ else using ..Random end -function transform_samples_with_jacobian(unconst_samples, transform, n_samples) - unconst_iter = AdvancedVI.eachsample(unconst_samples) - unconst_init = first(unconst_iter) - - samples_init, logjac_init = with_logabsdet_jacobian(transform, unconst_init) - - samples_and_logjac = mapreduce( - AdvancedVI.catsamples_and_acc, - Iterators.drop(unconst_iter, 1); - init=(AdvancedVI.samples_expand_dim(samples_init), logjac_init) - ) do sample - with_logabsdet_jacobian(transform, sample) - end - samples = first(samples_and_logjac) - logjac = last(samples_and_logjac)/n_samples - samples, logjac -end - function AdvancedVI.reparam_with_entropy( rng ::Random.AbstractRNG, q ::Bijectors.TransformedDistribution, @@ -41,14 +23,24 @@ function AdvancedVI.reparam_with_entropy( q_unconst_stop = q_stop.dist # Draw samples and compute entropy of the uncontrained distribution - unconst_samples, unconst_entropy = AdvancedVI.reparam_with_entropy( + unconstr_samples, unconst_entropy = AdvancedVI.reparam_with_entropy( rng, q_unconst, q_unconst_stop, n_samples, ent_est ) # Apply bijector to samples while estimating its jacobian - samples, logjac = transform_samples_with_jacobian( - unconst_samples, transform, n_samples - ) + unconstr_iter = AdvancedVI.eachsample(unconstr_samples) + unconstr_init = first(unconstr_iter) + samples_init, logjac_init = with_logabsdet_jacobian(transform, unconstr_init) + samples_and_logjac = mapreduce( + AdvancedVI.catsamples_and_acc, + Iterators.drop(unconstr_iter, 1); + init=(reshape(samples_init, (:,1)), logjac_init) + ) do sample + with_logabsdet_jacobian(transform, sample) + end + samples = first(samples_and_logjac) + logjac = last(samples_and_logjac)/n_samples + entropy = unconst_entropy + logjac samples, entropy end diff --git a/src/utils.jl b/src/utils.jl index 7d740f60..92b5686f 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -23,8 +23,6 @@ end eachsample(samples::AbstractMatrix) = eachcol(samples) -eachsample(samples::AbstractVector) = samples - function catsamples_and_acc( state_curr::Tuple{<:AbstractArray, <:Real}, state_new ::Tuple{<:AbstractVector, <:Real} @@ -34,7 +32,3 @@ function catsamples_and_acc( return (x, ∑y) end -function samples_expand_dim(x::AbstractVector) - reshape(x, (:,1)) -end -