Skip to content

Commit

Permalink
Fix Enzyme tests on older Julia versions (#88)
Browse files Browse the repository at this point in the history
* fix load Enzyme only in Julia>=1.10

* fix add Enzyme to AD interface tests only for Julia>=1.10

* fix mistakenly added Enzyme test for repgradelbo interface test

* fix formatting

* decrement version
  • Loading branch information
Red-Portal authored Sep 8, 2024
1 parent 9222f63 commit 95bd86d
Show file tree
Hide file tree
Showing 7 changed files with 39 additions and 51 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "AdvancedVI"
uuid = "b5ca4192-6429-45e5-a2d9-87aec30a685c"
version = "0.3.1"
version = "0.3.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
23 changes: 9 additions & 14 deletions test/inference/repgradelbo_distributionsad.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,18 @@

AD_distributionsad = if VERSION >= v"1.10"
Dict(
:ForwarDiff => AutoForwardDiff(),
#:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment
:Zygote => AutoZygote(),
:Enzyme => AutoEnzyme(),
)
else
Dict(
:ForwarDiff => AutoForwardDiff(),
#:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment
:Zygote => AutoZygote(),
)
end
AD_distributionsad = Dict(
:ForwarDiff => AutoForwardDiff(),
#:ReverseDiff => AutoReverseDiff(), # DistributionsAD doesn't support ReverseDiff at the moment
:Zygote => AutoZygote(),
)

if @isdefined(Tapir)
AD_distributionsad[:Tapir] = AutoTapir(; safe_mode=false)
end

if @isdefined(Enzyme)
AD_distributionsad[:Enzyme] = AutoEnzyme()
end

@testset "inference RepGradELBO DistributionsAD" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
Expand Down
24 changes: 9 additions & 15 deletions test/inference/repgradelbo_locationscale.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,18 @@

AD_locationscale = if VERSION >= v"1.10"
Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
:Enzyme => AutoEnzyme(),
:Tapir => AutoTapir(; safe_mode=false),
)
else
Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
)
end
AD_locationscale = Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
)

if @isdefined(Tapir)
AD_locationscale[:Tapir] = AutoTapir(; safe_mode=false)
end

if @isdefined(Enzyme)
AD_locationscale[:Enzyme] = AutoEnzyme()
end

@testset "inference RepGradELBO VILocationScale" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
Expand Down
24 changes: 9 additions & 15 deletions test/inference/repgradelbo_locationscale_bijectors.jl
Original file line number Diff line number Diff line change
@@ -1,24 +1,18 @@

AD_locationscale_bijectors = if VERSION >= v"1.10"
Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
:Enzyme => AutoEnzyme(),
:Tapir => AutoTapir(; safe_mode=false),
)
else
Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
)
end
AD_locationscale_bijectors = Dict(
:ForwarDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
)

if @isdefined(Tapir)
AD_locationscale_bijectors[:Tapir] = AutoTapir(; safe_mode=false)
end

if @isdefined(Enzyme)
AD_locationscale_bijectors[:Enzyme] = AutoEnzyme()
end

@testset "inference RepGradELBO VILocationScale Bijectors" begin
@testset "$(modelname) $(objname) $(realtype) $(adbackname)" for realtype in
[Float64, Float32],
Expand Down
6 changes: 5 additions & 1 deletion test/interface/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@ const interface_ad_backends = Dict(
:ForwardDiff => AutoForwardDiff(),
:ReverseDiff => AutoReverseDiff(),
:Zygote => AutoZygote(),
:Enzyme => AutoEnzyme(),
)

if @isdefined(Tapir)
interface_ad_backends[:Tapir] = AutoTapir(; safe_mode=false)
end

if @isdefined(Enzyme)
interface_ad_backends[:Enzyme] = AutoEnzyme()
end

@testset "ad" begin
@testset "$(adname)" for (adname, adtype) in interface_ad_backends
D = 10
Expand Down
8 changes: 4 additions & 4 deletions test/interface/repgradelbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,14 @@ end
@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats

ad_backends = [
ADTypes.AutoForwardDiff(),
ADTypes.AutoReverseDiff(),
ADTypes.AutoZygote(),
ADTypes.AutoEnzyme(),
ADTypes.AutoForwardDiff(), ADTypes.AutoReverseDiff(), ADTypes.AutoZygote()
]
if @isdefined(Tapir)
push!(ad_backends, AutoTapir(; safe_mode=false))
end
if @isdefined(Enzyme)
push!(ad_backends, AutoEnzyme())
end

@testset for ad in ad_backends
q_true = MeanFieldGaussian(
Expand Down
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,12 @@ using DistributionsAD
using LogDensityProblems
using Optimisers
using ADTypes
using ForwardDiff, ReverseDiff, Zygote, Enzyme
using ForwardDiff, ReverseDiff, Zygote

if VERSION >= v"1.10"
Pkg.add("Tapir")
using Tapir
using Enzyme
end

using AdvancedVI
Expand Down

0 comments on commit 95bd86d

Please sign in to comment.