Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Basic rewrite of the package 2023 edition #45

Closed
wants to merge 147 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
Show all changes
147 commits
Select commit Hold shift + click to select a range
b49cf3e
refactor ADVI, change gradient operation interface
Red-Portal Mar 14, 2023
88e0b79
remove unused file, remove unused dependency
Red-Portal Mar 14, 2023
c2fb3f8
fix ADVI elbo computation more efficiently
Red-Portal Mar 15, 2023
83161fd
fix missing entropy regularization term
Red-Portal Mar 15, 2023
efa8106
add LogDensityProblem interface
Red-Portal Mar 18, 2023
4ae2fbf
refactor use bijectors directly instead of transformed distributions
Red-Portal Mar 18, 2023
2bf2a42
Merge branch 'master' of https://github.com/TuringLang/AdvancedVI.jl …
Red-Portal Jun 7, 2023
1cadb51
fix type restrictions
Red-Portal Jun 7, 2023
3474e8d
remove unused file
Red-Portal Jun 7, 2023
03a2767
fix use of with_logabsdet_jacobian
Red-Portal Jun 8, 2023
09c44fb
restructure project; move the main VI routine to its own file
Red-Portal Jun 8, 2023
b7407ce
remove redundant import
Red-Portal Jun 8, 2023
4040149
restructure project into more modular objective estimators
Red-Portal Jun 8, 2023
2a4514e
migrate to AbstractDifferentiation
Red-Portal Jun 9, 2023
93a16d8
add location scale pre-packaged variational family, add functors
Red-Portal Jun 9, 2023
2b6e9eb
Revert "migrate to AbstractDifferentiation"
Red-Portal Jun 10, 2023
1bfec36
fix use optimized MvNormal specialization, add logpdf for Loc.Scale.
Red-Portal Jun 10, 2023
1003606
remove dead code
Red-Portal Jun 10, 2023
60a9987
fix location-scale logpdf
Red-Portal Jun 10, 2023
cd84f02
add sticking-the-landing (STL) estimator
Red-Portal Jun 10, 2023
768641b
migrate to Optimisers.jl
Red-Portal Jun 10, 2023
ca02fa3
remove execution time measurement (replace later with somethin else)
Red-Portal Jun 10, 2023
a48377f
fix use multiple dispatch for deciding whether to stop entropy grad.
Red-Portal Jun 12, 2023
0b40ccf
add termination decision, callback arguments
Red-Portal Jun 12, 2023
21db3fb
add Base.show to modules
Red-Portal Jun 12, 2023
25c51b4
add interface calling `restructure`, rename rebuild -> restructure
Red-Portal Jun 12, 2023
fc20046
add estimator state interface, add control variate interface to ADVI
Red-Portal Jun 12, 2023
6faa807
fix `show(advi)` to show control variate
Red-Portal Jun 12, 2023
7095d27
fix simplify `show(advi.control_variate)`
Red-Portal Jun 12, 2023
9169ae2
fix type piracy by wrapping location-scale bijected distribution
Red-Portal Jun 12, 2023
3db7301
remove old AdvancedVI custom optimizers
Red-Portal Jun 26, 2023
e6a082a
fix Location Scale to not depend on Bijectors
Red-Portal Jun 26, 2023
a034ebd
fix RNG namespace
Red-Portal Jul 12, 2023
e19abd3
fix location scale logpdf bug
Red-Portal Jul 13, 2023
680c186
add Accessors dependency
Red-Portal Jul 13, 2023
6c3efa8
Merge branch 'master' of https://github.com/TuringLang/AdvancedVI.jl …
Red-Portal Jul 13, 2023
4c6cabf
add location scale, autodiff tests
Red-Portal Jul 13, 2023
06db2f0
add Accessors import statement
Red-Portal Jul 13, 2023
12de2bd
remove optimiser tests
Red-Portal Jul 13, 2023
bbb2cc6
refactor slightly generalize the distribution tests for the future
Red-Portal Jul 13, 2023
1974846
migrate to SimpleUnPack, migrate to ADTypes
Red-Portal Jul 13, 2023
19c62c8
rename vi.jl to optimize.jl
Red-Portal Jul 13, 2023
63da51d
fix estimate_gradient to use adtypes
Red-Portal Jul 13, 2023
65ab473
add exact inference tests
Red-Portal Jul 13, 2023
3e5a452
remove Turing dependency in tests
Red-Portal Jul 13, 2023
3117cec
remove unused projection
Red-Portal Jul 14, 2023
b1ca9cf
remove redundant `ADVIEnergy` object (now baked into `ADVI`)
Red-Portal Jul 14, 2023
fcbb729
add more tests, fix rng seed for tests
Red-Portal Jul 14, 2023
0f6f6a4
add more tests, fix seed for tests
Red-Portal Jul 14, 2023
f5f5863
fix non-determinism bug
Red-Portal Jul 14, 2023
ade0d10
fix test hyperparameters so that tests pass, minor cleanups
Red-Portal Jul 14, 2023
0caf7a9
fix minor reorganization
Red-Portal Jul 14, 2023
5658cbf
add missing files
Red-Portal Jul 14, 2023
c712a97
fix add missing file, rename adbackend argument
Red-Portal Jul 14, 2023
bee839d
fix errors
Red-Portal Jul 14, 2023
913911e
rename test suite
Red-Portal Jul 14, 2023
d50cabb
refactor renamed arguments for ADVI to be shorter
Red-Portal Jul 15, 2023
b134f70
fix compile error in advi test
Red-Portal Jul 15, 2023
a6ba379
add initial doc
Red-Portal Jul 15, 2023
619b1c0
remove unused epsilon argument in location scale
Red-Portal Jul 15, 2023
f1c02f0
add project file for documenter
Red-Portal Jul 15, 2023
b0f259a
refactor STL gradient calculation to use multiple dispatch
Red-Portal Jul 16, 2023
b72c258
fix type bugs, relax test threshold for the exact inference tests
Red-Portal Jul 16, 2023
a8df9eb
refactor derivative utils to match NormalizingFlows.jl with extras
Red-Portal Aug 13, 2023
e8db6a7
add documentation, refactor optimize
Red-Portal Aug 13, 2023
65a2b37
fix bug missing extension
Red-Portal Aug 13, 2023
1a02051
remove tracker from tests
Red-Portal Aug 13, 2023
d8b5ea5
remove export for internal derivative utils
Red-Portal Aug 13, 2023
818bc2c
fix test errors, old interface
Red-Portal Aug 13, 2023
215abf3
fix wrong derivative interface, add documentation
Red-Portal Aug 13, 2023
88ad768
update documentation
Red-Portal Aug 13, 2023
e66935b
add doc build CI
Red-Portal Aug 13, 2023
9f1c647
remove convergence criterion for now
Red-Portal Aug 13, 2023
c8b3ee3
remove outdated export
Red-Portal Aug 13, 2023
afda1a1
update documentation
Red-Portal Aug 13, 2023
0d37ace
update documentation
Red-Portal Aug 13, 2023
b8b113d
update documentation
Red-Portal Aug 13, 2023
b78e713
fix type error in test
Red-Portal Aug 16, 2023
a0564b5
remove default ADType argument
Red-Portal Aug 16, 2023
3795d1e
update README
Red-Portal Aug 17, 2023
28a35bc
update make getting started example actually run Julia
Red-Portal Aug 17, 2023
620b38e
fix remove Float32 tests for inference tests
Red-Portal Aug 17, 2023
fa53398
update version
Red-Portal Aug 17, 2023
e909f41
add documentation publishing url
Red-Portal Aug 17, 2023
43f5b75
fix wrong uuid for ForwardDiff
Red-Portal Aug 17, 2023
468d5ca
Update CI.yml
yebai Aug 17, 2023
c07a511
refactor use `sum` and `mean` instead of abusing `mapreduce`
Red-Portal Aug 17, 2023
8256df1
Merge branch 'rewriting_advancedvi' of github.com:Red-Portal/Advanced…
Red-Portal Aug 17, 2023
13a8a44
remove tests for `FullMonteCarlo`
Red-Portal Aug 17, 2023
aadf8d3
add tests for the `optimize` interface
Red-Portal Aug 18, 2023
8c4e13d
fix turn off Zygote tests for now
Red-Portal Aug 18, 2023
0b708e6
remove unused function
Red-Portal Aug 18, 2023
be61acd
refactor change bijector field name, simplify STL estimator
Red-Portal Aug 18, 2023
fb519a5
update documentation
Red-Portal Aug 18, 2023
8682fd9
update STL documentation
Red-Portal Aug 18, 2023
9a16ee1
update STL documentation
Red-Portal Aug 18, 2023
fc74afa
update location scale documentation
Red-Portal Aug 18, 2023
4be30a1
fix README
Red-Portal Aug 19, 2023
c58309d
fix math in README
Red-Portal Aug 19, 2023
5b5bd3e
add gradient to arguments of callback!, remove `gradient_norm` info
Red-Portal Aug 20, 2023
967021d
fix math in README.md
Red-Portal Aug 21, 2023
4dab522
fix type constraint in `ZygoteExt`
Red-Portal Aug 21, 2023
8ab2f19
fix import of `Random`
Red-Portal Aug 21, 2023
83dec9f
refactor `__init__()`
Red-Portal Aug 21, 2023
a3e563c
fix type constraint in definition of `value_and_gradient!`
Red-Portal Aug 21, 2023
5553bb9
refactor `ZygoteExt`; use `only` instead of `first`
Red-Portal Aug 21, 2023
79b4557
refactor type constraint in `ReverseDiffExt`
Red-Portal Aug 21, 2023
656b44b
refactor remove outdated debug mode macro
Red-Portal Aug 21, 2023
c794063
fix remove outdated DEBUG mechanism
Red-Portal Aug 21, 2023
0c5cc1c
fix LaTeX in README: `operatorname` is currently broken
Red-Portal Aug 21, 2023
29d7d27
remove `SimpleUnPack` dependency
Red-Portal Aug 22, 2023
75eef44
fix LaTeX in docs and README
Red-Portal Aug 22, 2023
40574f4
add warning about forward-mode AD when using `LocationScale`
Red-Portal Aug 22, 2023
8738256
fix documentation
Red-Portal Aug 22, 2023
8173744
fix remove reamining use of `@unpack`
Red-Portal Aug 22, 2023
e0548ae
Revert "remove `SimpleUnPack` dependency"
Red-Portal Aug 22, 2023
6ab95a0
Revert "fix remove reamining use of `@unpack`"
Red-Portal Aug 22, 2023
f0ec242
fix documentation for `optimize`
Red-Portal Aug 22, 2023
1d4c1b6
add specializations of `Optimise.destructure` for mean-field
Red-Portal Aug 22, 2023
231835f
add test for `Optimisers.destructure` specializations
Red-Portal Aug 22, 2023
ea2d426
add specialization of `rand` for meanfield resulting in faster AD
Red-Portal Aug 22, 2023
3033d75
add argument checks for `VIMeanFieldGaussian`, `VIFullRankGaussian`
Red-Portal Aug 22, 2023
0cc36c0
update documentation
Red-Portal Aug 22, 2023
b7d3471
fix type instability, bug in argument check in `LocationScale`
Red-Portal Aug 22, 2023
df50e83
add missing import bug
Red-Portal Aug 22, 2023
ae3e9b0
refactor test, fix type bug in tests for `LocationScale`
Red-Portal Aug 22, 2023
e4002cf
add missing compat entries
Red-Portal Aug 22, 2023
8c82569
fix missing package import in test
Red-Portal Aug 22, 2023
c2e7517
add additional tests for sampling `LocationScale`
Red-Portal Aug 22, 2023
3a6f8bf
fix bug in batch in-place `rand!` for `LocationScale`
Red-Portal Aug 22, 2023
b78ef4b
fix bug in inference test initialization
Red-Portal Aug 22, 2023
a1f7e98
add missing file
Red-Portal Aug 23, 2023
8b783ec
fix remove use of for 1.6
Red-Portal Aug 23, 2023
12cd9f2
refactor adjust inference test hyperparameters to be more robust
Red-Portal Aug 23, 2023
837c729
refactor `optimize` to return `obj_state`, add warm start kwargs
Red-Portal Aug 24, 2023
95629a5
refactor make tests more robust, reduce amount of tests
Red-Portal Aug 24, 2023
0b4b865
fix remove a cholesky in test model
Red-Portal Aug 24, 2023
b49f4eb
fix compat bounds, remove unused package
Red-Portal Aug 24, 2023
947a070
bump compat for ADTypes 0.2
Red-Portal Aug 24, 2023
a9b3f48
fix broken LaTeX in README
Red-Portal Aug 24, 2023
54826eb
remove redundant use of PDMats in docs
Red-Portal Aug 24, 2023
1d1c8ff
fix use `Cholesky` signature supported in 1.6
Red-Portal Aug 24, 2023
a0de2cf
fix remove redundant cholesky operation in test
Red-Portal Aug 24, 2023
f593a67
add `mean`, `var`, `cov` to `LocationScale`
Red-Portal Aug 24, 2023
ff32ac6
refactor `optimize` warm-starting interface, add `objargs` argument
Red-Portal Aug 24, 2023
bc5cfd3
update documentation for `optimize`
Red-Portal Aug 24, 2023
de4284e
fix CUDA-compatibility bugs
Red-Portal Aug 25, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
181 changes: 62 additions & 119 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,20 +33,12 @@ function __init__()
export ZygoteAD

function AdvancedVI.grad!(
vo,
alg::VariationalInference{<:AdvancedVI.ZygoteAD},
q,
model,
θ::AbstractVector{<:Real},
f::Function,
::Type{<:ZygoteAD},
λ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
args...
)
f(θ) = if (q isa Distribution)
- vo(alg, update(q, θ), model, args...)
else
- vo(alg, q(θ), model, args...)
end
y, back = Zygote.pullback(f, θ)
y, back = Zygote.pullback(f, λ)
dy = first(back(1.0))
DiffResults.value!(out, y)
DiffResults.gradient!(out, dy)
Expand All @@ -58,21 +50,13 @@ function __init__()
export ReverseDiffAD

function AdvancedVI.grad!(
vo,
alg::VariationalInference{<:AdvancedVI.ReverseDiffAD{false}},
q,
model,
θ::AbstractVector{<:Real},
f::Function,
::Type{<:ReverseDiffAD},
λ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
args...
)
f(θ) = if (q isa Distribution)
- vo(alg, update(q, θ), model, args...)
else
- vo(alg, q(θ), model, args...)
end
tp = AdvancedVI.tape(f, θ)
ReverseDiff.gradient!(out, tp, θ)
tp = AdvancedVI.tape(f, λ)
ReverseDiff.gradient!(out, tp, λ)
return out
end
end
Expand All @@ -81,26 +65,18 @@ function __init__()
export EnzymeAD

function AdvancedVI.grad!(
vo,
alg::VariationalInference{<:AdvancedVI.EnzymeAD},
q,
model,
θ::AbstractVector{<:Real},
f::Function,
::Type{<:EnzymeAD},
λ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
args...
)
f(θ) = if (q isa Distribution)
- vo(alg, update(q, θ), model, args...)
else
- vo(alg, q(θ), model, args...)
end
# Use `Enzyme.ReverseWithPrimal` once it is released:
# https://github.com/EnzymeAD/Enzyme.jl/pull/598
y = f(θ)
y = f(λ)
DiffResults.value!(out, y)
dy = DiffResults.gradient(out)
fill!(dy, 0)
Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, dy))
Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(λ, dy))
return out
end
end
Expand All @@ -109,30 +85,20 @@ end
export
vi,
ADVI,
ELBO,
elbo,
TruncatedADAGrad,
DecayedADAGrad,
VariationalInference

abstract type VariationalInference{AD} end

getchunksize(::Type{<:VariationalInference{AD}}) where AD = getchunksize(AD)
getADtype(::VariationalInference{AD}) where AD = AD
DecayedADAGrad

abstract type VariationalObjective end

const VariationalPosterior = Distribution{Multivariate, Continuous}


"""
grad!(vo, alg::VariationalInference, q, model::Model, θ, out, args...)
grad!(f, λ, out)

Computes the gradients used in `optimize!`. Default implementation is provided for
Computes the gradients of the objective f. Default implementation is provided for
`VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`.
This implicitly also gives a default implementation of `optimize!`.

Variance reduction techniques, e.g. control variates, should be implemented in this function.
"""
function grad! end

Expand All @@ -157,51 +123,36 @@ function update end

# default implementations
function grad!(
vo,
alg::VariationalInference{<:ForwardDiffAD},
q,
model,
θ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
args...
f::Function,
adtype::Type{<:ForwardDiffAD},
λ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult
)
f(θ_) = if (q isa Distribution)
- vo(alg, update(q, θ_), model, args...)
else
- vo(alg, q(θ_), model, args...)
end

# Set chunk size and do ForwardMode.
chunk_size = getchunksize(typeof(alg))
chunk_size = getchunksize(adtype)
config = if chunk_size == 0
ForwardDiff.GradientConfig(f, θ)
ForwardDiff.GradientConfig(f, λ)
else
ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size))
ForwardDiff.GradientConfig(f, λ, ForwardDiff.Chunk(length(λ), chunk_size))
end
ForwardDiff.gradient!(out, f, θ, config)
ForwardDiff.gradient!(out, f, λ, config)
end

function grad!(
vo,
alg::VariationalInference{<:TrackerAD},
q,
model,
θ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult,
args...
f::Function,
::Type{<:TrackerAD},
λ::AbstractVector{<:Real},
out::DiffResults.MutableDiffResult
)
θ_tracked = Tracker.param(θ)
y = if (q isa Distribution)
- vo(alg, update(q, θ_tracked), model, args...)
else
- vo(alg, q(θ_tracked), model, args...)
end
λ_tracked = Tracker.param(λ)
y = f(λ_tracked)
Tracker.back!(y, 1.0)

DiffResults.value!(out, Tracker.data(y))
DiffResults.gradient!(out, Tracker.grad(θ_tracked))
DiffResults.gradient!(out, Tracker.grad(λ_tracked))
end

abstract type AbstractGradientEstimator end

"""
optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad())
Expand All @@ -210,61 +161,53 @@ Iteratively updates parameters by calling `grad!` and using the given `optimizer
the steps.
"""
function optimize!(
vo,
alg::VariationalInference,
q,
model,
θ::AbstractVector{<:Real};
optimizer = TruncatedADAGrad()
grad_estimator::AbstractGradientEstimator,
rebuild::Function,
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
ℓπ::Function,
n_max_iter::Int,
λ::AbstractVector{<:Real};
optimizer = TruncatedADAGrad(),
rng = Random.GLOBAL_RNG
)
# TODO: should we always assume `samples_per_step` and `max_iters` for all algos?
alg_name = alg_str(alg)
samples_per_step = alg.samples_per_step
max_iters = alg.max_iters

num_params = length(θ)
obj_name = objective(grad_estimator)

# TODO: really need a better way to warn the user about potentially
# not using the correct accumulator
if (optimizer isa TruncatedADAGrad) && (θ ∉ keys(optimizer.acc))
if (optimizer isa TruncatedADAGrad) && (λ ∉ keys(optimizer.acc))
# this message should only occurr once in the optimization process
@info "[$alg_name] Should only be seen once: optimizer created for θ" objectid(θ)
@info "[$obj_name] Should only be seen once: optimizer created for θ" objectid(λ)
end

diff_result = DiffResults.GradientResult(θ)
grad_buf = DiffResults.GradientResult(λ)

i = 0
prog = if PROGRESS[]
ProgressMeter.Progress(max_iters, 1, "[$alg_name] Optimizing...", 0)
else
0
end
prog = ProgressMeter.Progress(
n_max_iter; desc="[$obj_name] Optimizing...", barlen=0, enabled=PROGRESS[])

# add criterion? A running mean maybe?
time_elapsed = @elapsed while (i < max_iters) # & converged
grad!(vo, alg, q, model, θ, diff_result, samples_per_step)

# apply update rule
Δ = DiffResults.gradient(diff_result)
Δ = apply!(optimizer, θ, Δ)
@. θ = θ - Δ
time_elapsed = @elapsed begin
for i = 1:n_max_iter
stats = estimate_gradient!(rng, grad_estimator, λ, rebuild, ℓπ, grad_buf)

# apply update rule
Δλ = DiffResults.gradient(grad_buf)
Δλ = apply!(optimizer, λ, Δλ)
@. λ = λ - Δλ

stat′ = (Δλ=norm(Δλ),)
stats = merge(stats, stat′)
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved

AdvancedVI.DEBUG && @debug "Step $i" Δ DiffResults.value(diff_result)
PROGRESS[] && (ProgressMeter.next!(prog))

i += 1
AdvancedVI.DEBUG && @debug "Step $i" stats...
pm_next!(prog, stats)
end
end

return θ
return λ
end

# objectives
include("objectives.jl")
include("estimators/advi.jl")

# optimisers
include("optimisers.jl")

# VI algorithms
include("advi.jl")

end # module
47 changes: 0 additions & 47 deletions src/advi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,50 +50,3 @@ function optimize(elbo::ELBO, alg::ADVI, q, model, θ_init; optimizer = Truncate
return θ
end

# WITHOUT updating parameters inside ELBO
function (elbo::ELBO)(
rng::Random.AbstractRNG,
alg::ADVI,
q::VariationalPosterior,
logπ::Function,
num_samples
)
# 𝔼_q(z)[log p(xᵢ, z)]
# = ∫ log p(xᵢ, z) q(z) dz
# = ∫ log p(xᵢ, f(ϕ)) q(f(ϕ)) |det J_f(ϕ)| dϕ (since change of variables)
# = ∫ log p(xᵢ, f(ϕ)) q̃(ϕ) dϕ (since q(f(ϕ)) |det J_f(ϕ)| = q̃(ϕ))
# = 𝔼_q̃(ϕ)[log p(xᵢ, z)]

# 𝔼_q(z)[log q(z)]
# = ∫ q(f(ϕ)) log (q(f(ϕ))) |det J_f(ϕ)| dϕ (since q(f(ϕ)) |det J_f(ϕ)| = q̃(ϕ))
# = 𝔼_q̃(ϕ) [log q(f(ϕ))]
# = 𝔼_q̃(ϕ) [log q̃(ϕ) - log |det J_f(ϕ)|]
# = 𝔼_q̃(ϕ) [log q̃(ϕ)] - 𝔼_q̃(ϕ) [log |det J_f(ϕ)|]
# = - ℍ(q̃(ϕ)) - 𝔼_q̃(ϕ) [log |det J_f(ϕ)|]

# Finally, the ELBO is given by
# ELBO = 𝔼_q(z)[log p(xᵢ, z)] - 𝔼_q(z)[log q(z)]
# = 𝔼_q̃(ϕ)[log p(xᵢ, z)] + 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] + ℍ(q̃(ϕ))

# If f: supp(p(z | x)) → ℝ then
# ELBO = 𝔼[log p(x, z) - log q(z)]
# = 𝔼[log p(x, f⁻¹(z̃)) + logabsdet(J(f⁻¹(z̃)))] + ℍ(q̃(z̃))
# = 𝔼[log p(x, z) - logabsdetjac(J(f(z)))] + ℍ(q̃(z̃))

# But our `rand_and_logjac(q)` is using f⁻¹: ℝ → supp(p(z | x)) going forward → `+ logjac`
z, logjac = rand_and_logjac(rng, q)
res = (logπ(z) + logjac) / num_samples

if q isa TransformedDistribution
res += entropy(q.dist)
else
res += entropy(q)
end

for i = 2:num_samples
z, logjac = rand_and_logjac(rng, q)
res += (logπ(z) + logjac) / num_samples
end

return res
end
29 changes: 29 additions & 0 deletions src/estimators/advi.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@

struct ADVI <: AbstractGradientEstimator
n_samples::Int
end

objective(::ADVI) = "ELBO"

function estimate_gradient!(
rng::Random.AbstractRNG,
estimator::ADVI,
λ::Vector{<:Real},
rebuild::Function,
logπ::Function,
out::DiffResults.MutableDiffResult)

n_samples = estimator.n_samples

grad!(ADBackend(), λ, out) do λ′
q = rebuild(λ′)
zs, ∑logjac = rand_and_logjac(rng, q, estimator.n_samples)

elbo = mapreduce(+, eachcol(zs)) do zᵢ
(logπ(zᵢ) + ∑logjac)
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
end / n_samples
-elbo
end
nelbo = DiffResults.value(out)
(elbo=-nelbo,)
end
7 changes: 0 additions & 7 deletions src/objectives.jl

This file was deleted.

15 changes: 15 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,18 @@ function rand_and_logjac(rng::Random.AbstractRNG, dist::Bijectors.TransformedDis
y, logjac = Bijectors.with_logabsdet_jacobian(dist.transform, x)
return y, logjac
end

function rand_and_logjac(rng::Random.AbstractRNG, dist::Distribution, n_samples::Int)
x = rand(rng, dist, n_samples)
return x, zero(eltype(x))
end

function rand_and_logjac(rng::Random.AbstractRNG, dist::Bijectors.TransformedDistribution, n_samples::Int)
x = rand(rng, dist.dist, n_samples)
y, logjac = Bijectors.with_logabsdet_jacobian(dist.transform, x)
return y, logjac
end

function pm_next!(pm, stats::NamedTuple)
ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)])
end