Skip to content

Commit

Permalink
refactor rename advi to repgradelbo and not use bijectors directly
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Nov 21, 2023
1 parent a063583 commit 316b629
Show file tree
Hide file tree
Showing 13 changed files with 182 additions and 405 deletions.
2 changes: 0 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ version = "0.3.0"
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Expand Down Expand Up @@ -36,7 +35,6 @@ AdvancedVIZygoteExt = "Zygote"
[compat]
ADTypes = "0.1, 0.2"
Accessors = "0.1"
Bijectors = "0.12, 0.13"
ChainRulesCore = "1.16"
DiffResults = "1"
Distributions = "0.25.87"
Expand Down
9 changes: 3 additions & 6 deletions src/AdvancedVI.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ using Functors
using Optimisers

using DocStringExtensions

using ProgressMeter
using LinearAlgebra

Expand All @@ -21,7 +20,6 @@ using ADTypes, DiffResults
using ChainRulesCore

using FillArrays
using Bijectors

using StatsBase

Expand Down Expand Up @@ -115,18 +113,17 @@ Estimate (possibly stochastic) gradients of the variational objective `obj` targ
"""
function estimate_gradient! end

# ADVI-specific interfaces
# ELBO-specific interfaces
abstract type AbstractEntropyEstimator end

export
ADVI,
RepGradELBO,
ClosedFormEntropy,
StickingTheLandingEntropy,
MonteCarloEntropy

# entropy.jl must preceed advi.jl
include("objectives/elbo/entropy.jl")
include("objectives/elbo/advi.jl")
include("objectives/elbo/repgradelbo.jl")

# Optimization Routine

Expand Down
166 changes: 0 additions & 166 deletions src/objectives/elbo/advi.jl

This file was deleted.

126 changes: 126 additions & 0 deletions src/objectives/elbo/repgradelbo.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@

"""
RepGradELBO(n_samples; kwargs...)
Evidence lower-bound objective with the reparameterization gradient formulation[^TL2014][^RMW2014][^KW2014].
This computes the evidence lower-bound (ELBO) through the formulation:
```math
\\begin{aligned}
\\mathrm{ELBO}\\left(\\lambda\\right)
&\\triangleq
\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[
\\log \\pi\\left(z\\right)
\\right]
+ \\mathbb{H}\\left(q_{\\lambda}\\right),
\\end{aligned}
```
# Arguments
- `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO.
# Keyword Arguments
- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy())
# Requirements
- ``q_{\\lambda}`` implements `rand`.
- The target `logdensity(prob, x)` must be differentiable wrt. `x` by the selected AD backend.
Depending on the options, additional requirements on ``q_{\\lambda}`` may apply.
# References
[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In ICML.
[^RMW2014]: Rezende, D. J., Mohamed, S., & Wierstra, D. (2014, June). Stochastic backpropagation and approximate inference in deep generative models. In ICML.
[^KW2014]: Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In ICLR.
"""
struct RepGradELBO{EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective
entropy ::EntropyEst
n_samples::Int
end

RepGradELBO(
n_samples::Int;
entropy ::AbstractEntropyEstimator = ClosedFormEntropy()
) = RepGradELBO(entropy, n_samples)

Base.show(io::IO, obj::RepGradELBO) =

Check warning on line 45 in src/objectives/elbo/repgradelbo.jl

View check run for this annotation

Codecov / codecov/patch

src/objectives/elbo/repgradelbo.jl#L45

Added line #L45 was not covered by tests
print(io, "RepGradELBO(entropy=$(obj.entropy), n_samples=$(obj.n_samples))")

maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop

maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q

function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop)
q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop)
estimate_entropy(entropy_estimator, samples, q_maybe_stop)
end

function estimate_energy_with_samples(::RepGradELBO, samples, prob)
mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples))
end

function estimate_repgradelbo_maybe_stl_with_samples(
obj::RepGradELBO, q, q_stop, samples::AbstractMatrix, prob
)
energy = estimate_energy_with_samples(obj, samples, prob)
entropy = estimate_entropy_maybe_stl(obj.entropy, samples, q, q_stop)
energy + entropy
end

function estimate_repgradelbo_maybe_stl(rng::Random.AbstractRNG, obj::RepGradELBO, q, q_stop, prob)
samples = rand(rng, q, obj.n_samples)
estimate_repgradelbo_maybe_stl_with_samples(obj, q, q_stop, samples, prob)
end

"""
estimate_objective([rng,] obj, q, prob; n_samples)
Estimate the ELBO using the reparameterization gradient formulation.
# Arguments
- `obj::RepGradELBO`: The ELBO objective.
- `q`: Variational approximation
- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface.
# Keyword Arguments
- `n_samples::Int = obj.n_samples`: Number of samples to be used to estimate the objective.
# Returns
- `obj_est`: Estimate of the objective value.
"""
function estimate_objective(
rng::Random.AbstractRNG,
obj::RepGradELBO,
q,
prob;
n_samples::Int = obj.n_samples
)
samples = rand(rng, q, n_samples)
estimate_repgradelbo_maybe_stl_with_samples(obj, q, q, samples, prob)
end

estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) =
estimate_objective(Random.default_rng(), obj, q, prob; n_samples)

function estimate_gradient!(
rng ::Random.AbstractRNG,
obj ::RepGradELBO,
adbackend::ADTypes.AbstractADType,
out ::DiffResults.MutableDiffResult,
prob,
λ,
restructure,
est_state,
)
q_stop = restructure(λ)
function f(λ′)
q = restructure(λ′)
elbo = estimate_repgradelbo_maybe_stl(rng, obj, q, q_stop, prob)
-elbo
end
value_and_gradient!(adbackend, f, λ, out)

nelbo = DiffResults.value(out)
stat = (elbo=-nelbo,)

out, nothing, stat
end
3 changes: 3 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,6 @@ function maybe_init_objective(
haskey(state_init, :objective) ? state_init.objective : init(rng, objective, λ, restructure)
end

eachsample(samples::AbstractMatrix) = eachcol(samples)

eachsample(samples::AbstractVector) = samples

Check warning on line 26 in src/utils.jl

View check run for this annotation

Codecov / codecov/patch

src/utils.jl#L26

Added line #L26 was not covered by tests
2 changes: 0 additions & 2 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Expand All @@ -23,7 +22,6 @@ Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
ADTypes = "0.2.1"
Bijectors = "0.13.6"
Distributions = "0.25.100"
DistributionsAD = "0.6.45"
Enzyme = "0.11.7"
Expand Down
Loading

0 comments on commit 316b629

Please sign in to comment.