From 87730a0d34dad5e4d0ccccb511828cc7061bd553 Mon Sep 17 00:00:00 2001 From: Will Tebbutt Date: Tue, 3 Sep 2024 06:57:59 +0100 Subject: [PATCH] Tapir.jl Integration (#86) * Initial pass at adding Tapir jl * Prevent Tapir from being installed on v1.7 * Using Pkg in tests * Update ext/AdvancedVITapirExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- Project.toml | 4 +- ext/AdvancedVITapirExt.jl | 37 +++++++++++++++++++ test/inference/repgradelbo_distributionsad.jl | 4 ++ test/inference/repgradelbo_locationscale.jl | 5 +++ .../repgradelbo_locationscale_bijectors.jl | 5 +++ test/interface/ad.jl | 28 +++++++------- test/interface/repgradelbo.jl | 7 +++- test/runtests.jl | 6 +++ 8 files changed, 80 insertions(+), 16 deletions(-) create mode 100644 ext/AdvancedVITapirExt.jl diff --git a/Project.toml b/Project.toml index fff721f2..262a2e96 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "AdvancedVI" uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c" -version = "0.3.0" +version = "0.3.1" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" @@ -26,12 +26,14 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" +Tapir = "07d77754-e150-4737-8c94-cd238a1fb45b" [extensions] AdvancedVIBijectorsExt = "Bijectors" AdvancedVIEnzymeExt = "Enzyme" AdvancedVIForwardDiffExt = "ForwardDiff" AdvancedVIReverseDiffExt = "ReverseDiff" +AdvancedVITapirExt = "Tapir" AdvancedVIZygoteExt = "Zygote" [compat] diff --git a/ext/AdvancedVITapirExt.jl b/ext/AdvancedVITapirExt.jl new file mode 100644 index 00000000..459ef7da --- /dev/null +++ b/ext/AdvancedVITapirExt.jl @@ -0,0 +1,37 @@ +module AdvancedVITapirExt + +if isdefined(Base, :get_extension) + using AdvancedVI + using AdvancedVI: ADTypes, DiffResults + using Tapir +else + using ..AdvancedVI + using ..AdvancedVI: ADTypes, DiffResults + using ..Tapir +end + +function AdvancedVI.value_and_gradient!( + ::ADTypes.AutoTapir, f, x::AbstractVector{<:Real}, out::DiffResults.MutableDiffResult +) + rule = Tapir.build_rrule(f, x) + y, g = Tapir.value_and_gradient!!(rule, f, x) + DiffResults.value!(out, y) + DiffResults.gradient!(out, last(g)) + return out +end + +function AdvancedVI.value_and_gradient!( + ::ADTypes.AutoTapir, + f, + x::AbstractVector{<:Real}, + aux, + out::DiffResults.MutableDiffResult, +) + rule = Tapir.build_rrule(f, x, aux) + y, g = Tapir.value_and_gradient!!(rule, f, x, aux) + DiffResults.value!(out, y) + DiffResults.gradient!(out, g[2]) + return out +end + +end diff --git a/test/inference/repgradelbo_distributionsad.jl b/test/inference/repgradelbo_distributionsad.jl index 3a458e38..815981c7 100644 --- a/test/inference/repgradelbo_distributionsad.jl +++ b/test/inference/repgradelbo_distributionsad.jl @@ -14,6 +14,10 @@ else ) end +if @isdefined(Tapir) + AD_distributionsad[:Tapir] = AutoTapir(; safe_mode=false) +end + @testset "inference RepGradELBO DistributionsAD" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], diff --git a/test/inference/repgradelbo_locationscale.jl b/test/inference/repgradelbo_locationscale.jl index b0007706..dc643c74 100644 --- a/test/inference/repgradelbo_locationscale.jl +++ b/test/inference/repgradelbo_locationscale.jl @@ -5,6 +5,7 @@ AD_locationscale = if VERSION >= v"1.10" :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Enzyme => AutoEnzyme(), + :Tapir => AutoTapir(; safe_mode=false), ) else Dict( @@ -14,6 +15,10 @@ else ) end +if @isdefined(Tapir) + AD_locationscale[:Tapir] = AutoTapir(; safe_mode=false) +end + @testset "inference RepGradELBO VILocationScale" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], diff --git a/test/inference/repgradelbo_locationscale_bijectors.jl b/test/inference/repgradelbo_locationscale_bijectors.jl index 79c81c52..35355478 100644 --- a/test/inference/repgradelbo_locationscale_bijectors.jl +++ b/test/inference/repgradelbo_locationscale_bijectors.jl @@ -5,6 +5,7 @@ AD_locationscale_bijectors = if VERSION >= v"1.10" :ReverseDiff => AutoReverseDiff(), :Zygote => AutoZygote(), :Enzyme => AutoEnzyme(), + :Tapir => AutoTapir(; safe_mode=false), ) else Dict( @@ -14,6 +15,10 @@ else ) end +if @isdefined(Tapir) + AD_locationscale_bijectors[:Tapir] = AutoTapir(; safe_mode=false) +end + @testset "inference RepGradELBO VILocationScale Bijectors" begin @testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in [Float64, Float32], diff --git a/test/interface/ad.jl b/test/interface/ad.jl index 380c2b9b..791bcbb3 100644 --- a/test/interface/ad.jl +++ b/test/interface/ad.jl @@ -1,38 +1,38 @@ using Test +const interface_ad_backends = Dict( + :ForwardDiff => AutoForwardDiff(), + :ReverseDiff => AutoReverseDiff(), + :Zygote => AutoZygote(), + :Enzyme => AutoEnzyme(), +) +if @isdefined(Tapir) + interface_ad_backends[:Tapir] = AutoTapir(; safe_mode=false) +end + @testset "ad" begin - @testset "$(adname)" for (adname, adsymbol) in Dict( - :ForwardDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - :Enzyme => AutoEnzyme(), - ) + @testset "$(adname)" for (adname, adtype) in interface_ad_backends D = 10 A = randn(D, D) λ = randn(D) grad_buf = DiffResults.GradientResult(λ) f(λ′) = λ′' * A * λ′ / 2 - AdvancedVI.value_and_gradient!(adsymbol, f, λ, grad_buf) + AdvancedVI.value_and_gradient!(adtype, f, λ, grad_buf) ∇ = DiffResults.gradient(grad_buf) f = DiffResults.value(grad_buf) @test ∇ ≈ (A + A') * λ / 2 @test f ≈ λ' * A * λ / 2 end - @testset "$(adname) with auxiliary input" for (adname, adsymbol) in Dict( - :ForwardDiff => AutoForwardDiff(), - :ReverseDiff => AutoReverseDiff(), - :Zygote => AutoZygote(), - :Enzyme => AutoEnzyme(), - ) + @testset "$(adname) with auxiliary input" for (adname, adtype) in interface_ad_backends D = 10 A = randn(D, D) λ = randn(D) b = randn(D) grad_buf = DiffResults.GradientResult(λ) f(λ′, aux) = λ′' * A * λ′ / 2 + dot(aux.b, λ′) - AdvancedVI.value_and_gradient!(adsymbol, f, λ, (b=b,), grad_buf) + AdvancedVI.value_and_gradient!(adtype, f, λ, (b=b,), grad_buf) ∇ = DiffResults.gradient(grad_buf) f = DiffResults.value(grad_buf) @test ∇ ≈ (A + A') * λ / 2 + b diff --git a/test/interface/repgradelbo.jl b/test/interface/repgradelbo.jl index 5fec46ff..00eb2d37 100644 --- a/test/interface/repgradelbo.jl +++ b/test/interface/repgradelbo.jl @@ -34,12 +34,17 @@ end modelstats = normal_meanfield(rng, Float64) @unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats - @testset for ad in [ + ad_backends = [ ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoZygote(), ADTypes.AutoEnzyme(), ] + if @isdefined(Tapir) + push!(ad_backends, AutoTapir(; safe_mode=false)) + end + + @testset for ad in ad_backends q_true = MeanFieldGaussian( Vector{eltype(μ_true)}(μ_true), Diagonal(Vector{eltype(L_true)}(diag(L_true))) ) diff --git a/test/runtests.jl b/test/runtests.jl index 80194a43..5e92a5e5 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -8,6 +8,7 @@ using Distributions using FillArrays using LinearAlgebra using PDMats +using Pkg using Random, StableRNGs using SimpleUnPack: @unpack using Statistics @@ -22,6 +23,11 @@ using Optimisers using ADTypes using ForwardDiff, ReverseDiff, Zygote, Enzyme +if VERSION >= v"1.10" + Pkg.add("Tapir") + using Tapir +end + using AdvancedVI const GROUP = get(ENV, "GROUP", "All")