diff --git a/Project.toml b/Project.toml index 800fa34f..f4ea1bcc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,37 +1,71 @@ name = "AdvancedVI" uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" -version = "0.2.4" +version = "0.3.0" [deps] -Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[weakdeps] +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[extensions] +AdvancedVIEnzymeExt = "Enzyme" +AdvancedVIForwardDiffExt = "ForwardDiff" +AdvancedVIReverseDiffExt = "ReverseDiff" +AdvancedVIZygoteExt = "Zygote" +AdvancedVIBijectorsExt = "Bijectors" [compat] -Bijectors = "0.11, 0.12, 0.13" -Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" -DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6" +ADTypes = "0.1, 0.2" +Accessors = "0.1" +Bijectors = "0.13" +ChainRulesCore = "1.16" +DiffResults = "1" +Distributions = "0.25.87" DocStringExtensions = "0.8, 0.9" -ForwardDiff = "0.10.3" -ProgressMeter = "1.0.0" -Requires = "0.5, 1.0" +Enzyme = "0.11.7" +FillArrays = "1.3" +ForwardDiff = "0.10.36" +Functors = "0.4" +LinearAlgebra = "1" +LogDensityProblems = "2" +Optimisers = "0.2.16, 0.3" +ProgressMeter = "1.6" +Random = "1" +Requires = "1.0" +ReverseDiff = "1.15.1" +SimpleUnPack = "1.1.0" StatsBase = "0.32, 0.33, 0.34" -StatsFuns = "0.8, 0.9, 1" -Tracker = "0.2.3" +Zygote = "0.6.63" julia = "1.6" [extras] +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] test = ["Pkg", "Test"] diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl new file mode 100644 index 00000000..1b200ac5 --- /dev/null +++ b/ext/AdvancedVIBijectorsExt.jl @@ -0,0 +1,45 @@ + +module AdvancedVIBijectorsExt + +if isdefined(Base, :get_extension) + using AdvancedVI + using Bijectors + using Random +else + using ..AdvancedVI + using ..Bijectors + using ..Random +end + +function AdvancedVI.reparam_with_entropy( + rng ::Random.AbstractRNG, + q ::Bijectors.TransformedDistribution, + q_stop ::Bijectors.TransformedDistribution, + n_samples::Int, + ent_est ::AdvancedVI.AbstractEntropyEstimator +) + transform = q.transform + q_base = q.dist + q_base_stop = q_stop.dist + base_samples = rand(rng, q_base, n_samples) + it = AdvancedVI.eachsample(base_samples) + sample_init = first(it) + + samples_and_logjac = mapreduce( + AdvancedVI.catsamples_and_acc, + Iterators.drop(it, 1); + init=with_logabsdet_jacobian(transform, sample_init) + ) do sample + with_logabsdet_jacobian(transform, sample) + end + samples = first(samples_and_logjac) + logjac = last(samples_and_logjac) + + entropy_base = AdvancedVI.estimate_entropy_maybe_stl( + ent_est, base_samples, q_base, q_base_stop + ) + + entropy = entropy_base + logjac/n_samples + samples, entropy +end +end diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl new file mode 100644 index 00000000..8333299f --- /dev/null +++ b/ext/AdvancedVIEnzymeExt.jl @@ -0,0 +1,26 @@ + +module AdvancedVIEnzymeExt + +if isdefined(Base, :get_extension) + using Enzyme + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults +else + using ..Enzyme + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults +end + +# Enzyme doesn't support f::Bijectors (see https://github.com/EnzymeAD/Enzyme.jl/issues/916) +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult +) where {T<:Real} + y = f(θ) + DiffResults.value!(out, y) + ∇θ = DiffResults.gradient(out) + fill!(∇θ, zero(T)) + Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ)) + return out +end + +end diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl new file mode 100644 index 00000000..5949bdf8 --- /dev/null +++ b/ext/AdvancedVIForwardDiffExt.jl @@ -0,0 +1,29 @@ + +module AdvancedVIForwardDiffExt + +if isdefined(Base, :get_extension) + using ForwardDiff + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults +else + using ..ForwardDiff + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults +end + +getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize + +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult +) where {T<:Real} + chunk_size = getchunksize(ad) + config = if isnothing(chunk_size) + ForwardDiff.GradientConfig(f, θ) + else + ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size)) + end + ForwardDiff.gradient!(out, f, θ, config) + return out +end + +end diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl new file mode 100644 index 00000000..520cd9ff --- /dev/null +++ b/ext/AdvancedVIReverseDiffExt.jl @@ -0,0 +1,23 @@ + +module AdvancedVIReverseDiffExt + +if isdefined(Base, :get_extension) + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults + using ReverseDiff +else + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults + using ..ReverseDiff +end + +# ReverseDiff without compiled tape +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult +) + tp = ReverseDiff.GradientTape(f, θ) + ReverseDiff.gradient!(out, tp, θ) + return out +end + +end diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl new file mode 100644 index 00000000..7b8f8817 --- /dev/null +++ b/ext/AdvancedVIZygoteExt.jl @@ -0,0 +1,24 @@ + +module AdvancedVIZygoteExt + +if isdefined(Base, :get_extension) + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults + using Zygote +else + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults + using ..Zygote +end + +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoZygote, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult +) + y, back = Zygote.pullback(f, θ) + ∇θ = back(one(y)) + DiffResults.value!(out, y) + DiffResults.gradient!(out, only(∇θ)) + return out +end + +end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index e203a13c..a69b8d89 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -1,270 +1,190 @@ + module AdvancedVI -using Random: Random +using SimpleUnPack: @unpack, @pack! +using Accessors -using Distributions, DistributionsAD, Bijectors -using DocStringExtensions +using Random +using Distributions -using ProgressMeter, LinearAlgebra +using Functors +using Optimisers -using ForwardDiff -using Tracker +using DocStringExtensions +using ProgressMeter +using LinearAlgebra -const PROGRESS = Ref(true) -function turnprogress(switch::Bool) - @info("[AdvancedVI]: global PROGRESS is set as $switch") - PROGRESS[] = switch -end +using LogDensityProblems -const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0"))) +using ADTypes, DiffResults +using ChainRulesCore -include("ad.jl") -include("utils.jl") +using FillArrays -using Requires -function __init__() - @require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin - apply!(o, x, Δ) = Flux.Optimise.apply!(o, x, Δ) - Flux.Optimise.apply!(o::TruncatedADAGrad, x, Δ) = apply!(o, x, Δ) - Flux.Optimise.apply!(o::DecayedADAGrad, x, Δ) = apply!(o, x, Δ) - end - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin - include("compat/zygote.jl") - export ZygoteAD - - function AdvancedVI.grad!( - vo, - alg::VariationalInference{<:AdvancedVI.ZygoteAD}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... - ) - f(θ) = if (q isa Distribution) - - vo(alg, update(q, θ), model, args...) - else - - vo(alg, q(θ), model, args...) - end - y, back = Zygote.pullback(f, θ) - dy = first(back(1.0)) - DiffResults.value!(out, y) - DiffResults.gradient!(out, dy) - return out - end - end - @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin - include("compat/reversediff.jl") - export ReverseDiffAD - - function AdvancedVI.grad!( - vo, - alg::VariationalInference{<:AdvancedVI.ReverseDiffAD{false}}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... - ) - f(θ) = if (q isa Distribution) - - vo(alg, update(q, θ), model, args...) - else - - vo(alg, q(θ), model, args...) - end - tp = AdvancedVI.tape(f, θ) - ReverseDiff.gradient!(out, tp, θ) - return out - end - end - @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin - include("compat/enzyme.jl") - export EnzymeAD - - function AdvancedVI.grad!( - vo, - alg::VariationalInference{<:AdvancedVI.EnzymeAD}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... - ) - f(θ) = if (q isa Distribution) - - vo(alg, update(q, θ), model, args...) - else - - vo(alg, q(θ), model, args...) - end - # Use `Enzyme.ReverseWithPrimal` once it is released: - # https://github.com/EnzymeAD/Enzyme.jl/pull/598 - y = f(θ) - DiffResults.value!(out, y) - dy = DiffResults.gradient(out) - fill!(dy, 0) - Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, dy)) - return out - end - end -end +using StatsBase -export - vi, - ADVI, - ELBO, - elbo, - TruncatedADAGrad, - DecayedADAGrad, - VariationalInference +# derivatives +""" + value_and_gradient!(ad, f, θ, out) -abstract type VariationalInference{AD} end +Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad` and store the result in `out`. -getchunksize(::Type{<:VariationalInference{AD}}) where AD = getchunksize(AD) -getADtype(::VariationalInference{AD}) where AD = AD +# Arguments +- `ad::ADTypes.AbstractADType`: Automatic differentiation backend. +- `f`: Function subject to differentiation. +- `θ`: The point to evaluate the gradient. +- `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value. +""" +function value_and_gradient! end -abstract type VariationalObjective end +# estimators +""" + AbstractVariationalObjective -const VariationalPosterior = Distribution{Multivariate, Continuous} +Abstract type for the VI algorithms supported by `AdvancedVI`. +# Implementations +To be supported by `AdvancedVI`, a VI algorithm must implement `AbstractVariationalObjective` and `estimate_objective`. +Also, it should provide gradients by implementing the function `estimate_gradient!`. +If the estimator is stateful, it can implement `init` to initialize the state. +""" +abstract type AbstractVariationalObjective end """ - grad!(vo, alg::VariationalInference, q, model::Model, θ, out, args...) + init(rng, obj, λ, restructure) -Computes the gradients used in `optimize!`. Default implementation is provided for -`VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`. -This implicitly also gives a default implementation of `optimize!`. +Initialize a state of the variational objective `obj` given the initial variational parameters `λ`. +This function needs to be implemented only if `obj` is stateful. -Variance reduction techniques, e.g. control variates, should be implemented in this function. +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `obj::AbstractVariationalObjective`: Variational objective. +- `λ`: Initial variational parameters. +- `restructure`: Function that reconstructs the variational approximation from `λ`. """ -function grad! end +init( + ::Random.AbstractRNG, + ::AbstractVariationalObjective, + ::AbstractVector, + ::Any +) = nothing """ - vi(model, alg::VariationalInference) - vi(model, alg::VariationalInference, q::VariationalPosterior) - vi(model, alg::VariationalInference, getq::Function, θ::AbstractArray) + estimate_objective([rng,] obj, q, prob; kwargs...) -Constructs the variational posterior from the `model` and performs the optimization -following the configuration of the given `VariationalInference` instance. +Estimate the variational objective `obj` targeting `prob` with respect to the variational approximation `q`. # Arguments -- `model`: `Turing.Model` or `Function` z ↦ log p(x, z) where `x` denotes the observations -- `alg`: the VI algorithm used -- `q`: a `VariationalPosterior` for which it is assumed a specialized implementation of the variational objective used exists. -- `getq`: function taking parameters `θ` as input and returns a `VariationalPosterior` -- `θ`: only required if `getq` is used, in which case it is the initial parameters for the variational posterior +- `rng::Random.AbstractRNG`: Random number generator. +- `obj::AbstractVariationalObjective`: Variational objective. +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. +- `q`: Variational approximation. + +# Keyword Arguments +Depending on the objective, additional keyword arguments may apply. +Please refer to the respective documentation of each variational objective for more info. + +# Returns +- `obj_est`: Estimate of the objective value. """ -function vi end - -function update end - -# default implementations -function grad!( - vo, - alg::VariationalInference{<:ForwardDiffAD}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... -) - f(θ_) = if (q isa Distribution) - - vo(alg, update(q, θ_), model, args...) - else - - vo(alg, q(θ_), model, args...) - end +function estimate_objective end - # Set chunk size and do ForwardMode. - chunk_size = getchunksize(typeof(alg)) - config = if chunk_size == 0 - ForwardDiff.GradientConfig(f, θ) - else - ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size)) - end - ForwardDiff.gradient!(out, f, θ, config) -end +export estimate_objective -function grad!( - vo, - alg::VariationalInference{<:TrackerAD}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... -) - θ_tracked = Tracker.param(θ) - y = if (q isa Distribution) - - vo(alg, update(q, θ_tracked), model, args...) - else - - vo(alg, q(θ_tracked), model, args...) - end - Tracker.back!(y, 1.0) - DiffResults.value!(out, Tracker.data(y)) - DiffResults.gradient!(out, Tracker.grad(θ_tracked)) -end +""" + estimate_gradient!(rng, obj, adbackend, out, prob, λ, restructure, obj_state) +Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ` +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `obj::AbstractVariationalObjective`: Variational objective. +- `adbackend::ADTypes.AbstractADType`: Automatic differentiation backend. +- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates. +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. +- `λ`: Variational parameters to evaluate the gradient on. +- `restructure`: Function that reconstructs the variational approximation from `λ`. +- `obj_state`: Previous state of the objective. + +# Returns +- `out::MutableDiffResult`: Buffer containing the objective value and gradient estimates. +- `obj_state`: The updated state of the objective. +- `stat::NamedTuple`: Statistics and logs generated during estimation. """ - optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad()) +function estimate_gradient! end + +# ELBO-specific interfaces +abstract type AbstractEntropyEstimator end -Iteratively updates parameters by calling `grad!` and using the given `optimizer` to compute -the steps. """ -function optimize!( - vo, - alg::VariationalInference, - q, - model, - θ::AbstractVector{<:Real}; - optimizer = TruncatedADAGrad() -) - # TODO: should we always assume `samples_per_step` and `max_iters` for all algos? - alg_name = alg_str(alg) - samples_per_step = alg.samples_per_step - max_iters = alg.max_iters - - num_params = length(θ) - - # TODO: really need a better way to warn the user about potentially - # not using the correct accumulator - if (optimizer isa TruncatedADAGrad) && (θ ∉ keys(optimizer.acc)) - # this message should only occurr once in the optimization process - @info "[$alg_name] Should only be seen once: optimizer created for θ" objectid(θ) - end + estimate_entropy(entropy_estimator, mc_samples, q) - diff_result = DiffResults.GradientResult(θ) +Estimate the entropy of `q`. - i = 0 - prog = if PROGRESS[] - ProgressMeter.Progress(max_iters, 1, "[$alg_name] Optimizing...", 0) - else - 0 - end +# Arguments +- `entropy_estimator`: Entropy estimation strategy. +- `q`: Variational approximation. +- `mc_samples`: Monte Carlo samples used to estimate the entropy. (Only used for Monte Carlo strategies.) + +# Returns +- `obj_est`: Estimate of the objective value. +""" +function estimate_entropy end - # add criterion? A running mean maybe? - time_elapsed = @elapsed while (i < max_iters) # & converged - grad!(vo, alg, q, model, θ, diff_result, samples_per_step) +export + RepGradELBO, + ClosedFormEntropy, + StickingTheLandingEntropy, + MonteCarloEntropy - # apply update rule - Δ = DiffResults.gradient(diff_result) - Δ = apply!(optimizer, θ, Δ) - @. θ = θ - Δ - - AdvancedVI.DEBUG && @debug "Step $i" Δ DiffResults.value(diff_result) - PROGRESS[] && (ProgressMeter.next!(prog)) +include("objectives/elbo/entropy.jl") +include("objectives/elbo/repgradelbo.jl") - i += 1 - end - return θ -end +# Variational Families +export + VILocationScale, + MeanFieldGaussian, + FullRankGaussian + +include("families/location_scale.jl") + + +# Optimization Routine -# objectives -include("objectives.jl") +function optimize end -# optimisers -include("optimisers.jl") +export optimize -# VI algorithms -include("advi.jl") +include("utils.jl") +include("optimize.jl") + + +# optional dependencies +if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base + using Requires +end + +@static if !isdefined(Base, :get_extension) + function __init__() + @require Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" begin + include("../ext/AdvancedVIBijectorsExt.jl") + end + @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin + include("../ext/AdvancedVIEnzymeExt.jl") + end + @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin + include("../ext/AdvancedVIForwardDiffExt.jl") + end + @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin + include("../ext/AdvancedVIReverseDiffExt.jl") + end + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + include("../ext/AdvancedVIZygoteExt.jl") + end + end +end + +end -end # module diff --git a/src/ad.jl b/src/ad.jl deleted file mode 100644 index 62e785e1..00000000 --- a/src/ad.jl +++ /dev/null @@ -1,46 +0,0 @@ -############################## -# Global variables/constants # -############################## -const ADBACKEND = Ref(:forwarddiff) -setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym)) -function setadbackend(::Val{:forward_diff}) - Base.depwarn("`AdvancedVI.setadbackend(:forward_diff)` is deprecated. Please use `AdvancedVI.setadbackend(:forwarddiff)` to use `ForwardDiff`.", :setadbackend) - setadbackend(Val(:forwarddiff)) -end -function setadbackend(::Val{:forwarddiff}) - ADBACKEND[] = :forwarddiff -end - -function setadbackend(::Val{:reverse_diff}) - Base.depwarn("`AdvancedVI.setadbackend(:reverse_diff)` is deprecated. Please use `AdvancedVI.setadbackend(:tracker)` to use `Tracker` or `AdvancedVI.setadbackend(:reversediff)` to use `ReverseDiff`. To use `ReverseDiff`, please make sure it is loaded separately with `using ReverseDiff`.", :setadbackend) - setadbackend(Val(:tracker)) -end -function setadbackend(::Val{:tracker}) - ADBACKEND[] = :tracker -end - -const ADSAFE = Ref(false) -function setadsafe(switch::Bool) - @info("[AdvancedVI]: global ADSAFE is set as $switch") - ADSAFE[] = switch -end - -const CHUNKSIZE = Ref(0) # 0 means letting ForwardDiff set it automatically - -function setchunksize(chunk_size::Int) - @info("[AdvancedVI]: AD chunk size is set as $chunk_size") - CHUNKSIZE[] = chunk_size -end - -abstract type ADBackend end -struct ForwardDiffAD{chunk} <: ADBackend end -getchunksize(::Type{<:ForwardDiffAD{chunk}}) where chunk = chunk - -struct TrackerAD <: ADBackend end - -ADBackend() = ADBackend(ADBACKEND[]) -ADBackend(T::Symbol) = ADBackend(Val(T)) - -ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]} -ADBackend(::Val{:tracker}) = TrackerAD -ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.") diff --git a/src/advi.jl b/src/advi.jl deleted file mode 100644 index 7f9e7346..00000000 --- a/src/advi.jl +++ /dev/null @@ -1,99 +0,0 @@ -using StatsFuns -using DistributionsAD -using Bijectors -using Bijectors: TransformedDistribution - - -""" -$(TYPEDEF) - -Automatic Differentiation Variational Inference (ADVI) with automatic differentiation -backend `AD`. - -# Fields - -$(TYPEDFIELDS) -""" -struct ADVI{AD} <: VariationalInference{AD} - "Number of samples used to estimate the ELBO in each optimization step." - samples_per_step::Int - "Maximum number of gradient steps." - max_iters::Int -end - -function ADVI(samples_per_step::Int=1, max_iters::Int=1000) - return ADVI{ADBackend()}(samples_per_step, max_iters) -end - -alg_str(::ADVI) = "ADVI" - -function vi(model, alg::ADVI, q, θ_init; optimizer = TruncatedADAGrad()) - θ = copy(θ_init) - optimize!(elbo, alg, q, model, θ; optimizer = optimizer) - - # If `q` is a mean-field approx we use the specialized `update` function - if q isa Distribution - return update(q, θ) - else - # Otherwise we assume it's a mapping θ → q - return q(θ) - end -end - - -function optimize(elbo::ELBO, alg::ADVI, q, model, θ_init; optimizer = TruncatedADAGrad()) - θ = copy(θ_init) - - # `model` assumed to be callable z ↦ p(x, z) - optimize!(elbo, alg, q, model, θ; optimizer = optimizer) - - return θ -end - -# WITHOUT updating parameters inside ELBO -function (elbo::ELBO)( - rng::Random.AbstractRNG, - alg::ADVI, - q::VariationalPosterior, - logπ::Function, - num_samples -) - # 𝔼_q(z)[log p(xᵢ, z)] - # = ∫ log p(xᵢ, z) q(z) dz - # = ∫ log p(xᵢ, f(ϕ)) q(f(ϕ)) |det J_f(ϕ)| dϕ (since change of variables) - # = ∫ log p(xᵢ, f(ϕ)) q̃(ϕ) dϕ (since q(f(ϕ)) |det J_f(ϕ)| = q̃(ϕ)) - # = 𝔼_q̃(ϕ)[log p(xᵢ, z)] - - # 𝔼_q(z)[log q(z)] - # = ∫ q(f(ϕ)) log (q(f(ϕ))) |det J_f(ϕ)| dϕ (since q(f(ϕ)) |det J_f(ϕ)| = q̃(ϕ)) - # = 𝔼_q̃(ϕ) [log q(f(ϕ))] - # = 𝔼_q̃(ϕ) [log q̃(ϕ) - log |det J_f(ϕ)|] - # = 𝔼_q̃(ϕ) [log q̃(ϕ)] - 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] - # = - ℍ(q̃(ϕ)) - 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] - - # Finally, the ELBO is given by - # ELBO = 𝔼_q(z)[log p(xᵢ, z)] - 𝔼_q(z)[log q(z)] - # = 𝔼_q̃(ϕ)[log p(xᵢ, z)] + 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] + ℍ(q̃(ϕ)) - - # If f: supp(p(z | x)) → ℝ then - # ELBO = 𝔼[log p(x, z) - log q(z)] - # = 𝔼[log p(x, f⁻¹(z̃)) + logabsdet(J(f⁻¹(z̃)))] + ℍ(q̃(z̃)) - # = 𝔼[log p(x, z) - logabsdetjac(J(f(z)))] + ℍ(q̃(z̃)) - - # But our `rand_and_logjac(q)` is using f⁻¹: ℝ → supp(p(z | x)) going forward → `+ logjac` - z, logjac = rand_and_logjac(rng, q) - res = (logπ(z) + logjac) / num_samples - - if q isa TransformedDistribution - res += entropy(q.dist) - else - res += entropy(q) - end - - for i = 2:num_samples - z, logjac = rand_and_logjac(rng, q) - res += (logπ(z) + logjac) / num_samples - end - - return res -end diff --git a/src/compat/enzyme.jl b/src/compat/enzyme.jl deleted file mode 100644 index c6bb9ac3..00000000 --- a/src/compat/enzyme.jl +++ /dev/null @@ -1,5 +0,0 @@ -struct EnzymeAD <: ADBackend end -ADBackend(::Val{:enzyme}) = EnzymeAD -function setadbackend(::Val{:enzyme}) - ADBACKEND[] = :enzyme -end diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl deleted file mode 100644 index 721d0361..00000000 --- a/src/compat/reversediff.jl +++ /dev/null @@ -1,16 +0,0 @@ -using .ReverseDiff: compile, GradientTape -using .ReverseDiff.DiffResults: GradientResult - -struct ReverseDiffAD{cache} <: ADBackend end -const RDCache = Ref(false) -setcache(b::Bool) = RDCache[] = b -getcache() = RDCache[] -ADBackend(::Val{:reversediff}) = ReverseDiffAD{getcache()} -function setadbackend(::Val{:reversediff}) - ADBACKEND[] = :reversediff -end - -tape(f, x) = GradientTape(f, x) -function taperesult(f, x) - return tape(f, x), GradientResult(x) -end diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl deleted file mode 100644 index 40022e21..00000000 --- a/src/compat/zygote.jl +++ /dev/null @@ -1,5 +0,0 @@ -struct ZygoteAD <: ADBackend end -ADBackend(::Val{:zygote}) = ZygoteAD -function setadbackend(::Val{:zygote}) - ADBACKEND[] = :zygote -end diff --git a/src/objectives.jl b/src/objectives.jl deleted file mode 100644 index 5a6b61b0..00000000 --- a/src/objectives.jl +++ /dev/null @@ -1,7 +0,0 @@ -struct ELBO <: VariationalObjective end - -function (elbo::ELBO)(alg, q, logπ, num_samples; kwargs...) - return elbo(Random.default_rng(), alg, q, logπ, num_samples; kwargs...) -end - -const elbo = ELBO() diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl new file mode 100644 index 00000000..6c5b4739 --- /dev/null +++ b/src/objectives/elbo/entropy.jl @@ -0,0 +1,48 @@ + +""" + ClosedFormEntropy() + +Use closed-form expression of entropy. + +# Requirements +- The variational approximation implements `entropy`. + +# References +* Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR. +* Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research. +""" +struct ClosedFormEntropy <: AbstractEntropyEstimator end + +maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q + +function estimate_entropy(::ClosedFormEntropy, ::Any, q) + entropy(q) +end + +""" + StickingTheLandingEntropy() + +The "sticking the landing" entropy estimator. + +# Requirements +- The variational approximation `q` implements `logpdf`. +- `logpdf(q, η)` must be differentiable by the selected AD framework. + +# References +* Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30. +""" +struct StickingTheLandingEntropy <: AbstractEntropyEstimator end + +struct MonteCarloEntropy <: AbstractEntropyEstimator end + +maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop + +function estimate_entropy( + ::Union{MonteCarloEntropy, StickingTheLandingEntropy}, + mc_samples::AbstractMatrix, + q +) + mean(eachcol(mc_samples)) do mc_sample + -logpdf(q, mc_sample) + end +end diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl new file mode 100644 index 00000000..04f35320 --- /dev/null +++ b/src/objectives/elbo/repgradelbo.jl @@ -0,0 +1,125 @@ + +""" + RepGradELBO(n_samples; kwargs...) + +Evidence lower-bound objective with the reparameterization gradient formulation[^TL2014][^RMW2014][^KW2014]. +This computes the evidence lower-bound (ELBO) through the formulation: +```math +\\begin{aligned} +\\mathrm{ELBO}\\left(\\lambda\\right) +&\\triangleq +\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[ + \\log \\pi\\left(z\\right) +\\right] ++ \\mathbb{H}\\left(q_{\\lambda}\\right), +\\end{aligned} +``` + +# Arguments +- `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO. + +# Keyword Arguments +- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy()) + +# Requirements +- ``q_{\\lambda}`` implements `rand`. +- The target `logdensity(prob, x)` must be differentiable wrt. `x` by the selected AD backend. + +Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. + +# References +[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In ICML. +[^RMW2014]: Rezende, D. J., Mohamed, S., & Wierstra, D. (2014, June). Stochastic backpropagation and approximate inference in deep generative models. In ICML. +[^KW2014]: Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In ICLR. +""" +struct RepGradELBO{EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective + entropy ::EntropyEst + n_samples::Int +end + +RepGradELBO( + n_samples::Int; + entropy ::AbstractEntropyEstimator = ClosedFormEntropy() +) = RepGradELBO(entropy, n_samples) + +function Base.show(io::IO, obj::RepGradELBO) + print(io, "RepGradELBO(entropy=") + print(io, obj.entropy) + print(io, ", n_samples=") + print(io, obj.n_samples) + print(io, ")") +end + +function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop) + q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) + estimate_entropy(entropy_estimator, samples, q_maybe_stop) +end + +function estimate_energy_with_samples(prob, samples) + mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) +end + +""" + reparam_with_entropy(rng, q, q_stop, n_samples, ent_est) + +Draw `n_samples` from `q` and compute its entropy. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `q`: Variational approximation. +- `q_stop`: `q` but with its gradient stopped. +- `n_samples::Int`: Number of Monte Carlo samples +- `ent_est`: The entropy estimation strategy. (See `estimate_entropy`.) + +# Returns +- `samples`: Monte Carlo samples generated through reparameterization. Their support matches that of the target distribution. +- `entropy`: An estimate (or exact value) of the differential entropy of `q`. +""" +function reparam_with_entropy( + rng::Random.AbstractRNG, q, q_stop, n_samples::Int, ent_est::AbstractEntropyEstimator +) + samples = rand(rng, q, n_samples) + entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop) + samples, entropy +end + +function estimate_objective( + rng::Random.AbstractRNG, + obj::RepGradELBO, + q, + prob; + n_samples::Int = obj.n_samples +) + samples, entropy = reparam_with_entropy(rng, q, q, n_samples, obj.entropy) + energy = estimate_energy_with_samples(prob, samples) + energy + entropy +end + +estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) = + estimate_objective(Random.default_rng(), obj, q, prob; n_samples) + +function estimate_gradient!( + rng ::Random.AbstractRNG, + obj ::RepGradELBO, + adbackend::ADTypes.AbstractADType, + out ::DiffResults.MutableDiffResult, + prob, + λ, + restructure, + state, +) + q_stop = restructure(λ) + function f(λ′) + q = restructure(λ′) + samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy) + energy = estimate_energy_with_samples(prob, samples) + elbo = energy + entropy + -elbo + end + value_and_gradient!(adbackend, f, λ, out) + + nelbo = DiffResults.value(out) + stat = (elbo=-nelbo,) + + out, nothing, stat +end diff --git a/src/optimisers.jl b/src/optimisers.jl deleted file mode 100644 index 8077f98c..00000000 --- a/src/optimisers.jl +++ /dev/null @@ -1,94 +0,0 @@ -const ϵ = 1e-8 - -""" - TruncatedADAGrad(η=0.1, τ=1.0, n=100) - -Implements a truncated version of AdaGrad in the sense that only the `n` previous gradient norms are used to compute the scaling rather than *all* previous. It has parameter specific learning rates based on how frequently it is updated. - -## Parameters - - η: learning rate - - τ: constant scale factor - - n: number of previous gradient norms to use in the scaling. -``` -## References -[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. -Parameters don't need tuning. - -[TruncatedADAGrad](https://arxiv.org/abs/1506.03431v2) (Appendix E). -""" -mutable struct TruncatedADAGrad - eta::Float64 - tau::Float64 - n::Int - - iters::IdDict - acc::IdDict -end - -function TruncatedADAGrad(η = 0.1, τ = 1.0, n = 100) - TruncatedADAGrad(η, τ, n, IdDict(), IdDict()) -end - -function apply!(o::TruncatedADAGrad, x, Δ) - T = eltype(Tracker.data(Δ)) - - η = o.eta - τ = o.tau - - g² = get!( - o.acc, - x, - [zeros(T, size(x)) for j = 1:o.n] - )::Array{typeof(Tracker.data(Δ)), 1} - i = get!(o.iters, x, 1)::Int - - # Example: suppose i = 12 and o.n = 10 - idx = mod(i - 1, o.n) + 1 # => idx = 2 - - # set the current - @inbounds @. g²[idx] = Δ^2 # => g²[2] = Δ^2 where Δ is the (o.n + 2)-th Δ - - # TODO: make more efficient and stable - s = sum(g²) - - # increment - o.iters[x] += 1 - - # TODO: increment (but "truncate") - # o.iters[x] = i > o.n ? o.n + mod(i, o.n) : i + 1 - - @. Δ *= η / (τ + sqrt(s) + ϵ) -end - -""" - DecayedADAGrad(η=0.1, pre=1.0, post=0.9) - -Implements a decayed version of AdaGrad. It has parameter specific learning rates based on how frequently it is updated. - -## Parameters - - η: learning rate - - pre: weight of new gradient norm - - post: weight of histroy of gradient norms -``` -## References -[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. -Parameters don't need tuning. -""" -mutable struct DecayedADAGrad - eta::Float64 - pre::Float64 - post::Float64 - - acc::IdDict -end - -DecayedADAGrad(η = 0.1, pre = 1.0, post = 0.9) = DecayedADAGrad(η, pre, post, IdDict()) - -function apply!(o::DecayedADAGrad, x, Δ) - T = eltype(Tracker.data(Δ)) - - η = o.eta - acc = get!(o.acc, x, fill(T(ϵ), size(x)))::typeof(Tracker.data(x)) - @. acc = o.post * acc + o.pre * Δ^2 - @. Δ *= η / (√acc + ϵ) -end diff --git a/src/optimize.jl b/src/optimize.jl new file mode 100644 index 00000000..7e0032dc --- /dev/null +++ b/src/optimize.jl @@ -0,0 +1,161 @@ + +""" + optimize(problem, objective, restructure, param_init, max_iter, objargs...; kwargs...) + optimize(problem, objective, variational_dist_init, max_iter, objargs...; kwargs...) + +Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients. + +The variational approximation can be constructed by passing the variational parameters `param_init` or the initial variational approximation `variational_dist_init` to the function `restructure`. + +# Arguments +- `objective::AbstractVariationalObjective`: Variational Objective. +- `param_init`: Initial value of the variational parameters. +- `restruct`: Function that reconstructs the variational approximation from the flattened parameters. +- `variational_dist_init`: Initial variational distribution. The variational parameters must be extractable through `Optimisers.destructure`. +- `max_iter::Int`: Maximum number of iterations. +- `objargs...`: Arguments to be passed to `objective`. + +# Keyword Arguments +- `adbackend::ADtypes.AbstractADType`: Automatic differentiation backend. +- `optimizer::Optimisers.AbstractRule`: Optimizer used for inference. (Default: `Adam`.) +- `rng::AbstractRNG`: Random number generator. (Default: `Random.default_rng()`.) +- `show_progress::Bool`: Whether to show the progress bar. (Default: `true`.) +- `callback`: Callback function called after every iteration. See further information below. (Default: `nothing`.) +- `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.) +- `state::NamedTuple`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.) + +# Returns +- `params`: Variational parameters optimizing the variational objective. +- `stats`: Statistics gathered during optimization. +- `state`: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run. + +# Callback +The callback function `callback` has a signature of + + callback(; stat, state, param, restructure, gradient) + +The arguments are as follows: +- `stat`: Statistics gathered during the current iteration. The content will vary depending on `objective`. +- `state`: Collection of the internal states used for optimization. +- `param`: Variational parameters. +- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(param)` reconstructs the variational approximation. +- `gradient`: The estimated (possibly stochastic) gradient. + +`cb` can return a `NamedTuple` containing some additional information computed within `cb`. +This will be appended to the statistic of the current corresponding iteration. +Otherwise, just return `nothing`. + +""" + +function optimize( + rng ::Random.AbstractRNG, + problem, + objective ::AbstractVariationalObjective, + restructure, + params_init, + max_iter ::Int, + objargs...; + adbackend ::ADTypes.AbstractADType, + optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), + show_progress::Bool = true, + state_init ::NamedTuple = NamedTuple(), + callback = nothing, + prog = ProgressMeter.Progress( + max_iter; + desc = "Optimizing", + barlen = 31, + showspeed = true, + enabled = show_progress + ) +) + λ = copy(params_init) + opt_st = maybe_init_optimizer(state_init, optimizer, λ) + obj_st = maybe_init_objective(state_init, rng, objective, λ, restructure) + grad_buf = DiffResults.DiffResult(zero(eltype(λ)), similar(λ)) + stats = NamedTuple[] + + for t = 1:max_iter + stat = (iteration=t,) + + grad_buf, obj_st, stat′ = estimate_gradient!( + rng, objective, adbackend, grad_buf, problem, + λ, restructure, obj_st, objargs... + ) + stat = merge(stat, stat′) + + g = DiffResults.gradient(grad_buf) + opt_st, λ = Optimisers.update!(opt_st, λ, g) + + if !isnothing(callback) + stat′ = callback( + ; stat, restructure, params=λ, gradient=g, + state=(optimizer=opt_st, objective=obj_st) + ) + stat = !isnothing(stat′) ? merge(stat′, stat) : stat + end + + @debug "Iteration $t" stat... + + pm_next!(prog, stat) + push!(stats, stat) + end + state = (optimizer=opt_st, objective=obj_st) + stats = map(identity, stats) + params = λ + params, stats, state +end + +function optimize( + problem, + objective ::AbstractVariationalObjective, + restructure, + params_init, + max_iter ::Int, + objargs...; + kwargs... +) + optimize( + Random.default_rng(), + problem, + objective, + restructure, + params_init, + max_iter, + objargs...; + kwargs... + ) +end + +function optimize(rng ::Random.AbstractRNG, + problem, + objective ::AbstractVariationalObjective, + variational_dist_init, + n_max_iter ::Int, + objargs...; + kwargs...) + λ, restructure = Optimisers.destructure(variational_dist_init) + λ, logstats, state = optimize( + rng, problem, objective, restructure, λ, n_max_iter, objargs...; kwargs... + ) + restructure(λ), logstats, state +end + + +function optimize( + problem, + objective ::AbstractVariationalObjective, + variational_dist_init, + max_iter ::Int, + objargs...; + kwargs... +) + optimize( + Random.default_rng(), + problem, + objective, + variational_dist_init, + max_iter, + objargs...; + kwargs... + ) +end diff --git a/src/utils.jl b/src/utils.jl index bb4c1f18..8e67ff1a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,15 +1,36 @@ -using Distributions -using Bijectors: Bijectors +function pm_next!(pm, stats::NamedTuple) + ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) +end +function maybe_init_optimizer( + state_init::NamedTuple, + optimizer ::Optimisers.AbstractRule, + λ ::AbstractVector +) + haskey(state_init, :optimizer) ? state_init.optimizer : Optimisers.setup(optimizer, λ) +end -function rand_and_logjac(rng::Random.AbstractRNG, dist::Distribution) - x = rand(rng, dist) - return x, zero(eltype(x)) +function maybe_init_objective( + state_init::NamedTuple, + rng ::Random.AbstractRNG, + objective ::AbstractVariationalObjective, + λ ::AbstractVector, + restructure +) + haskey(state_init, :objective) ? state_init.objective : init(rng, objective, λ, restructure) end -function rand_and_logjac(rng::Random.AbstractRNG, dist::Bijectors.TransformedDistribution) - x = rand(rng, dist.dist) - y, logjac = Bijectors.with_logabsdet_jacobian(dist.transform, x) - return y, logjac +eachsample(samples::AbstractMatrix) = eachcol(samples) + +eachsample(samples::AbstractVector) = samples + +function catsamples_and_acc( + state_curr::Tuple{<:AbstractArray, <:Real}, + state_new ::Tuple{<:AbstractVector, <:Real} +) + x = hcat(first(state_curr), first(state_new)) + ∑y = last(state_curr) + last(state_new) + return (x, ∑y) end + diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 00000000..a751b89d --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,45 @@ +[deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +ADTypes = "0.2.1" +Bijectors = "0.13" +Distributions = "0.25.100" +DistributionsAD = "0.6.45" +Enzyme = "0.11.7" +FillArrays = "1.6.1" +ForwardDiff = "0.10.36" +Functors = "0.4.5" +LinearAlgebra = "1" +LogDensityProblems = "2.1.1" +Optimisers = "0.2.16, 0.3" +PDMats = "0.11.7" +Random = "1" +ReverseDiff = "1.15.1" +SimpleUnPack = "1.1.0" +StableRNGs = "1.0.0" +Statistics = "1" +Test = "1" +Tracker = "0.2.20" +Zygote = "0.6.63" +julia = "1.6" diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl new file mode 100644 index 00000000..b6db22a6 --- /dev/null +++ b/test/inference/repgradelbo_distributionsad.jl @@ -0,0 +1,78 @@ + +const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false + +using Test + +@testset "inference RepGradELBO DistributionsAD" begin + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for + realtype ∈ [Float64, Float32], + (modelname, modelconstr) ∈ Dict( + :Normal=> normal_meanfield, + ), + (objname, objective) ∈ Dict( + :RepGradELBOClosedFormEntropy => RepGradELBO(10), + :RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()), + ), + (adbackname, adbackend) ∈ Dict( + :ForwarDiff => AutoForwardDiff(), + #:ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + #:Enzyme => AutoEnzyme(), + ) + + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = modelconstr(rng, realtype) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) + + μ0 = Zeros(realtype, n_dims) + L0 = Diagonal(Ones(realtype, n_dims)) + q0 = TuringDiagMvNormal(μ0, diag(L0)) + + @testset "convergence" begin + Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) + q, stats, _ = optimize( + rng, model, objective, q0, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + + μ = mean(q) + L = sqrt(cov(q)) + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ Δλ₀/T^(1/4) + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = StableRNG(seed) + q, stats, _ = optimize( + rng, model, objective, q0, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + μ = mean(q) + L = sqrt(cov(q)) + + rng_repl = StableRNG(seed) + q, stats, _ = optimize( + rng_repl, model, objective, q0, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + μ_repl = mean(q) + L_repl = sqrt(cov(q)) + @test μ == μ_repl + @test L == L_repl + end + end +end + diff --git a/test/inference/repgradelbo_distributionsad_bijectors.jl b/test/inference/repgradelbo_distributionsad_bijectors.jl new file mode 100644 index 00000000..53e9e62f --- /dev/null +++ b/test/inference/repgradelbo_distributionsad_bijectors.jl @@ -0,0 +1,86 @@ + +const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false + +using Test + +@testset "inference RepGradELBO DistributionsAD Bijectors" begin + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for + realtype ∈ [Float64, Float32], + (modelname, modelconstr) ∈ Dict( + :NormalLogNormalMeanField => normallognormal_meanfield, + ), + (objname, objective) ∈ Dict( + :RepGradELBOClosedFormEntropy => RepGradELBO(10), + :RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()), + ), + (adbackname, adbackend) ∈ Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + #:Zygote => AutoZygote(), + #:Enzyme => AutoEnzyme(), + ) + + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = modelconstr(rng, realtype) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) + + b = Bijectors.bijector(model) + b⁻¹ = inverse(b) + μ0 = Zeros(realtype, n_dims) + L0 = Diagonal(Ones(realtype, n_dims)) + + q0_η = if is_meanfield + MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) + else + L0 = Matrix{realtype}(I, n_dims, n_dims) |> LowerTriangular + FullRankGaussian(zeros(realtype, n_dims), L0) + end + q0_z = Bijectors.transformed(q0_η, b⁻¹) + + @testset "convergence" begin + Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) + q, stats, _ = optimize( + rng, model, objective, q0_z, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + + μ = q.dist.location + L = q.dist.scale + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ Δλ₀/T^(1/4) + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = StableRNG(seed) + q, stats, _ = optimize( + rng, model, objective, q0_z, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + μ = q.dist.location + L = q.dist.scale + + rng_repl = StableRNG(seed) + q, stats, _ = optimize( + rng_repl, model, objective, q0_z, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + μ_repl = mean(q.dist) + L_repl = sqrt(cov(q.dist)) + @test μ == μ_repl + @test L == L_repl + end + end +end diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl new file mode 100644 index 00000000..d5177fb6 --- /dev/null +++ b/test/inference/repgradelbo_locationscale.jl @@ -0,0 +1,82 @@ + +const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false + +using Test + +@testset "inference RepGradELBO DistributionsAD" begin + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for + realtype ∈ [Float64, Float32], + (modelname, modelconstr) ∈ Dict( + :Normal=> normal_meanfield, + :Normal=> normal_fullrank, + ), + (objname, objective) ∈ Dict( + :RepGradELBOClosedFormEntropy => RepGradELBO(10), + :RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()), + ), + (adbackname, adbackend) ∈ Dict( + :ForwarDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + #:Enzyme => AutoEnzyme(), + ) + + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = modelconstr(rng, realtype) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) + + q0 = if is_meanfield + MeanFieldGaussian(zeros(realtype, n_dims), Diagonal(ones(realtype, n_dims))) + else + L0 = Matrix{realtype}(I, n_dims, n_dims) |> LowerTriangular + FullRankGaussian(zeros(realtype, n_dims), L0) + end + + @testset "convergence" begin + Δλ₀ = sum(abs2, q0.location - μ_true) + sum(abs2, q0.scale - L_true) + q, stats, _ = optimize( + rng, model, objective, q0, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + + μ = q.location + L = q.scale + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ Δλ₀/T^(1/4) + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = StableRNG(seed) + q, stats, _ = optimize( + rng, model, objective, q0, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + μ = q.location + L = q.scale + + rng_repl = StableRNG(seed) + q, stats, _ = optimize( + rng_repl, model, objective, q0, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + μ_repl = q.location + L_repl = q.scale + @test μ == μ_repl + @test L == L_repl + end + end +end + diff --git a/test/interface/ad.jl b/test/interface/ad.jl new file mode 100644 index 00000000..b716ca2f --- /dev/null +++ b/test/interface/ad.jl @@ -0,0 +1,22 @@ + +using Test + +@testset "ad" begin + @testset "$(adname)" for (adname, adsymbol) ∈ Dict( + :ForwardDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + # :Enzyme => AutoEnzyme(), # Currently not tested against. + ) + D = 10 + A = randn(D, D) + λ = randn(D) + grad_buf = DiffResults.GradientResult(λ) + f(λ′) = λ′'*A*λ′ / 2 + AdvancedVI.value_and_gradient!(adsymbol, f, λ, grad_buf) + ∇ = DiffResults.gradient(grad_buf) + f = DiffResults.value(grad_buf) + @test ∇ ≈ (A + A')*λ/2 + @test f ≈ λ'*A*λ / 2 + end +end diff --git a/test/interface/optimize.jl b/test/interface/optimize.jl new file mode 100644 index 00000000..6e69616b --- /dev/null +++ b/test/interface/optimize.jl @@ -0,0 +1,98 @@ + +using Test + +@testset "interface optimize" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + T = 1000 + modelstats = normal_meanfield(rng, Float64) + + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + # Global Test Configurations + q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) + obj = RepGradELBO(10) + + adbackend = AutoForwardDiff() + optimizer = Optimisers.Adam(1e-2) + + rng = StableRNG(seed) + q_ref, stats_ref, _ = optimize( + rng, model, obj, q0, T; + optimizer, + show_progress = false, + adbackend, + ) + λ_ref, _ = Optimisers.destructure(q_ref) + + @testset "default_rng" begin + optimize( + model, obj, q0, T; + optimizer, + show_progress = false, + adbackend, + ) + + λ₀, re = Optimisers.destructure(q0) + optimize( + model, obj, re, λ₀, T; + optimizer, + show_progress = false, + adbackend, + ) + end + + @testset "restructure" begin + λ₀, re = Optimisers.destructure(q0) + + rng = StableRNG(seed) + λ, stats, _ = optimize( + rng, model, obj, re, λ₀, T; + optimizer, + show_progress = false, + adbackend, + ) + @test λ == λ_ref + @test stats == stats_ref + end + + @testset "callback" begin + rng = StableRNG(seed) + test_values = rand(rng, T) + + callback(; stat, args...) = (test_value = test_values[stat.iteration],) + + rng = StableRNG(seed) + _, stats, _ = optimize( + rng, model, obj, q0, T; + show_progress = false, + adbackend, + callback + ) + @test [stat.test_value for stat ∈ stats] == test_values + end + + @testset "warm start" begin + rng = StableRNG(seed) + + T_first = div(T,2) + T_last = T - T_first + + q_first, _, state = optimize( + rng, model, obj, q0, T_first; + optimizer, + show_progress = false, + adbackend + ) + + q, stats, _ = optimize( + rng, model, obj, q_first, T_last; + optimizer, + show_progress = false, + state_init = state, + adbackend + ) + @test q == q_ref + end +end diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl new file mode 100644 index 00000000..61ff0111 --- /dev/null +++ b/test/interface/repgradelbo.jl @@ -0,0 +1,28 @@ + +using Test + +@testset "interface RepGradELBO" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = normal_meanfield(rng, Float64) + + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) + + obj = RepGradELBO(10) + rng = StableRNG(seed) + elbo_ref = estimate_objective(rng, obj, q0, model; n_samples=10^4) + + @testset "determinism" begin + rng = StableRNG(seed) + elbo = estimate_objective(rng, obj, q0, model; n_samples=10^4) + @test elbo == elbo_ref + end + + @testset "default_rng" begin + elbo = estimate_objective(obj, q0, model; n_samples=10^4) + @test elbo ≈ elbo_ref rtol=0.1 + end +end diff --git a/test/models/normal.jl b/test/models/normal.jl new file mode 100644 index 00000000..3f305e1a --- /dev/null +++ b/test/models/normal.jl @@ -0,0 +1,43 @@ + +struct TestNormal{M,S} + μ::M + Σ::S +end + +function LogDensityProblems.logdensity(model::TestNormal, θ) + @unpack μ, Σ = model + logpdf(MvNormal(μ, Σ), θ) +end + +function LogDensityProblems.dimension(model::TestNormal) + length(model.μ) +end + +function LogDensityProblems.capabilities(::Type{<:TestNormal}) + LogDensityProblems.LogDensityOrder{0}() +end + +function normal_fullrank(rng::Random.AbstractRNG, realtype::Type) + n_dims = 5 + + μ = randn(rng, realtype, n_dims) + L = tril(I + ones(realtype, n_dims, n_dims))/2 + Σ = L*L' |> Hermitian + + model = TestNormal(μ, PDMat(Σ, Cholesky(L, 'L', 0))) + + TestModel(model, μ, L, n_dims, false) +end + +function normal_meanfield(rng::Random.AbstractRNG, realtype::Type) + n_dims = 5 + + μ = randn(rng, realtype, n_dims) + σ = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) + + model = TestNormal(μ, Diagonal(σ.^2)) + + L = σ |> Diagonal + + TestModel(model, μ, L, n_dims, true) +end diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl new file mode 100644 index 00000000..6615084b --- /dev/null +++ b/test/models/normallognormal.jl @@ -0,0 +1,65 @@ + +struct NormalLogNormal{MX,SX,MY,SY} + μ_x::MX + σ_x::SX + μ_y::MY + Σ_y::SY +end + +function LogDensityProblems.logdensity(model::NormalLogNormal, θ) + @unpack μ_x, σ_x, μ_y, Σ_y = model + logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) +end + +function LogDensityProblems.dimension(model::NormalLogNormal) + length(model.μ_y) + 1 +end + +function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) + LogDensityProblems.LogDensityOrder{0}() +end + +function Bijectors.bijector(model::NormalLogNormal) + @unpack μ_x, σ_x, μ_y, Σ_y = model + Bijectors.Stacked( + Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), + [1:1, 2:1+length(μ_y)]) +end + +function normallognormal_fullrank(rng::Random.AbstractRNG, realtype::Type) + n_dims = 5 + + μ_x = randn(rng, realtype) + σ_x = ℯ + μ_y = randn(rng, realtype, n_dims) + L_y = tril(I + ones(realtype, n_dims, n_dims))/2 + Σ_y = L_y*L_y' |> Hermitian + + model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y, 'L', 0))) + + Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1) + Σ[1,1] = σ_x^2 + Σ[2:end,2:end] = Σ_y + Σ = Σ |> Hermitian + + μ = vcat(μ_x, μ_y) + L = cholesky(Σ).L + + TestModel(model, μ, L, n_dims+1, false) +end + +function normallognormal_meanfield(rng::Random.AbstractRNG, realtype::Type) + n_dims = 5 + + μ_x = randn(rng, realtype) + σ_x = ℯ + μ_y = randn(rng, realtype, n_dims) + σ_y = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) + + model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)) + + μ = vcat(μ_x, μ_y) + L = vcat(σ_x, σ_y) |> Diagonal + + TestModel(model, μ, L, n_dims+1, true) +end diff --git a/test/optimisers.jl b/test/optimisers.jl deleted file mode 100644 index fae652ed..00000000 --- a/test/optimisers.jl +++ /dev/null @@ -1,17 +0,0 @@ -using Random, Test, LinearAlgebra, ForwardDiff -using AdvancedVI: TruncatedADAGrad, DecayedADAGrad, apply! - -θ = randn(10, 10) -@testset for opt in [TruncatedADAGrad(), DecayedADAGrad(1e-2)] - θ_fit = randn(10, 10) - loss(x, θ_) = mean(sum(abs2, θ*x - θ_*x; dims = 1)) - for t = 1:10^4 - x = rand(10) - Δ = ForwardDiff.gradient(θ_ -> loss(x, θ_), θ_fit) - Δ = apply!(opt, θ_fit, Δ) - @. θ_fit = θ_fit - Δ - end - @test loss(rand(10, 100), θ_fit) < 0.01 - @test length(opt.acc) == 1 -end - diff --git a/test/runtests.jl b/test/runtests.jl index a305c25e..1d99881b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,28 +1,43 @@ -using Test -using Distributions, DistributionsAD -using AdvancedVI - -include("optimisers.jl") - -target = MvNormal(ones(2)) -logπ(z) = logpdf(target, z) -advi = ADVI(10, 1000) -# Using a function z ↦ q(⋅∣z) -getq(θ) = TuringDiagMvNormal(θ[1:2], exp.(θ[3:4])) -q = vi(logπ, advi, getq, randn(4)) +using Test +using Test: @testset, @test + +using Bijectors +using Random, StableRNGs +using Statistics +using Distributions +using LinearAlgebra +using SimpleUnPack: @unpack +using FillArrays +using PDMats + +using Functors +using DistributionsAD +@functor TuringDiagMvNormal + +using LogDensityProblems +using Optimisers +using ADTypes +using ForwardDiff, ReverseDiff, Zygote -xs = rand(target, 10) -@test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.05 +using AdvancedVI -# OR: implement `update` and pass a `Distribution` -function AdvancedVI.update(d::TuringDiagMvNormal, θ::AbstractArray{<:Real}) - return TuringDiagMvNormal(θ[1:length(q)], exp.(θ[length(q) + 1:end])) +# Models for Inference Tests +struct TestModel{M,L,S} + model::M + μ_true::L + L_true::S + n_dims::Int + is_meanfield::Bool end +include("models/normal.jl") +include("models/normallognormal.jl") -q0 = TuringDiagMvNormal(zeros(2), ones(2)) -q = vi(logπ, advi, q0, randn(4)) - -xs = rand(target, 10) -@test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.05 +# Tests +include("interface/ad.jl") +include("interface/optimize.jl") +include("interface/repgradelbo.jl") +include("inference/repgradelbo_distributionsad.jl") +include("inference/repgradelbo_locationscale.jl") +include("inference/repgradelbo_distributionsad_bijectors.jl")