From 576259ad2a4de343244ce8b6cce8b50e9f20dcd4 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 8 Dec 2023 03:30:22 -0500 Subject: [PATCH] Basic rewrite of the package 2023 edition Part I: ADVI (#49) * refactor ADVI, change gradient operation interface * remove unused file, remove unused dependency * fix ADVI elbo computation more efficiently * fix missing entropy regularization term * add LogDensityProblem interface * refactor use bijectors directly instead of transformed distributions This is to avoid having to reconstruct transformed distributions all the time. The direct use of bijectors also avoids going through lots of abstraction layers that could break. Instead, transformed distributions could be constructed only once when returing the VI result. * fix type restrictions * remove unused file * fix use of with_logabsdet_jacobian * restructure project; move the main VI routine to its own file * remove redundant import * restructure project into more modular objective estimators * migrate to AbstractDifferentiation * add location scale pre-packaged variational family, add functors * Revert "migrate to AbstractDifferentiation" This reverts commit 2a4514e4ff0ab0459b7ed78dcdee2f61be61c691. * fix use optimized MvNormal specialization, add logpdf for Loc.Scale. * remove dead code * fix location-scale logpdf - Full Monte Carlo ELBO estimation now works. I checked. * add sticking-the-landing (STL) estimator * migrate to Optimisers.jl * remove execution time measurement (replace later with somethin else) * fix use multiple dispatch for deciding whether to stop entropy grad. * add termination decision, callback arguments * add Base.show to modules * add interface calling `restructure`, rename rebuild -> restructure * add estimator state interface, add control variate interface to ADVI * fix `show(advi)` to show control variate * fix simplify `show(advi.control_variate)` * fix type piracy by wrapping location-scale bijected distribution * remove old AdvancedVI custom optimizers * fix Location Scale to not depend on Bijectors * fix RNG namespace * fix location scale logpdf bug * add Accessors dependency * add location scale, autodiff tests * add Accessors import statement * remove optimiser tests * refactor slightly generalize the distribution tests for the future * migrate to SimpleUnPack, migrate to ADTypes * rename vi.jl to optimize.jl * fix estimate_gradient to use adtypes * add exact inference tests * remove Turing dependency in tests * remove unused projection * remove redundant `ADVIEnergy` object (now baked into `ADVI`) * add more tests, fix rng seed for tests * add more tests, fix seed for tests * fix non-determinism bug * fix test hyperparameters so that tests pass, minor cleanups * fix minor reorganization * add missing files * fix add missing file, rename adbackend argument * fix errors * rename test suite * refactor renamed arguments for ADVI to be shorter * fix compile error in advi test * add initial doc * remove unused epsilon argument in location scale * add project file for documenter * refactor STL gradient calculation to use multiple dispatch * fix type bugs, relax test threshold for the exact inference tests * refactor derivative utils to match NormalizingFlows.jl with extras * add documentation, refactor optimize * fix bug missing extension * remove tracker from tests * remove export for internal derivative utils * fix test errors, old interface * fix wrong derivative interface, add documentation * update documentation * add doc build CI * remove convergence criterion for now * remove outdated export * update documentation * update documentation * update documentation * fix type error in test * remove default ADType argument * update README * update make getting started example actually run Julia * fix remove Float32 tests for inference tests * update version * add documentation publishing url * fix wrong uuid for ForwardDiff * Update CI.yml * refactor use `sum` and `mean` instead of abusing `mapreduce` * remove tests for `FullMonteCarlo` * add tests for the `optimize` interface * fix turn off Zygote tests for now * remove unused function * refactor change bijector field name, simplify STL estimator * update documentation * update STL documentation * update STL documentation * update location scale documentation * fix README * fix math in README * add gradient to arguments of callback!, remove `gradient_norm` info * fix math in README.md Co-authored-by: David Widmann * fix type constraint in `ZygoteExt` Co-authored-by: David Widmann * fix import of `Random` Co-authored-by: David Widmann * refactor `__init__()` Co-authored-by: David Widmann * fix type constraint in definition of `value_and_gradient!` Co-authored-by: David Widmann * refactor `ZygoteExt`; use `only` instead of `first` Co-authored-by: David Widmann * refactor type constraint in `ReverseDiffExt` Co-authored-by: David Widmann * refactor remove outdated debug mode macro * fix remove outdated DEBUG mechanism * fix LaTeX in README: `operatorname` is currently broken * remove `SimpleUnPack` dependency * fix LaTeX in docs and README * add warning about forward-mode AD when using `LocationScale` * fix documentation * fix remove reamining use of `@unpack` * Revert "remove `SimpleUnPack` dependency" This reverts commit 29d7d27ca227413275174e12f9258b13b8276fd0. * Revert "fix remove reamining use of `@unpack`" This reverts commit 817374403e58cb11e4e0e3aaee045c350d5bdfdc. * fix documentation for `optimize` * add specializations of `Optimise.destructure` for mean-field * This fixes the poor performance of `ForwardDiff` * This prevents the zero elements of the mean-field scale being extracted * add test for `Optimisers.destructure` specializations * add specialization of `rand` for meanfield resulting in faster AD * add argument checks for `VIMeanFieldGaussian`, `VIFullRankGaussian` * update documentation * fix type instability, bug in argument check in `LocationScale` * add missing import bug * refactor test, fix type bug in tests for `LocationScale` * add missing compat entries * fix missing package import in test * add additional tests for sampling `LocationScale` * fix bug in batch in-place `rand!` for `LocationScale` * fix bug in inference test initialization * add missing file * fix remove use of for 1.6 * refactor adjust inference test hyperparameters to be more robust * refactor `optimize` to return `obj_state`, add warm start kwargs * refactor make tests more robust, reduce amount of tests * fix remove a cholesky in test model * fix compat bounds, remove unused package * bump compat for ADTypes 0.2 * fix broken LaTeX in README * remove redundant use of PDMats in docs * fix use `Cholesky` signature supported in 1.6 * revert custom variational families and docs * remove doc action for now * revert README for now * refactor remove redundant `rng` argument to `ADVI`, improve docs * fix wrong whitespace in tests * refactor `estimate_gradient` to `estimate_gradient!`, add docs * refactor add default `init` impl, update docs * merge (manually) commit ff32ac642d6aa3a08d371ed895aa6b4026b06b92 * fix test for new interface, change interface for `optimize`, `advi` * fix integer subtype error in documentation of advi Co-authored-by: Tor Erlend Fjelde * fix remove redundant argument for `advi` * remove manifest * refactor remove imports and use fully qualified names * update documentation for `AbstractVariationalObjective` Co-authored-by: Tor Erlend Fjelde * refactor use StableRNG instead of Random123 * refactor migrate to Test, re-enable x86 tests * refactor remove inner constructor for `ADVI` * fix swap `export`s and `include`s Co-authored-by: Tor Erlend Fjelde * fix doscs for `ADVI` Co-authored-by: Tor Erlend Fjelde * fix use `FillArrays` in the test problems Co-authored-by: Tor Erlend Fjelde * fix `optimize` docs Co-authored-by: Tor Erlend Fjelde * fix improve argument names and docs for `optimize` * fix tests to match new interface of `optimize` * refactor move utility functions to new file * fix docs for `optimize` Co-authored-by: Tor Erlend Fjelde * refactor advi internal objective Co-authored-by: Tor Erlend Fjelde * refactor move `rng` to be an optional first argument * fix docs for optimize * add compat bounds to test dependencies * update compat bound for `Optimisers` * fix test compat * fix remove `!` in callback Co-authored-by: Tor Erlend Fjelde * fix rng argument position in `advi` * fix callback signature in `optimize` * refactor reorganize test files and naming * fix simplify description for `optimize` Co-authored-by: Tor Erlend Fjelde * fix remove redundant `Nothing` type signature for `maybe_init` * fix remove "internal use" warning in documentation * refactor change `estimate_gradient!` signature to be type stable * add signature for computing `advi` over a fixed set of samples * fix change test tolerance * fix update documentation for `estimate_gradient!` * refactor remove type constraint for variational parameters * fix remove dead code * add compat entry for stdlib * add compat entry for stdlib in `test/` * fix rng argument position in tests * refactor change name of inference test * fix documentation for `optimize` * refactor rewrite the documentation for the global interfaces * fix compat error * fix documentation for `optimize` to be single line * refactor remove begin end for one-liner * refactor create unified interface for estimating objectives * refactor unify interface for entropy estimator, fix advi docs * fix STL estimator to use manually stopped gradients instead * add inference test for a non-bijector model * refactor add indirections to handle STL and bijectors in ADVI * refactor split inference tests for advi+distributionsad * refactor rename advi to repgradelbo and not use bijectors directly * fix documentation for estimate_objective * refactor add indirection in repgradelbo for interacting with `q` * add TransformedDistribution support as extension * Update src/objectives/elbo/repgradelbo.jl Co-authored-by: Tor Erlend Fjelde * fix docstring for entropy estimator * fix `reparam_with_entropy` specialization for bijectors * enable Zygote for non-bijector tests --------- Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> Co-authored-by: David Widmann Co-authored-by: Tor Erlend Fjelde --- Project.toml | 62 ++- ext/AdvancedVIBijectorsExt.jl | 45 +++ ext/AdvancedVIEnzymeExt.jl | 26 ++ ext/AdvancedVIForwardDiffExt.jl | 29 ++ ext/AdvancedVIReverseDiffExt.jl | 23 ++ ext/AdvancedVIZygoteExt.jl | 24 ++ src/AdvancedVI.jl | 362 +++++++----------- src/ad.jl | 46 --- src/advi.jl | 99 ----- src/compat/enzyme.jl | 5 - src/compat/reversediff.jl | 16 - src/compat/zygote.jl | 5 - src/objectives.jl | 7 - src/objectives/elbo/entropy.jl | 48 +++ src/objectives/elbo/repgradelbo.jl | 125 ++++++ src/optimisers.jl | 94 ----- src/optimize.jl | 161 ++++++++ src/utils.jl | 39 +- test/Project.toml | 45 +++ test/inference/repgradelbo_distributionsad.jl | 78 ++++ .../repgradelbo_distributionsad_bijectors.jl | 81 ++++ test/interface/ad.jl | 22 ++ test/interface/optimize.jl | 98 +++++ test/interface/repgradelbo.jl | 28 ++ test/models/normal.jl | 43 +++ test/models/normallognormal.jl | 65 ++++ test/optimisers.jl | 17 - test/runtests.jl | 58 +-- 28 files changed, 1191 insertions(+), 560 deletions(-) create mode 100644 ext/AdvancedVIBijectorsExt.jl create mode 100644 ext/AdvancedVIEnzymeExt.jl create mode 100644 ext/AdvancedVIForwardDiffExt.jl create mode 100644 ext/AdvancedVIReverseDiffExt.jl create mode 100644 ext/AdvancedVIZygoteExt.jl delete mode 100644 src/ad.jl delete mode 100644 src/advi.jl delete mode 100644 src/compat/enzyme.jl delete mode 100644 src/compat/reversediff.jl delete mode 100644 src/compat/zygote.jl delete mode 100644 src/objectives.jl create mode 100644 src/objectives/elbo/entropy.jl create mode 100644 src/objectives/elbo/repgradelbo.jl delete mode 100644 src/optimisers.jl create mode 100644 src/optimize.jl create mode 100644 test/Project.toml create mode 100644 test/inference/repgradelbo_distributionsad.jl create mode 100644 test/inference/repgradelbo_distributionsad_bijectors.jl create mode 100644 test/interface/ad.jl create mode 100644 test/interface/optimize.jl create mode 100644 test/interface/repgradelbo.jl create mode 100644 test/models/normal.jl create mode 100644 test/models/normallognormal.jl delete mode 100644 test/optimisers.jl diff --git a/Project.toml b/Project.toml index 800fa34f..f4ea1bcc 100644 --- a/Project.toml +++ b/Project.toml @@ -1,37 +1,71 @@ name = "AdvancedVI" uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" -version = "0.2.4" +version = "0.3.0" [deps] -Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697" +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" -StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" + +[weakdeps] +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[extensions] +AdvancedVIEnzymeExt = "Enzyme" +AdvancedVIForwardDiffExt = "ForwardDiff" +AdvancedVIReverseDiffExt = "ReverseDiff" +AdvancedVIZygoteExt = "Zygote" +AdvancedVIBijectorsExt = "Bijectors" [compat] -Bijectors = "0.11, 0.12, 0.13" -Distributions = "0.21, 0.22, 0.23, 0.24, 0.25" -DistributionsAD = "0.2, 0.3, 0.4, 0.5, 0.6" +ADTypes = "0.1, 0.2" +Accessors = "0.1" +Bijectors = "0.13" +ChainRulesCore = "1.16" +DiffResults = "1" +Distributions = "0.25.87" DocStringExtensions = "0.8, 0.9" -ForwardDiff = "0.10.3" -ProgressMeter = "1.0.0" -Requires = "0.5, 1.0" +Enzyme = "0.11.7" +FillArrays = "1.3" +ForwardDiff = "0.10.36" +Functors = "0.4" +LinearAlgebra = "1" +LogDensityProblems = "2" +Optimisers = "0.2.16, 0.3" +ProgressMeter = "1.6" +Random = "1" +Requires = "1.0" +ReverseDiff = "1.15.1" +SimpleUnPack = "1.1.0" StatsBase = "0.32, 0.33, 0.34" -StatsFuns = "0.8, 0.9, 1" -Tracker = "0.2.3" +Zygote = "0.6.63" julia = "1.6" [extras] +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] test = ["Pkg", "Test"] diff --git a/ext/AdvancedVIBijectorsExt.jl b/ext/AdvancedVIBijectorsExt.jl new file mode 100644 index 00000000..1b200ac5 --- /dev/null +++ b/ext/AdvancedVIBijectorsExt.jl @@ -0,0 +1,45 @@ + +module AdvancedVIBijectorsExt + +if isdefined(Base, :get_extension) + using AdvancedVI + using Bijectors + using Random +else + using ..AdvancedVI + using ..Bijectors + using ..Random +end + +function AdvancedVI.reparam_with_entropy( + rng ::Random.AbstractRNG, + q ::Bijectors.TransformedDistribution, + q_stop ::Bijectors.TransformedDistribution, + n_samples::Int, + ent_est ::AdvancedVI.AbstractEntropyEstimator +) + transform = q.transform + q_base = q.dist + q_base_stop = q_stop.dist + base_samples = rand(rng, q_base, n_samples) + it = AdvancedVI.eachsample(base_samples) + sample_init = first(it) + + samples_and_logjac = mapreduce( + AdvancedVI.catsamples_and_acc, + Iterators.drop(it, 1); + init=with_logabsdet_jacobian(transform, sample_init) + ) do sample + with_logabsdet_jacobian(transform, sample) + end + samples = first(samples_and_logjac) + logjac = last(samples_and_logjac) + + entropy_base = AdvancedVI.estimate_entropy_maybe_stl( + ent_est, base_samples, q_base, q_base_stop + ) + + entropy = entropy_base + logjac/n_samples + samples, entropy +end +end diff --git a/ext/AdvancedVIEnzymeExt.jl b/ext/AdvancedVIEnzymeExt.jl new file mode 100644 index 00000000..8333299f --- /dev/null +++ b/ext/AdvancedVIEnzymeExt.jl @@ -0,0 +1,26 @@ + +module AdvancedVIEnzymeExt + +if isdefined(Base, :get_extension) + using Enzyme + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults +else + using ..Enzyme + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults +end + +# Enzyme doesn't support f::Bijectors (see https://github.com/EnzymeAD/Enzyme.jl/issues/916) +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoEnzyme, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult +) where {T<:Real} + y = f(θ) + DiffResults.value!(out, y) + ∇θ = DiffResults.gradient(out) + fill!(∇θ, zero(T)) + Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, ∇θ)) + return out +end + +end diff --git a/ext/AdvancedVIForwardDiffExt.jl b/ext/AdvancedVIForwardDiffExt.jl new file mode 100644 index 00000000..5949bdf8 --- /dev/null +++ b/ext/AdvancedVIForwardDiffExt.jl @@ -0,0 +1,29 @@ + +module AdvancedVIForwardDiffExt + +if isdefined(Base, :get_extension) + using ForwardDiff + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults +else + using ..ForwardDiff + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults +end + +getchunksize(::ADTypes.AutoForwardDiff{chunksize}) where {chunksize} = chunksize + +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoForwardDiff, f, θ::AbstractVector{T}, out::DiffResults.MutableDiffResult +) where {T<:Real} + chunk_size = getchunksize(ad) + config = if isnothing(chunk_size) + ForwardDiff.GradientConfig(f, θ) + else + ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size)) + end + ForwardDiff.gradient!(out, f, θ, config) + return out +end + +end diff --git a/ext/AdvancedVIReverseDiffExt.jl b/ext/AdvancedVIReverseDiffExt.jl new file mode 100644 index 00000000..520cd9ff --- /dev/null +++ b/ext/AdvancedVIReverseDiffExt.jl @@ -0,0 +1,23 @@ + +module AdvancedVIReverseDiffExt + +if isdefined(Base, :get_extension) + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults + using ReverseDiff +else + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults + using ..ReverseDiff +end + +# ReverseDiff without compiled tape +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoReverseDiff, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult +) + tp = ReverseDiff.GradientTape(f, θ) + ReverseDiff.gradient!(out, tp, θ) + return out +end + +end diff --git a/ext/AdvancedVIZygoteExt.jl b/ext/AdvancedVIZygoteExt.jl new file mode 100644 index 00000000..7b8f8817 --- /dev/null +++ b/ext/AdvancedVIZygoteExt.jl @@ -0,0 +1,24 @@ + +module AdvancedVIZygoteExt + +if isdefined(Base, :get_extension) + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults + using Zygote +else + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults + using ..Zygote +end + +function AdvancedVI.value_and_gradient!( + ad::ADTypes.AutoZygote, f, θ::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult +) + y, back = Zygote.pullback(f, θ) + ∇θ = back(one(y)) + DiffResults.value!(out, y) + DiffResults.gradient!(out, only(∇θ)) + return out +end + +end diff --git a/src/AdvancedVI.jl b/src/AdvancedVI.jl index e203a13c..89f86696 100644 --- a/src/AdvancedVI.jl +++ b/src/AdvancedVI.jl @@ -1,270 +1,180 @@ + module AdvancedVI -using Random: Random +using SimpleUnPack: @unpack, @pack! +using Accessors -using Distributions, DistributionsAD, Bijectors -using DocStringExtensions +using Random +using Distributions -using ProgressMeter, LinearAlgebra +using Functors +using Optimisers -using ForwardDiff -using Tracker +using DocStringExtensions +using ProgressMeter +using LinearAlgebra -const PROGRESS = Ref(true) -function turnprogress(switch::Bool) - @info("[AdvancedVI]: global PROGRESS is set as $switch") - PROGRESS[] = switch -end +using LogDensityProblems -const DEBUG = Bool(parse(Int, get(ENV, "DEBUG_ADVANCEDVI", "0"))) +using ADTypes, DiffResults +using ChainRulesCore -include("ad.jl") -include("utils.jl") +using FillArrays -using Requires -function __init__() - @require Flux="587475ba-b771-5e3f-ad9e-33799f191a9c" begin - apply!(o, x, Δ) = Flux.Optimise.apply!(o, x, Δ) - Flux.Optimise.apply!(o::TruncatedADAGrad, x, Δ) = apply!(o, x, Δ) - Flux.Optimise.apply!(o::DecayedADAGrad, x, Δ) = apply!(o, x, Δ) - end - @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin - include("compat/zygote.jl") - export ZygoteAD - - function AdvancedVI.grad!( - vo, - alg::VariationalInference{<:AdvancedVI.ZygoteAD}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... - ) - f(θ) = if (q isa Distribution) - - vo(alg, update(q, θ), model, args...) - else - - vo(alg, q(θ), model, args...) - end - y, back = Zygote.pullback(f, θ) - dy = first(back(1.0)) - DiffResults.value!(out, y) - DiffResults.gradient!(out, dy) - return out - end - end - @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin - include("compat/reversediff.jl") - export ReverseDiffAD - - function AdvancedVI.grad!( - vo, - alg::VariationalInference{<:AdvancedVI.ReverseDiffAD{false}}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... - ) - f(θ) = if (q isa Distribution) - - vo(alg, update(q, θ), model, args...) - else - - vo(alg, q(θ), model, args...) - end - tp = AdvancedVI.tape(f, θ) - ReverseDiff.gradient!(out, tp, θ) - return out - end - end - @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin - include("compat/enzyme.jl") - export EnzymeAD - - function AdvancedVI.grad!( - vo, - alg::VariationalInference{<:AdvancedVI.EnzymeAD}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... - ) - f(θ) = if (q isa Distribution) - - vo(alg, update(q, θ), model, args...) - else - - vo(alg, q(θ), model, args...) - end - # Use `Enzyme.ReverseWithPrimal` once it is released: - # https://github.com/EnzymeAD/Enzyme.jl/pull/598 - y = f(θ) - DiffResults.value!(out, y) - dy = DiffResults.gradient(out) - fill!(dy, 0) - Enzyme.autodiff(Enzyme.ReverseWithPrimal, f, Enzyme.Active, Enzyme.Duplicated(θ, dy)) - return out - end - end -end +using StatsBase -export - vi, - ADVI, - ELBO, - elbo, - TruncatedADAGrad, - DecayedADAGrad, - VariationalInference +# derivatives +""" + value_and_gradient!(ad, f, θ, out) -abstract type VariationalInference{AD} end +Evaluate the value and gradient of a function `f` at `θ` using the automatic differentiation backend `ad` and store the result in `out`. -getchunksize(::Type{<:VariationalInference{AD}}) where AD = getchunksize(AD) -getADtype(::VariationalInference{AD}) where AD = AD +# Arguments +- `ad::ADTypes.AbstractADType`: Automatic differentiation backend. +- `f`: Function subject to differentiation. +- `θ`: The point to evaluate the gradient. +- `out::DiffResults.MutableDiffResult`: Buffer to contain the output gradient and function value. +""" +function value_and_gradient! end -abstract type VariationalObjective end +# estimators +""" + AbstractVariationalObjective -const VariationalPosterior = Distribution{Multivariate, Continuous} +Abstract type for the VI algorithms supported by `AdvancedVI`. +# Implementations +To be supported by `AdvancedVI`, a VI algorithm must implement `AbstractVariationalObjective` and `estimate_objective`. +Also, it should provide gradients by implementing the function `estimate_gradient!`. +If the estimator is stateful, it can implement `init` to initialize the state. +""" +abstract type AbstractVariationalObjective end """ - grad!(vo, alg::VariationalInference, q, model::Model, θ, out, args...) + init(rng, obj, λ, restructure) -Computes the gradients used in `optimize!`. Default implementation is provided for -`VariationalInference{AD}` where `AD` is either `ForwardDiffAD` or `TrackerAD`. -This implicitly also gives a default implementation of `optimize!`. +Initialize a state of the variational objective `obj` given the initial variational parameters `λ`. +This function needs to be implemented only if `obj` is stateful. -Variance reduction techniques, e.g. control variates, should be implemented in this function. +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `obj::AbstractVariationalObjective`: Variational objective. +- `λ`: Initial variational parameters. +- `restructure`: Function that reconstructs the variational approximation from `λ`. """ -function grad! end +init( + ::Random.AbstractRNG, + ::AbstractVariationalObjective, + ::AbstractVector, + ::Any +) = nothing """ - vi(model, alg::VariationalInference) - vi(model, alg::VariationalInference, q::VariationalPosterior) - vi(model, alg::VariationalInference, getq::Function, θ::AbstractArray) + estimate_objective([rng,] obj, q, prob; kwargs...) -Constructs the variational posterior from the `model` and performs the optimization -following the configuration of the given `VariationalInference` instance. +Estimate the variational objective `obj` targeting `prob` with respect to the variational approximation `q`. # Arguments -- `model`: `Turing.Model` or `Function` z ↦ log p(x, z) where `x` denotes the observations -- `alg`: the VI algorithm used -- `q`: a `VariationalPosterior` for which it is assumed a specialized implementation of the variational objective used exists. -- `getq`: function taking parameters `θ` as input and returns a `VariationalPosterior` -- `θ`: only required if `getq` is used, in which case it is the initial parameters for the variational posterior +- `rng::Random.AbstractRNG`: Random number generator. +- `obj::AbstractVariationalObjective`: Variational objective. +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. +- `q`: Variational approximation. + +# Keyword Arguments +Depending on the objective, additional keyword arguments may apply. +Please refer to the respective documentation of each variational objective for more info. + +# Returns +- `obj_est`: Estimate of the objective value. """ -function vi end - -function update end - -# default implementations -function grad!( - vo, - alg::VariationalInference{<:ForwardDiffAD}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... -) - f(θ_) = if (q isa Distribution) - - vo(alg, update(q, θ_), model, args...) - else - - vo(alg, q(θ_), model, args...) - end +function estimate_objective end - # Set chunk size and do ForwardMode. - chunk_size = getchunksize(typeof(alg)) - config = if chunk_size == 0 - ForwardDiff.GradientConfig(f, θ) - else - ForwardDiff.GradientConfig(f, θ, ForwardDiff.Chunk(length(θ), chunk_size)) - end - ForwardDiff.gradient!(out, f, θ, config) -end +export estimate_objective -function grad!( - vo, - alg::VariationalInference{<:TrackerAD}, - q, - model, - θ::AbstractVector{<:Real}, - out::DiffResults.MutableDiffResult, - args... -) - θ_tracked = Tracker.param(θ) - y = if (q isa Distribution) - - vo(alg, update(q, θ_tracked), model, args...) - else - - vo(alg, q(θ_tracked), model, args...) - end - Tracker.back!(y, 1.0) - DiffResults.value!(out, Tracker.data(y)) - DiffResults.gradient!(out, Tracker.grad(θ_tracked)) -end +""" + estimate_gradient!(rng, obj, adbackend, out, prob, λ, restructure, obj_state) +Estimate (possibly stochastic) gradients of the variational objective `obj` targeting `prob` with respect to the variational parameters `λ` +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `obj::AbstractVariationalObjective`: Variational objective. +- `adbackend::ADTypes.AbstractADType`: Automatic differentiation backend. +- `out::DiffResults.MutableDiffResult`: Buffer containing the objective value and gradient estimates. +- `prob`: The target log-joint likelihood implementing the `LogDensityProblem` interface. +- `λ`: Variational parameters to evaluate the gradient on. +- `restructure`: Function that reconstructs the variational approximation from `λ`. +- `obj_state`: Previous state of the objective. + +# Returns +- `out::MutableDiffResult`: Buffer containing the objective value and gradient estimates. +- `obj_state`: The updated state of the objective. +- `stat::NamedTuple`: Statistics and logs generated during estimation. """ - optimize!(vo, alg::VariationalInference{AD}, q::VariationalPosterior, model::Model, θ; optimizer = TruncatedADAGrad()) +function estimate_gradient! end + +# ELBO-specific interfaces +abstract type AbstractEntropyEstimator end -Iteratively updates parameters by calling `grad!` and using the given `optimizer` to compute -the steps. """ -function optimize!( - vo, - alg::VariationalInference, - q, - model, - θ::AbstractVector{<:Real}; - optimizer = TruncatedADAGrad() -) - # TODO: should we always assume `samples_per_step` and `max_iters` for all algos? - alg_name = alg_str(alg) - samples_per_step = alg.samples_per_step - max_iters = alg.max_iters - - num_params = length(θ) - - # TODO: really need a better way to warn the user about potentially - # not using the correct accumulator - if (optimizer isa TruncatedADAGrad) && (θ ∉ keys(optimizer.acc)) - # this message should only occurr once in the optimization process - @info "[$alg_name] Should only be seen once: optimizer created for θ" objectid(θ) - end + estimate_entropy(entropy_estimator, mc_samples, q) - diff_result = DiffResults.GradientResult(θ) +Estimate the entropy of `q`. - i = 0 - prog = if PROGRESS[] - ProgressMeter.Progress(max_iters, 1, "[$alg_name] Optimizing...", 0) - else - 0 - end +# Arguments +- `entropy_estimator`: Entropy estimation strategy. +- `q`: Variational approximation. +- `mc_samples`: Monte Carlo samples used to estimate the entropy. (Only used for Monte Carlo strategies.) - # add criterion? A running mean maybe? - time_elapsed = @elapsed while (i < max_iters) # & converged - grad!(vo, alg, q, model, θ, diff_result, samples_per_step) +# Returns +- `obj_est`: Estimate of the objective value. +""" +function estimate_entropy end - # apply update rule - Δ = DiffResults.gradient(diff_result) - Δ = apply!(optimizer, θ, Δ) - @. θ = θ - Δ - - AdvancedVI.DEBUG && @debug "Step $i" Δ DiffResults.value(diff_result) - PROGRESS[] && (ProgressMeter.next!(prog)) +export + RepGradELBO, + ClosedFormEntropy, + StickingTheLandingEntropy, + MonteCarloEntropy - i += 1 - end +include("objectives/elbo/entropy.jl") +include("objectives/elbo/repgradelbo.jl") - return θ -end +# Optimization Routine + +function optimize end -# objectives -include("objectives.jl") +export optimize -# optimisers -include("optimisers.jl") +include("utils.jl") +include("optimize.jl") -# VI algorithms -include("advi.jl") -end # module +# optional dependencies +if !isdefined(Base, :get_extension) # check whether :get_extension is defined in Base + using Requires +end + +@static if !isdefined(Base, :get_extension) + function __init__() + @require Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" begin + include("../ext/AdvancedVIBijectorsExt.jl") + end + @require Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" begin + include("../ext/AdvancedVIEnzymeExt.jl") + end + @require ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" begin + include("../ext/AdvancedVIForwardDiffExt.jl") + end + @require ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" begin + include("../ext/AdvancedVIReverseDiffExt.jl") + end + @require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" begin + include("../ext/AdvancedVIZygoteExt.jl") + end + end +end + +end + diff --git a/src/ad.jl b/src/ad.jl deleted file mode 100644 index 62e785e1..00000000 --- a/src/ad.jl +++ /dev/null @@ -1,46 +0,0 @@ -############################## -# Global variables/constants # -############################## -const ADBACKEND = Ref(:forwarddiff) -setadbackend(backend_sym::Symbol) = setadbackend(Val(backend_sym)) -function setadbackend(::Val{:forward_diff}) - Base.depwarn("`AdvancedVI.setadbackend(:forward_diff)` is deprecated. Please use `AdvancedVI.setadbackend(:forwarddiff)` to use `ForwardDiff`.", :setadbackend) - setadbackend(Val(:forwarddiff)) -end -function setadbackend(::Val{:forwarddiff}) - ADBACKEND[] = :forwarddiff -end - -function setadbackend(::Val{:reverse_diff}) - Base.depwarn("`AdvancedVI.setadbackend(:reverse_diff)` is deprecated. Please use `AdvancedVI.setadbackend(:tracker)` to use `Tracker` or `AdvancedVI.setadbackend(:reversediff)` to use `ReverseDiff`. To use `ReverseDiff`, please make sure it is loaded separately with `using ReverseDiff`.", :setadbackend) - setadbackend(Val(:tracker)) -end -function setadbackend(::Val{:tracker}) - ADBACKEND[] = :tracker -end - -const ADSAFE = Ref(false) -function setadsafe(switch::Bool) - @info("[AdvancedVI]: global ADSAFE is set as $switch") - ADSAFE[] = switch -end - -const CHUNKSIZE = Ref(0) # 0 means letting ForwardDiff set it automatically - -function setchunksize(chunk_size::Int) - @info("[AdvancedVI]: AD chunk size is set as $chunk_size") - CHUNKSIZE[] = chunk_size -end - -abstract type ADBackend end -struct ForwardDiffAD{chunk} <: ADBackend end -getchunksize(::Type{<:ForwardDiffAD{chunk}}) where chunk = chunk - -struct TrackerAD <: ADBackend end - -ADBackend() = ADBackend(ADBACKEND[]) -ADBackend(T::Symbol) = ADBackend(Val(T)) - -ADBackend(::Val{:forwarddiff}) = ForwardDiffAD{CHUNKSIZE[]} -ADBackend(::Val{:tracker}) = TrackerAD -ADBackend(::Val) = error("The requested AD backend is not available. Make sure to load all required packages.") diff --git a/src/advi.jl b/src/advi.jl deleted file mode 100644 index 7f9e7346..00000000 --- a/src/advi.jl +++ /dev/null @@ -1,99 +0,0 @@ -using StatsFuns -using DistributionsAD -using Bijectors -using Bijectors: TransformedDistribution - - -""" -$(TYPEDEF) - -Automatic Differentiation Variational Inference (ADVI) with automatic differentiation -backend `AD`. - -# Fields - -$(TYPEDFIELDS) -""" -struct ADVI{AD} <: VariationalInference{AD} - "Number of samples used to estimate the ELBO in each optimization step." - samples_per_step::Int - "Maximum number of gradient steps." - max_iters::Int -end - -function ADVI(samples_per_step::Int=1, max_iters::Int=1000) - return ADVI{ADBackend()}(samples_per_step, max_iters) -end - -alg_str(::ADVI) = "ADVI" - -function vi(model, alg::ADVI, q, θ_init; optimizer = TruncatedADAGrad()) - θ = copy(θ_init) - optimize!(elbo, alg, q, model, θ; optimizer = optimizer) - - # If `q` is a mean-field approx we use the specialized `update` function - if q isa Distribution - return update(q, θ) - else - # Otherwise we assume it's a mapping θ → q - return q(θ) - end -end - - -function optimize(elbo::ELBO, alg::ADVI, q, model, θ_init; optimizer = TruncatedADAGrad()) - θ = copy(θ_init) - - # `model` assumed to be callable z ↦ p(x, z) - optimize!(elbo, alg, q, model, θ; optimizer = optimizer) - - return θ -end - -# WITHOUT updating parameters inside ELBO -function (elbo::ELBO)( - rng::Random.AbstractRNG, - alg::ADVI, - q::VariationalPosterior, - logπ::Function, - num_samples -) - # 𝔼_q(z)[log p(xᵢ, z)] - # = ∫ log p(xᵢ, z) q(z) dz - # = ∫ log p(xᵢ, f(ϕ)) q(f(ϕ)) |det J_f(ϕ)| dϕ (since change of variables) - # = ∫ log p(xᵢ, f(ϕ)) q̃(ϕ) dϕ (since q(f(ϕ)) |det J_f(ϕ)| = q̃(ϕ)) - # = 𝔼_q̃(ϕ)[log p(xᵢ, z)] - - # 𝔼_q(z)[log q(z)] - # = ∫ q(f(ϕ)) log (q(f(ϕ))) |det J_f(ϕ)| dϕ (since q(f(ϕ)) |det J_f(ϕ)| = q̃(ϕ)) - # = 𝔼_q̃(ϕ) [log q(f(ϕ))] - # = 𝔼_q̃(ϕ) [log q̃(ϕ) - log |det J_f(ϕ)|] - # = 𝔼_q̃(ϕ) [log q̃(ϕ)] - 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] - # = - ℍ(q̃(ϕ)) - 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] - - # Finally, the ELBO is given by - # ELBO = 𝔼_q(z)[log p(xᵢ, z)] - 𝔼_q(z)[log q(z)] - # = 𝔼_q̃(ϕ)[log p(xᵢ, z)] + 𝔼_q̃(ϕ) [log |det J_f(ϕ)|] + ℍ(q̃(ϕ)) - - # If f: supp(p(z | x)) → ℝ then - # ELBO = 𝔼[log p(x, z) - log q(z)] - # = 𝔼[log p(x, f⁻¹(z̃)) + logabsdet(J(f⁻¹(z̃)))] + ℍ(q̃(z̃)) - # = 𝔼[log p(x, z) - logabsdetjac(J(f(z)))] + ℍ(q̃(z̃)) - - # But our `rand_and_logjac(q)` is using f⁻¹: ℝ → supp(p(z | x)) going forward → `+ logjac` - z, logjac = rand_and_logjac(rng, q) - res = (logπ(z) + logjac) / num_samples - - if q isa TransformedDistribution - res += entropy(q.dist) - else - res += entropy(q) - end - - for i = 2:num_samples - z, logjac = rand_and_logjac(rng, q) - res += (logπ(z) + logjac) / num_samples - end - - return res -end diff --git a/src/compat/enzyme.jl b/src/compat/enzyme.jl deleted file mode 100644 index c6bb9ac3..00000000 --- a/src/compat/enzyme.jl +++ /dev/null @@ -1,5 +0,0 @@ -struct EnzymeAD <: ADBackend end -ADBackend(::Val{:enzyme}) = EnzymeAD -function setadbackend(::Val{:enzyme}) - ADBACKEND[] = :enzyme -end diff --git a/src/compat/reversediff.jl b/src/compat/reversediff.jl deleted file mode 100644 index 721d0361..00000000 --- a/src/compat/reversediff.jl +++ /dev/null @@ -1,16 +0,0 @@ -using .ReverseDiff: compile, GradientTape -using .ReverseDiff.DiffResults: GradientResult - -struct ReverseDiffAD{cache} <: ADBackend end -const RDCache = Ref(false) -setcache(b::Bool) = RDCache[] = b -getcache() = RDCache[] -ADBackend(::Val{:reversediff}) = ReverseDiffAD{getcache()} -function setadbackend(::Val{:reversediff}) - ADBACKEND[] = :reversediff -end - -tape(f, x) = GradientTape(f, x) -function taperesult(f, x) - return tape(f, x), GradientResult(x) -end diff --git a/src/compat/zygote.jl b/src/compat/zygote.jl deleted file mode 100644 index 40022e21..00000000 --- a/src/compat/zygote.jl +++ /dev/null @@ -1,5 +0,0 @@ -struct ZygoteAD <: ADBackend end -ADBackend(::Val{:zygote}) = ZygoteAD -function setadbackend(::Val{:zygote}) - ADBACKEND[] = :zygote -end diff --git a/src/objectives.jl b/src/objectives.jl deleted file mode 100644 index 5a6b61b0..00000000 --- a/src/objectives.jl +++ /dev/null @@ -1,7 +0,0 @@ -struct ELBO <: VariationalObjective end - -function (elbo::ELBO)(alg, q, logπ, num_samples; kwargs...) - return elbo(Random.default_rng(), alg, q, logπ, num_samples; kwargs...) -end - -const elbo = ELBO() diff --git a/src/objectives/elbo/entropy.jl b/src/objectives/elbo/entropy.jl new file mode 100644 index 00000000..6c5b4739 --- /dev/null +++ b/src/objectives/elbo/entropy.jl @@ -0,0 +1,48 @@ + +""" + ClosedFormEntropy() + +Use closed-form expression of entropy. + +# Requirements +- The variational approximation implements `entropy`. + +# References +* Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In International conference on machine learning (pp. 1971-1979). PMLR. +* Kucukelbir, A., Tran, D., Ranganath, R., Gelman, A., & Blei, D. M. (2017). Automatic differentiation variational inference. Journal of machine learning research. +""" +struct ClosedFormEntropy <: AbstractEntropyEstimator end + +maybe_stop_entropy_score(::AbstractEntropyEstimator, q, q_stop) = q + +function estimate_entropy(::ClosedFormEntropy, ::Any, q) + entropy(q) +end + +""" + StickingTheLandingEntropy() + +The "sticking the landing" entropy estimator. + +# Requirements +- The variational approximation `q` implements `logpdf`. +- `logpdf(q, η)` must be differentiable by the selected AD framework. + +# References +* Roeder, G., Wu, Y., & Duvenaud, D. K. (2017). Sticking the landing: Simple, lower-variance gradient estimators for variational inference. Advances in Neural Information Processing Systems, 30. +""" +struct StickingTheLandingEntropy <: AbstractEntropyEstimator end + +struct MonteCarloEntropy <: AbstractEntropyEstimator end + +maybe_stop_entropy_score(::StickingTheLandingEntropy, q, q_stop) = q_stop + +function estimate_entropy( + ::Union{MonteCarloEntropy, StickingTheLandingEntropy}, + mc_samples::AbstractMatrix, + q +) + mean(eachcol(mc_samples)) do mc_sample + -logpdf(q, mc_sample) + end +end diff --git a/src/objectives/elbo/repgradelbo.jl b/src/objectives/elbo/repgradelbo.jl new file mode 100644 index 00000000..04f35320 --- /dev/null +++ b/src/objectives/elbo/repgradelbo.jl @@ -0,0 +1,125 @@ + +""" + RepGradELBO(n_samples; kwargs...) + +Evidence lower-bound objective with the reparameterization gradient formulation[^TL2014][^RMW2014][^KW2014]. +This computes the evidence lower-bound (ELBO) through the formulation: +```math +\\begin{aligned} +\\mathrm{ELBO}\\left(\\lambda\\right) +&\\triangleq +\\mathbb{E}_{z \\sim q_{\\lambda}}\\left[ + \\log \\pi\\left(z\\right) +\\right] ++ \\mathbb{H}\\left(q_{\\lambda}\\right), +\\end{aligned} +``` + +# Arguments +- `n_samples::Int`: Number of Monte Carlo samples used to estimate the ELBO. + +# Keyword Arguments +- `entropy`: The estimator for the entropy term. (Type `<: AbstractEntropyEstimator`; Default: ClosedFormEntropy()) + +# Requirements +- ``q_{\\lambda}`` implements `rand`. +- The target `logdensity(prob, x)` must be differentiable wrt. `x` by the selected AD backend. + +Depending on the options, additional requirements on ``q_{\\lambda}`` may apply. + +# References +[^TL2014]: Titsias, M., & Lázaro-Gredilla, M. (2014, June). Doubly stochastic variational Bayes for non-conjugate inference. In ICML. +[^RMW2014]: Rezende, D. J., Mohamed, S., & Wierstra, D. (2014, June). Stochastic backpropagation and approximate inference in deep generative models. In ICML. +[^KW2014]: Kingma, D. P., & Welling, M. (2014). Auto-encoding variational bayes. In ICLR. +""" +struct RepGradELBO{EntropyEst <: AbstractEntropyEstimator} <: AbstractVariationalObjective + entropy ::EntropyEst + n_samples::Int +end + +RepGradELBO( + n_samples::Int; + entropy ::AbstractEntropyEstimator = ClosedFormEntropy() +) = RepGradELBO(entropy, n_samples) + +function Base.show(io::IO, obj::RepGradELBO) + print(io, "RepGradELBO(entropy=") + print(io, obj.entropy) + print(io, ", n_samples=") + print(io, obj.n_samples) + print(io, ")") +end + +function estimate_entropy_maybe_stl(entropy_estimator::AbstractEntropyEstimator, samples, q, q_stop) + q_maybe_stop = maybe_stop_entropy_score(entropy_estimator, q, q_stop) + estimate_entropy(entropy_estimator, samples, q_maybe_stop) +end + +function estimate_energy_with_samples(prob, samples) + mean(Base.Fix1(LogDensityProblems.logdensity, prob), eachsample(samples)) +end + +""" + reparam_with_entropy(rng, q, q_stop, n_samples, ent_est) + +Draw `n_samples` from `q` and compute its entropy. + +# Arguments +- `rng::Random.AbstractRNG`: Random number generator. +- `q`: Variational approximation. +- `q_stop`: `q` but with its gradient stopped. +- `n_samples::Int`: Number of Monte Carlo samples +- `ent_est`: The entropy estimation strategy. (See `estimate_entropy`.) + +# Returns +- `samples`: Monte Carlo samples generated through reparameterization. Their support matches that of the target distribution. +- `entropy`: An estimate (or exact value) of the differential entropy of `q`. +""" +function reparam_with_entropy( + rng::Random.AbstractRNG, q, q_stop, n_samples::Int, ent_est::AbstractEntropyEstimator +) + samples = rand(rng, q, n_samples) + entropy = estimate_entropy_maybe_stl(ent_est, samples, q, q_stop) + samples, entropy +end + +function estimate_objective( + rng::Random.AbstractRNG, + obj::RepGradELBO, + q, + prob; + n_samples::Int = obj.n_samples +) + samples, entropy = reparam_with_entropy(rng, q, q, n_samples, obj.entropy) + energy = estimate_energy_with_samples(prob, samples) + energy + entropy +end + +estimate_objective(obj::RepGradELBO, q, prob; n_samples::Int = obj.n_samples) = + estimate_objective(Random.default_rng(), obj, q, prob; n_samples) + +function estimate_gradient!( + rng ::Random.AbstractRNG, + obj ::RepGradELBO, + adbackend::ADTypes.AbstractADType, + out ::DiffResults.MutableDiffResult, + prob, + λ, + restructure, + state, +) + q_stop = restructure(λ) + function f(λ′) + q = restructure(λ′) + samples, entropy = reparam_with_entropy(rng, q, q_stop, obj.n_samples, obj.entropy) + energy = estimate_energy_with_samples(prob, samples) + elbo = energy + entropy + -elbo + end + value_and_gradient!(adbackend, f, λ, out) + + nelbo = DiffResults.value(out) + stat = (elbo=-nelbo,) + + out, nothing, stat +end diff --git a/src/optimisers.jl b/src/optimisers.jl deleted file mode 100644 index 8077f98c..00000000 --- a/src/optimisers.jl +++ /dev/null @@ -1,94 +0,0 @@ -const ϵ = 1e-8 - -""" - TruncatedADAGrad(η=0.1, τ=1.0, n=100) - -Implements a truncated version of AdaGrad in the sense that only the `n` previous gradient norms are used to compute the scaling rather than *all* previous. It has parameter specific learning rates based on how frequently it is updated. - -## Parameters - - η: learning rate - - τ: constant scale factor - - n: number of previous gradient norms to use in the scaling. -``` -## References -[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. -Parameters don't need tuning. - -[TruncatedADAGrad](https://arxiv.org/abs/1506.03431v2) (Appendix E). -""" -mutable struct TruncatedADAGrad - eta::Float64 - tau::Float64 - n::Int - - iters::IdDict - acc::IdDict -end - -function TruncatedADAGrad(η = 0.1, τ = 1.0, n = 100) - TruncatedADAGrad(η, τ, n, IdDict(), IdDict()) -end - -function apply!(o::TruncatedADAGrad, x, Δ) - T = eltype(Tracker.data(Δ)) - - η = o.eta - τ = o.tau - - g² = get!( - o.acc, - x, - [zeros(T, size(x)) for j = 1:o.n] - )::Array{typeof(Tracker.data(Δ)), 1} - i = get!(o.iters, x, 1)::Int - - # Example: suppose i = 12 and o.n = 10 - idx = mod(i - 1, o.n) + 1 # => idx = 2 - - # set the current - @inbounds @. g²[idx] = Δ^2 # => g²[2] = Δ^2 where Δ is the (o.n + 2)-th Δ - - # TODO: make more efficient and stable - s = sum(g²) - - # increment - o.iters[x] += 1 - - # TODO: increment (but "truncate") - # o.iters[x] = i > o.n ? o.n + mod(i, o.n) : i + 1 - - @. Δ *= η / (τ + sqrt(s) + ϵ) -end - -""" - DecayedADAGrad(η=0.1, pre=1.0, post=0.9) - -Implements a decayed version of AdaGrad. It has parameter specific learning rates based on how frequently it is updated. - -## Parameters - - η: learning rate - - pre: weight of new gradient norm - - post: weight of histroy of gradient norms -``` -## References -[ADAGrad](http://www.jmlr.org/papers/volume12/duchi11a/duchi11a.pdf) optimiser. -Parameters don't need tuning. -""" -mutable struct DecayedADAGrad - eta::Float64 - pre::Float64 - post::Float64 - - acc::IdDict -end - -DecayedADAGrad(η = 0.1, pre = 1.0, post = 0.9) = DecayedADAGrad(η, pre, post, IdDict()) - -function apply!(o::DecayedADAGrad, x, Δ) - T = eltype(Tracker.data(Δ)) - - η = o.eta - acc = get!(o.acc, x, fill(T(ϵ), size(x)))::typeof(Tracker.data(x)) - @. acc = o.post * acc + o.pre * Δ^2 - @. Δ *= η / (√acc + ϵ) -end diff --git a/src/optimize.jl b/src/optimize.jl new file mode 100644 index 00000000..7e0032dc --- /dev/null +++ b/src/optimize.jl @@ -0,0 +1,161 @@ + +""" + optimize(problem, objective, restructure, param_init, max_iter, objargs...; kwargs...) + optimize(problem, objective, variational_dist_init, max_iter, objargs...; kwargs...) + +Optimize the variational objective `objective` targeting the problem `problem` by estimating (stochastic) gradients. + +The variational approximation can be constructed by passing the variational parameters `param_init` or the initial variational approximation `variational_dist_init` to the function `restructure`. + +# Arguments +- `objective::AbstractVariationalObjective`: Variational Objective. +- `param_init`: Initial value of the variational parameters. +- `restruct`: Function that reconstructs the variational approximation from the flattened parameters. +- `variational_dist_init`: Initial variational distribution. The variational parameters must be extractable through `Optimisers.destructure`. +- `max_iter::Int`: Maximum number of iterations. +- `objargs...`: Arguments to be passed to `objective`. + +# Keyword Arguments +- `adbackend::ADtypes.AbstractADType`: Automatic differentiation backend. +- `optimizer::Optimisers.AbstractRule`: Optimizer used for inference. (Default: `Adam`.) +- `rng::AbstractRNG`: Random number generator. (Default: `Random.default_rng()`.) +- `show_progress::Bool`: Whether to show the progress bar. (Default: `true`.) +- `callback`: Callback function called after every iteration. See further information below. (Default: `nothing`.) +- `prog`: Progress bar configuration. (Default: `ProgressMeter.Progress(n_max_iter; desc="Optimizing", barlen=31, showspeed=true, enabled=prog)`.) +- `state::NamedTuple`: Initial value for the internal state of optimization. Used to warm-start from the state of a previous run. (See the returned values below.) + +# Returns +- `params`: Variational parameters optimizing the variational objective. +- `stats`: Statistics gathered during optimization. +- `state`: Collection of the final internal states of optimization. This can used later to warm-start from the last iteration of the corresponding run. + +# Callback +The callback function `callback` has a signature of + + callback(; stat, state, param, restructure, gradient) + +The arguments are as follows: +- `stat`: Statistics gathered during the current iteration. The content will vary depending on `objective`. +- `state`: Collection of the internal states used for optimization. +- `param`: Variational parameters. +- `restructure`: Function that restructures the variational approximation from the variational parameters. Calling `restructure(param)` reconstructs the variational approximation. +- `gradient`: The estimated (possibly stochastic) gradient. + +`cb` can return a `NamedTuple` containing some additional information computed within `cb`. +This will be appended to the statistic of the current corresponding iteration. +Otherwise, just return `nothing`. + +""" + +function optimize( + rng ::Random.AbstractRNG, + problem, + objective ::AbstractVariationalObjective, + restructure, + params_init, + max_iter ::Int, + objargs...; + adbackend ::ADTypes.AbstractADType, + optimizer ::Optimisers.AbstractRule = Optimisers.Adam(), + show_progress::Bool = true, + state_init ::NamedTuple = NamedTuple(), + callback = nothing, + prog = ProgressMeter.Progress( + max_iter; + desc = "Optimizing", + barlen = 31, + showspeed = true, + enabled = show_progress + ) +) + λ = copy(params_init) + opt_st = maybe_init_optimizer(state_init, optimizer, λ) + obj_st = maybe_init_objective(state_init, rng, objective, λ, restructure) + grad_buf = DiffResults.DiffResult(zero(eltype(λ)), similar(λ)) + stats = NamedTuple[] + + for t = 1:max_iter + stat = (iteration=t,) + + grad_buf, obj_st, stat′ = estimate_gradient!( + rng, objective, adbackend, grad_buf, problem, + λ, restructure, obj_st, objargs... + ) + stat = merge(stat, stat′) + + g = DiffResults.gradient(grad_buf) + opt_st, λ = Optimisers.update!(opt_st, λ, g) + + if !isnothing(callback) + stat′ = callback( + ; stat, restructure, params=λ, gradient=g, + state=(optimizer=opt_st, objective=obj_st) + ) + stat = !isnothing(stat′) ? merge(stat′, stat) : stat + end + + @debug "Iteration $t" stat... + + pm_next!(prog, stat) + push!(stats, stat) + end + state = (optimizer=opt_st, objective=obj_st) + stats = map(identity, stats) + params = λ + params, stats, state +end + +function optimize( + problem, + objective ::AbstractVariationalObjective, + restructure, + params_init, + max_iter ::Int, + objargs...; + kwargs... +) + optimize( + Random.default_rng(), + problem, + objective, + restructure, + params_init, + max_iter, + objargs...; + kwargs... + ) +end + +function optimize(rng ::Random.AbstractRNG, + problem, + objective ::AbstractVariationalObjective, + variational_dist_init, + n_max_iter ::Int, + objargs...; + kwargs...) + λ, restructure = Optimisers.destructure(variational_dist_init) + λ, logstats, state = optimize( + rng, problem, objective, restructure, λ, n_max_iter, objargs...; kwargs... + ) + restructure(λ), logstats, state +end + + +function optimize( + problem, + objective ::AbstractVariationalObjective, + variational_dist_init, + max_iter ::Int, + objargs...; + kwargs... +) + optimize( + Random.default_rng(), + problem, + objective, + variational_dist_init, + max_iter, + objargs...; + kwargs... + ) +end diff --git a/src/utils.jl b/src/utils.jl index bb4c1f18..8e67ff1a 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,15 +1,36 @@ -using Distributions -using Bijectors: Bijectors +function pm_next!(pm, stats::NamedTuple) + ProgressMeter.next!(pm; showvalues=[tuple(s...) for s in pairs(stats)]) +end +function maybe_init_optimizer( + state_init::NamedTuple, + optimizer ::Optimisers.AbstractRule, + λ ::AbstractVector +) + haskey(state_init, :optimizer) ? state_init.optimizer : Optimisers.setup(optimizer, λ) +end -function rand_and_logjac(rng::Random.AbstractRNG, dist::Distribution) - x = rand(rng, dist) - return x, zero(eltype(x)) +function maybe_init_objective( + state_init::NamedTuple, + rng ::Random.AbstractRNG, + objective ::AbstractVariationalObjective, + λ ::AbstractVector, + restructure +) + haskey(state_init, :objective) ? state_init.objective : init(rng, objective, λ, restructure) end -function rand_and_logjac(rng::Random.AbstractRNG, dist::Bijectors.TransformedDistribution) - x = rand(rng, dist.dist) - y, logjac = Bijectors.with_logabsdet_jacobian(dist.transform, x) - return y, logjac +eachsample(samples::AbstractMatrix) = eachcol(samples) + +eachsample(samples::AbstractVector) = samples + +function catsamples_and_acc( + state_curr::Tuple{<:AbstractArray, <:Real}, + state_new ::Tuple{<:AbstractVector, <:Real} +) + x = hcat(first(state_curr), first(state_new)) + ∑y = last(state_curr) + last(state_new) + return (x, ∑y) end + diff --git a/test/Project.toml b/test/Project.toml new file mode 100644 index 00000000..a751b89d --- /dev/null +++ b/test/Project.toml @@ -0,0 +1,45 @@ +[deps] +ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" +Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" +FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" +Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SimpleUnPack = "ce78b400-467f-4804-87d8-8f486da07d0a" +StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" + +[compat] +ADTypes = "0.2.1" +Bijectors = "0.13" +Distributions = "0.25.100" +DistributionsAD = "0.6.45" +Enzyme = "0.11.7" +FillArrays = "1.6.1" +ForwardDiff = "0.10.36" +Functors = "0.4.5" +LinearAlgebra = "1" +LogDensityProblems = "2.1.1" +Optimisers = "0.2.16, 0.3" +PDMats = "0.11.7" +Random = "1" +ReverseDiff = "1.15.1" +SimpleUnPack = "1.1.0" +StableRNGs = "1.0.0" +Statistics = "1" +Test = "1" +Tracker = "0.2.20" +Zygote = "0.6.63" +julia = "1.6" diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl new file mode 100644 index 00000000..b6db22a6 --- /dev/null +++ b/test/inference/repgradelbo_distributionsad.jl @@ -0,0 +1,78 @@ + +const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false + +using Test + +@testset "inference RepGradELBO DistributionsAD" begin + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for + realtype ∈ [Float64, Float32], + (modelname, modelconstr) ∈ Dict( + :Normal=> normal_meanfield, + ), + (objname, objective) ∈ Dict( + :RepGradELBOClosedFormEntropy => RepGradELBO(10), + :RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()), + ), + (adbackname, adbackend) ∈ Dict( + :ForwarDiff => AutoForwardDiff(), + #:ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + #:Enzyme => AutoEnzyme(), + ) + + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = modelconstr(rng, realtype) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) + + μ0 = Zeros(realtype, n_dims) + L0 = Diagonal(Ones(realtype, n_dims)) + q0 = TuringDiagMvNormal(μ0, diag(L0)) + + @testset "convergence" begin + Δλ₀ = sum(abs2, μ0 - μ_true) + sum(abs2, L0 - L_true) + q, stats, _ = optimize( + rng, model, objective, q0, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + + μ = mean(q) + L = sqrt(cov(q)) + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ Δλ₀/T^(1/4) + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = StableRNG(seed) + q, stats, _ = optimize( + rng, model, objective, q0, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + μ = mean(q) + L = sqrt(cov(q)) + + rng_repl = StableRNG(seed) + q, stats, _ = optimize( + rng_repl, model, objective, q0, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + μ_repl = mean(q) + L_repl = sqrt(cov(q)) + @test μ == μ_repl + @test L == L_repl + end + end +end + diff --git a/test/inference/repgradelbo_distributionsad_bijectors.jl b/test/inference/repgradelbo_distributionsad_bijectors.jl new file mode 100644 index 00000000..9f1e3cc4 --- /dev/null +++ b/test/inference/repgradelbo_distributionsad_bijectors.jl @@ -0,0 +1,81 @@ + +const PROGRESS = length(ARGS) > 0 && ARGS[1] == "--progress" ? true : false + +using Test + +@testset "inference RepGradELBO DistributionsAD Bijectors" begin + @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for + realtype ∈ [Float64, Float32], + (modelname, modelconstr) ∈ Dict( + :NormalLogNormalMeanField => normallognormal_meanfield, + ), + (objname, objective) ∈ Dict( + :RepGradELBOClosedFormEntropy => RepGradELBO(10), + :RepGradELBOStickingTheLanding => RepGradELBO(10, entropy = StickingTheLandingEntropy()), + ), + (adbackname, adbackend) ∈ Dict( + :ForwarDiff => AutoForwardDiff(), + #:ReverseDiff => AutoReverseDiff(), + #:Zygote => AutoZygote(), + #:Enzyme => AutoEnzyme(), + ) + + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = modelconstr(rng, realtype) + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3) + + b = Bijectors.bijector(model) + b⁻¹ = inverse(b) + μ₀ = Zeros(realtype, n_dims) + L₀ = Diagonal(Ones(realtype, n_dims)) + + q₀_η = TuringDiagMvNormal(μ₀, diag(L₀)) + q₀_z = Bijectors.transformed(q₀_η, b⁻¹) + + @testset "convergence" begin + Δλ₀ = sum(abs2, μ₀ - μ_true) + sum(abs2, L₀ - L_true) + q, stats, _ = optimize( + rng, model, objective, q₀_z, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + + μ = mean(q.dist) + L = sqrt(cov(q.dist)) + Δλ = sum(abs2, μ - μ_true) + sum(abs2, L - L_true) + + @test Δλ ≤ Δλ₀/T^(1/4) + @test eltype(μ) == eltype(μ_true) + @test eltype(L) == eltype(L_true) + end + + @testset "determinism" begin + rng = StableRNG(seed) + q, stats, _ = optimize( + rng, model, objective, q₀_z, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + μ = mean(q.dist) + L = sqrt(cov(q.dist)) + + rng_repl = StableRNG(seed) + q, stats, _ = optimize( + rng_repl, model, objective, q₀_z, T; + optimizer = Optimisers.Adam(realtype(η)), + show_progress = PROGRESS, + adbackend = adbackend, + ) + μ_repl = mean(q.dist) + L_repl = sqrt(cov(q.dist)) + @test μ == μ_repl + @test L == L_repl + end + end +end diff --git a/test/interface/ad.jl b/test/interface/ad.jl new file mode 100644 index 00000000..b716ca2f --- /dev/null +++ b/test/interface/ad.jl @@ -0,0 +1,22 @@ + +using Test + +@testset "ad" begin + @testset "$(adname)" for (adname, adsymbol) ∈ Dict( + :ForwardDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + # :Enzyme => AutoEnzyme(), # Currently not tested against. + ) + D = 10 + A = randn(D, D) + λ = randn(D) + grad_buf = DiffResults.GradientResult(λ) + f(λ′) = λ′'*A*λ′ / 2 + AdvancedVI.value_and_gradient!(adsymbol, f, λ, grad_buf) + ∇ = DiffResults.gradient(grad_buf) + f = DiffResults.value(grad_buf) + @test ∇ ≈ (A + A')*λ/2 + @test f ≈ λ'*A*λ / 2 + end +end diff --git a/test/interface/optimize.jl b/test/interface/optimize.jl new file mode 100644 index 00000000..6e69616b --- /dev/null +++ b/test/interface/optimize.jl @@ -0,0 +1,98 @@ + +using Test + +@testset "interface optimize" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + T = 1000 + modelstats = normal_meanfield(rng, Float64) + + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + # Global Test Configurations + q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) + obj = RepGradELBO(10) + + adbackend = AutoForwardDiff() + optimizer = Optimisers.Adam(1e-2) + + rng = StableRNG(seed) + q_ref, stats_ref, _ = optimize( + rng, model, obj, q0, T; + optimizer, + show_progress = false, + adbackend, + ) + λ_ref, _ = Optimisers.destructure(q_ref) + + @testset "default_rng" begin + optimize( + model, obj, q0, T; + optimizer, + show_progress = false, + adbackend, + ) + + λ₀, re = Optimisers.destructure(q0) + optimize( + model, obj, re, λ₀, T; + optimizer, + show_progress = false, + adbackend, + ) + end + + @testset "restructure" begin + λ₀, re = Optimisers.destructure(q0) + + rng = StableRNG(seed) + λ, stats, _ = optimize( + rng, model, obj, re, λ₀, T; + optimizer, + show_progress = false, + adbackend, + ) + @test λ == λ_ref + @test stats == stats_ref + end + + @testset "callback" begin + rng = StableRNG(seed) + test_values = rand(rng, T) + + callback(; stat, args...) = (test_value = test_values[stat.iteration],) + + rng = StableRNG(seed) + _, stats, _ = optimize( + rng, model, obj, q0, T; + show_progress = false, + adbackend, + callback + ) + @test [stat.test_value for stat ∈ stats] == test_values + end + + @testset "warm start" begin + rng = StableRNG(seed) + + T_first = div(T,2) + T_last = T - T_first + + q_first, _, state = optimize( + rng, model, obj, q0, T_first; + optimizer, + show_progress = false, + adbackend + ) + + q, stats, _ = optimize( + rng, model, obj, q_first, T_last; + optimizer, + show_progress = false, + state_init = state, + adbackend + ) + @test q == q_ref + end +end diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl new file mode 100644 index 00000000..61ff0111 --- /dev/null +++ b/test/interface/repgradelbo.jl @@ -0,0 +1,28 @@ + +using Test + +@testset "interface RepGradELBO" begin + seed = (0x38bef07cf9cc549d) + rng = StableRNG(seed) + + modelstats = normal_meanfield(rng, Float64) + + @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats + + q0 = TuringDiagMvNormal(zeros(Float64, n_dims), ones(Float64, n_dims)) + + obj = RepGradELBO(10) + rng = StableRNG(seed) + elbo_ref = estimate_objective(rng, obj, q0, model; n_samples=10^4) + + @testset "determinism" begin + rng = StableRNG(seed) + elbo = estimate_objective(rng, obj, q0, model; n_samples=10^4) + @test elbo == elbo_ref + end + + @testset "default_rng" begin + elbo = estimate_objective(obj, q0, model; n_samples=10^4) + @test elbo ≈ elbo_ref rtol=0.1 + end +end diff --git a/test/models/normal.jl b/test/models/normal.jl new file mode 100644 index 00000000..3f305e1a --- /dev/null +++ b/test/models/normal.jl @@ -0,0 +1,43 @@ + +struct TestNormal{M,S} + μ::M + Σ::S +end + +function LogDensityProblems.logdensity(model::TestNormal, θ) + @unpack μ, Σ = model + logpdf(MvNormal(μ, Σ), θ) +end + +function LogDensityProblems.dimension(model::TestNormal) + length(model.μ) +end + +function LogDensityProblems.capabilities(::Type{<:TestNormal}) + LogDensityProblems.LogDensityOrder{0}() +end + +function normal_fullrank(rng::Random.AbstractRNG, realtype::Type) + n_dims = 5 + + μ = randn(rng, realtype, n_dims) + L = tril(I + ones(realtype, n_dims, n_dims))/2 + Σ = L*L' |> Hermitian + + model = TestNormal(μ, PDMat(Σ, Cholesky(L, 'L', 0))) + + TestModel(model, μ, L, n_dims, false) +end + +function normal_meanfield(rng::Random.AbstractRNG, realtype::Type) + n_dims = 5 + + μ = randn(rng, realtype, n_dims) + σ = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) + + model = TestNormal(μ, Diagonal(σ.^2)) + + L = σ |> Diagonal + + TestModel(model, μ, L, n_dims, true) +end diff --git a/test/models/normallognormal.jl b/test/models/normallognormal.jl new file mode 100644 index 00000000..6615084b --- /dev/null +++ b/test/models/normallognormal.jl @@ -0,0 +1,65 @@ + +struct NormalLogNormal{MX,SX,MY,SY} + μ_x::MX + σ_x::SX + μ_y::MY + Σ_y::SY +end + +function LogDensityProblems.logdensity(model::NormalLogNormal, θ) + @unpack μ_x, σ_x, μ_y, Σ_y = model + logpdf(LogNormal(μ_x, σ_x), θ[1]) + logpdf(MvNormal(μ_y, Σ_y), θ[2:end]) +end + +function LogDensityProblems.dimension(model::NormalLogNormal) + length(model.μ_y) + 1 +end + +function LogDensityProblems.capabilities(::Type{<:NormalLogNormal}) + LogDensityProblems.LogDensityOrder{0}() +end + +function Bijectors.bijector(model::NormalLogNormal) + @unpack μ_x, σ_x, μ_y, Σ_y = model + Bijectors.Stacked( + Bijectors.bijector.([LogNormal(μ_x, σ_x), MvNormal(μ_y, Σ_y)]), + [1:1, 2:1+length(μ_y)]) +end + +function normallognormal_fullrank(rng::Random.AbstractRNG, realtype::Type) + n_dims = 5 + + μ_x = randn(rng, realtype) + σ_x = ℯ + μ_y = randn(rng, realtype, n_dims) + L_y = tril(I + ones(realtype, n_dims, n_dims))/2 + Σ_y = L_y*L_y' |> Hermitian + + model = NormalLogNormal(μ_x, σ_x, μ_y, PDMat(Σ_y, Cholesky(L_y, 'L', 0))) + + Σ = Matrix{realtype}(undef, n_dims+1, n_dims+1) + Σ[1,1] = σ_x^2 + Σ[2:end,2:end] = Σ_y + Σ = Σ |> Hermitian + + μ = vcat(μ_x, μ_y) + L = cholesky(Σ).L + + TestModel(model, μ, L, n_dims+1, false) +end + +function normallognormal_meanfield(rng::Random.AbstractRNG, realtype::Type) + n_dims = 5 + + μ_x = randn(rng, realtype) + σ_x = ℯ + μ_y = randn(rng, realtype, n_dims) + σ_y = log.(exp.(randn(rng, realtype, n_dims)) .+ 1) + + model = NormalLogNormal(μ_x, σ_x, μ_y, Diagonal(σ_y.^2)) + + μ = vcat(μ_x, μ_y) + L = vcat(σ_x, σ_y) |> Diagonal + + TestModel(model, μ, L, n_dims+1, true) +end diff --git a/test/optimisers.jl b/test/optimisers.jl deleted file mode 100644 index fae652ed..00000000 --- a/test/optimisers.jl +++ /dev/null @@ -1,17 +0,0 @@ -using Random, Test, LinearAlgebra, ForwardDiff -using AdvancedVI: TruncatedADAGrad, DecayedADAGrad, apply! - -θ = randn(10, 10) -@testset for opt in [TruncatedADAGrad(), DecayedADAGrad(1e-2)] - θ_fit = randn(10, 10) - loss(x, θ_) = mean(sum(abs2, θ*x - θ_*x; dims = 1)) - for t = 1:10^4 - x = rand(10) - Δ = ForwardDiff.gradient(θ_ -> loss(x, θ_), θ_fit) - Δ = apply!(opt, θ_fit, Δ) - @. θ_fit = θ_fit - Δ - end - @test loss(rand(10, 100), θ_fit) < 0.01 - @test length(opt.acc) == 1 -end - diff --git a/test/runtests.jl b/test/runtests.jl index a305c25e..b14b8b2e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,28 +1,42 @@ -using Test -using Distributions, DistributionsAD -using AdvancedVI - -include("optimisers.jl") - -target = MvNormal(ones(2)) -logπ(z) = logpdf(target, z) -advi = ADVI(10, 1000) -# Using a function z ↦ q(⋅∣z) -getq(θ) = TuringDiagMvNormal(θ[1:2], exp.(θ[3:4])) -q = vi(logπ, advi, getq, randn(4)) +using Test +using Test: @testset, @test + +using Bijectors +using Random, StableRNGs +using Statistics +using Distributions +using LinearAlgebra +using SimpleUnPack: @unpack +using FillArrays +using PDMats + +using Functors +using DistributionsAD +@functor TuringDiagMvNormal + +using LogDensityProblems +using Optimisers +using ADTypes +using ForwardDiff, ReverseDiff, Zygote -xs = rand(target, 10) -@test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.05 +using AdvancedVI -# OR: implement `update` and pass a `Distribution` -function AdvancedVI.update(d::TuringDiagMvNormal, θ::AbstractArray{<:Real}) - return TuringDiagMvNormal(θ[1:length(q)], exp.(θ[length(q) + 1:end])) +# Models for Inference Tests +struct TestModel{M,L,S} + model::M + μ_true::L + L_true::S + n_dims::Int + is_meanfield::Bool end +include("models/normal.jl") +include("models/normallognormal.jl") -q0 = TuringDiagMvNormal(zeros(2), ones(2)) -q = vi(logπ, advi, q0, randn(4)) - -xs = rand(target, 10) -@test mean(abs2, logpdf(q, xs) - logpdf(target, xs)) ≤ 0.05 +# Tests +include("interface/ad.jl") +include("interface/optimize.jl") +include("interface/repgradelbo.jl") +include("inference/repgradelbo_distributionsad.jl") +include("inference/repgradelbo_distributionsad_bijectors.jl")