From 803d2f5672b7483c768b9987bcec0dc20257ebda Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Wed, 7 Aug 2024 10:09:35 +0100 Subject: [PATCH 1/4] Check that the correct AD backend is being used (#2291) * Add ADTypeCheckContext * Check ADType use in optimisation * Use ADTypeCheckContext with hmc tests * using A: A instead of import A * More robust ADTypeCheckContext checks for Zygote --- test/mcmc/hmc.jl | 10 ++ test/optimisation/Optimisation.jl | 14 +- test/runtests.jl | 1 + test/test_utils/ad_utils.jl | 270 ++++++++++++++++++++++++++++++ 4 files changed, 294 insertions(+), 1 deletion(-) create mode 100644 test/test_utils/ad_utils.jl diff --git a/test/mcmc/hmc.jl b/test/mcmc/hmc.jl index 968f24d7b7..b589d46873 100644 --- a/test/mcmc/hmc.jl +++ b/test/mcmc/hmc.jl @@ -1,6 +1,7 @@ module HMCTests using ..Models: gdemo_default +using ..ADUtils: ADTypeCheckContext #using ..Models: gdemo using ..NumericalTests: check_gdemo, check_numerical using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample @@ -321,6 +322,15 @@ using Turing # KS will compare the empirical CDFs, which seems like a reasonable thing to do here. @test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.001 end + + @testset "Check ADType" begin + alg = HMC(0.1, 10; adtype=adbackend) + m = DynamicPPL.contextualize( + gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context) + ) + # These will error if the adbackend being used is not the one set. + sample(rng, m, alg, 10) + end end end diff --git a/test/optimisation/Optimisation.jl b/test/optimisation/Optimisation.jl index 76d3a940d6..1ba073864f 100644 --- a/test/optimisation/Optimisation.jl +++ b/test/optimisation/Optimisation.jl @@ -1,6 +1,7 @@ module OptimisationTests using ..Models: gdemo, gdemo_default +using ..ADUtils: ADTypeCheckContext using Distributions using Distributions.FillArrays: Zeros using DynamicPPL: DynamicPPL @@ -140,7 +141,6 @@ using Turing gdemo_default, OptimizationOptimJL.LBFGS(); initial_params=true_value ) m3 = maximum_likelihood(gdemo_default, OptimizationOptimJL.Newton()) - # TODO(mhauru) How can we check that the adtype is actually AutoReverseDiff? m4 = maximum_likelihood( gdemo_default, OptimizationOptimJL.BFGS(); adtype=AutoReverseDiff() ) @@ -616,6 +616,18 @@ using Turing @assert vcat(get_a[:a], get_b[:b]) == result.values.array @assert get(result, :c) == (; :c => Array{Float64}[]) end + + @testset "ADType" begin + Random.seed!(222) + for adbackend in (AutoReverseDiff(), AutoForwardDiff(), AutoTracker()) + m = DynamicPPL.contextualize( + gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context) + ) + # These will error if the adbackend being used is not the one set. + maximum_likelihood(m; adtype=adbackend) + maximum_a_posteriori(m; adtype=adbackend) + end + end end end diff --git a/test/runtests.jl b/test/runtests.jl index 48a00122d4..1aa8bb635b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -7,6 +7,7 @@ import Turing include(pkgdir(Turing) * "/test/test_utils/models.jl") include(pkgdir(Turing) * "/test/test_utils/numerical_tests.jl") +include(pkgdir(Turing) * "/test/test_utils/ad_utils.jl") Turing.setprogress!(false) diff --git a/test/test_utils/ad_utils.jl b/test/test_utils/ad_utils.jl new file mode 100644 index 0000000000..7e47cd9eee --- /dev/null +++ b/test/test_utils/ad_utils.jl @@ -0,0 +1,270 @@ +module ADUtils + +using ForwardDiff: ForwardDiff +using ReverseDiff: ReverseDiff +using Test: Test +using Tracker: Tracker +using Turing: Turing +using Turing: DynamicPPL +using Zygote: Zygote + +export ADTypeCheckContext + +"""Element types that are always valid for a VarInfo regardless of ADType.""" +const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational) + +"""A dictionary mapping ADTypes to the element types they use.""" +const eltypes_by_adtype = Dict( + Turing.AutoForwardDiff => (ForwardDiff.Dual,), + Turing.AutoReverseDiff => ( + ReverseDiff.TrackedArray, + ReverseDiff.TrackedMatrix, + ReverseDiff.TrackedReal, + ReverseDiff.TrackedStyle, + ReverseDiff.TrackedType, + ReverseDiff.TrackedVecOrMat, + ReverseDiff.TrackedVector, + ), + # Zygote.Dual is actually the same as ForwardDiff.Dual, so can't distinguish between the + # two by element type. However, we have other checks for Zygote, see check_adtype. + Turing.AutoZygote => (Zygote.Dual,), + Turing.AutoTracker => ( + Tracker.Tracked, + Tracker.TrackedArray, + Tracker.TrackedMatrix, + Tracker.TrackedReal, + Tracker.TrackedStyle, + Tracker.TrackedVecOrMat, + Tracker.TrackedVector, + ), +) + +""" + AbstractWrongADBackendError + +An abstract error thrown when we seem to be using a different AD backend than expected. +""" +abstract type AbstractWrongADBackendError <: Exception end + +""" + WrongADBackendError + +An error thrown when we seem to be using a different AD backend than expected. +""" +struct WrongADBackendError <: AbstractWrongADBackendError + actual_adtype::Type + expected_adtype::Type +end + +function Base.showerror(io::IO, e::WrongADBackendError) + return print( + io, "Expected to use $(e.expected_adtype), but using $(e.actual_adtype) instead." + ) +end + +""" + IncompatibleADTypeError + +An error thrown when an element type is encountered that is unexpected for the given ADType. +""" +struct IncompatibleADTypeError <: AbstractWrongADBackendError + valtype::Type + adtype::Type +end + +function Base.showerror(io::IO, e::IncompatibleADTypeError) + return print( + io, + "Incompatible ADType: Did not expect element of type $(e.valtype) with $(e.adtype)", + ) +end + +""" + ADTypeCheckContext{ADType,ChildContext} + +A context for checking that the expected ADType is being used. + +Evaluating a model with this context will check that the types of values in a `VarInfo` are +compatible with the ADType of the context. If the check fails, an `IncompatibleADTypeError` +is thrown. + +For instance, evaluating a model with +`ADTypeCheckContext(AutoForwardDiff(), child_context)` +would throw an error if within the model a type associated with e.g. ReverseDiff was +encountered. + +As a current short-coming, this context can not distinguish between ForwardDiff and Zygote. +""" +struct ADTypeCheckContext{ADType,ChildContext<:DynamicPPL.AbstractContext} <: + DynamicPPL.AbstractContext + child::ChildContext + + function ADTypeCheckContext(adbackend, child) + adtype = adbackend isa Type ? adbackend : typeof(adbackend) + if !any(adtype <: k for k in keys(eltypes_by_adtype)) + throw(ArgumentError("Unsupported ADType: $adtype")) + end + return new{adtype,typeof(child)}(child) + end +end + +adtype(_::ADTypeCheckContext{ADType}) where {ADType} = ADType + +DynamicPPL.NodeTrait(::ADTypeCheckContext) = DynamicPPL.IsParent() +DynamicPPL.childcontext(c::ADTypeCheckContext) = c.child +function DynamicPPL.setchildcontext(c::ADTypeCheckContext, child) + return ADTypeCheckContext(adtype(c), child) +end + +""" + valid_eltypes(context::ADTypeCheckContext) + +Return the element types that are valid for the ADType of `context` as a tuple. +""" +function valid_eltypes(context::ADTypeCheckContext) + context_at = adtype(context) + for at in keys(eltypes_by_adtype) + if context_at <: at + return (eltypes_by_adtype[at]..., always_valid_eltypes...) + end + end + # This should never be reached due to the check in the inner constructor. + throw(ArgumentError("Unsupported ADType: $(adtype(context))")) +end + +""" + check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.VarInfo) + +Check that the element types in `vi` are compatible with the ADType of `context`. + +When Zygote is being used, we also more explicitly check that `adtype(context)` is +`AutoZygote`. This is because Zygote uses the same element type as ForwardDiff, so we can't +discriminate between the two based on element type alone. This function will still fail to +catch cases where Zygote is supposed to be used, but ForwardDiff is used instead. + +Throw an `IncompatibleADTypeError` if an incompatible element type is encountered, or +`WrongADBackendError` if Zygote is used unexpectedly. +""" +function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo) + Zygote.hook(vi) do _ + if !(adtype(context) <: Turing.AutoZygote) + throw(WrongADBackendError(Turing.AutoZygote, adtype(context))) + end + end + + valids = valid_eltypes(context) + for val in vi[:] + valtype = typeof(val) + if !any(valtype .<: valids) + throw(IncompatibleADTypeError(valtype, adtype(context))) + end + end + return nothing +end + +# A bunch of tilde_assume/tilde_observe methods that just call the same method on the child +# context, and then call check_adtype on the result before returning the results from the +# child context. + +function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi) + value, logp, vi = DynamicPPL.tilde_assume( + DynamicPPL.childcontext(context), right, vn, vi + ) + check_adtype(context, vi) + return value, logp, vi +end + +function DynamicPPL.tilde_assume(rng, context::ADTypeCheckContext, sampler, right, vn, vi) + value, logp, vi = DynamicPPL.tilde_assume( + rng, DynamicPPL.childcontext(context), sampler, right, vn, vi + ) + check_adtype(context, vi) + return value, logp, vi +end + +function DynamicPPL.tilde_observe(context::ADTypeCheckContext, right, left, vi) + logp, vi = DynamicPPL.tilde_observe(DynamicPPL.childcontext(context), right, left, vi) + check_adtype(context, vi) + return logp, vi +end + +function DynamicPPL.tilde_observe(context::ADTypeCheckContext, sampler, right, left, vi) + logp, vi = DynamicPPL.tilde_observe( + DynamicPPL.childcontext(context), sampler, right, left, vi + ) + check_adtype(context, vi) + return logp, vi +end + +function DynamicPPL.dot_tilde_assume(context::ADTypeCheckContext, right, left, vn, vi) + value, logp, vi = DynamicPPL.dot_tilde_assume( + DynamicPPL.childcontext(context), right, left, vn, vi + ) + check_adtype(context, vi) + return value, logp, vi +end + +function DynamicPPL.dot_tilde_assume( + rng, context::ADTypeCheckContext, sampler, right, left, vn, vi +) + value, logp, vi = DynamicPPL.dot_tilde_assume( + rng, DynamicPPL.childcontext(context), sampler, right, left, vn, vi + ) + check_adtype(context, vi) + return value, logp, vi +end + +function DynamicPPL.dot_tilde_observe(context::ADTypeCheckContext, right, left, vi) + logp, vi = DynamicPPL.dot_tilde_observe( + DynamicPPL.childcontext(context), right, left, vi + ) + check_adtype(context, vi) + return logp, vi +end + +function DynamicPPL.dot_tilde_observe(context::ADTypeCheckContext, sampler, right, left, vi) + logp, vi = DynamicPPL.dot_tilde_observe( + DynamicPPL.childcontext(context), sampler, right, left, vi + ) + check_adtype(context, vi) + return logp, vi +end + +# Check that the ADTypeCheckContext works as expected. +Test.@testset "ADTypeCheckContext" begin + Turing.@model test_model() = x ~ Turing.Normal(0, 1) + tm = test_model() + adtypes = ( + Turing.AutoForwardDiff(), + Turing.AutoReverseDiff(), + Turing.AutoZygote(), + Turing.AutoTracker(), + ) + for actual_adtype in adtypes + sampler = Turing.HMC(0.1, 5; adtype=actual_adtype) + for expected_adtype in adtypes + if ( + actual_adtype == Turing.AutoForwardDiff() && + expected_adtype == Turing.AutoZygote() + ) + # TODO(mhauru) We are currently unable to check this case. + continue + end + contextualised_tm = DynamicPPL.contextualize( + tm, ADTypeCheckContext(expected_adtype, tm.context) + ) + Test.@testset "Expected: $expected_adtype, Actual: $actual_adtype" begin + if actual_adtype == expected_adtype + # Check that this does not throw an error. + Turing.sample(contextualised_tm, sampler, 2) + else + Test.@test_throws AbstractWrongADBackendError Turing.sample( + contextualised_tm, sampler, 2 + ) + end + end + end + end +end + +end From 07cc40beb0c6caa60e945e204f0fbc88cd3d4362 Mon Sep 17 00:00:00 2001 From: Markus Hauru Date: Fri, 9 Aug 2024 18:47:04 +0200 Subject: [PATCH 2/4] Resolve ADTypeCheckContext method ambiguity (#2299) --- test/test_utils/ad_utils.jl | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/test/test_utils/ad_utils.jl b/test/test_utils/ad_utils.jl index 7e47cd9eee..27b469ed7e 100644 --- a/test/test_utils/ad_utils.jl +++ b/test/test_utils/ad_utils.jl @@ -1,6 +1,7 @@ module ADUtils using ForwardDiff: ForwardDiff +using Random: Random using ReverseDiff: ReverseDiff using Test: Test using Tracker: Tracker @@ -174,7 +175,9 @@ function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi) return value, logp, vi end -function DynamicPPL.tilde_assume(rng, context::ADTypeCheckContext, sampler, right, vn, vi) +function DynamicPPL.tilde_assume( + rng::Random.AbstractRNG, context::ADTypeCheckContext, sampler, right, vn, vi +) value, logp, vi = DynamicPPL.tilde_assume( rng, DynamicPPL.childcontext(context), sampler, right, vn, vi ) @@ -205,7 +208,7 @@ function DynamicPPL.dot_tilde_assume(context::ADTypeCheckContext, right, left, v end function DynamicPPL.dot_tilde_assume( - rng, context::ADTypeCheckContext, sampler, right, left, vn, vi + rng::Random.AbstractRNG, context::ADTypeCheckContext, sampler, right, left, vn, vi ) value, logp, vi = DynamicPPL.dot_tilde_assume( rng, DynamicPPL.childcontext(context), sampler, right, left, vn, vi From 5b5da11dfcd7725dcf2afe705feff867a93e8b62 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Fri, 30 Aug 2024 10:08:00 +0200 Subject: [PATCH 3/4] Remove unused code (#2312) --- src/mcmc/hmc.jl | 55 ------------------------------------------------- 1 file changed, 55 deletions(-) diff --git a/src/mcmc/hmc.jl b/src/mcmc/hmc.jl index f17bca2e73..29cc8348c9 100644 --- a/src/mcmc/hmc.jl +++ b/src/mcmc/hmc.jl @@ -563,58 +563,3 @@ end function AHMCAdaptor(::Hamiltonian, ::AHMC.AbstractMetric; kwargs...) return AHMC.Adaptation.NoAdaptation() end - -########################## -# HMC State Constructors # -########################## - -function HMCState( - rng::AbstractRNG, - model::Model, - spl::Sampler{<:Hamiltonian}, - vi::AbstractVarInfo; - kwargs..., -) - # Link everything if needed. - waslinked = islinked(vi, spl) - if !waslinked - vi = link!!(vi, spl, model) - end - - # Get the initial log pdf and gradient functions. - ∂logπ∂θ = gen_∂logπ∂θ(vi, spl, model) - logπ = Turing.LogDensityFunction( - vi, - model, - DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)), - ) - - # Get the metric type. - metricT = getmetricT(spl.alg) - - # Create a Hamiltonian. - θ_init = Vector{Float64}(spl.state.vi[spl]) - metric = metricT(length(θ_init)) - h = AHMC.Hamiltonian(metric, logπ, ∂logπ∂θ) - - # Find good eps if not provided one - if iszero(spl.alg.ϵ) - ϵ = AHMC.find_good_stepsize(rng, h, θ_init) - @info "Found initial step size" ϵ - else - ϵ = spl.alg.ϵ - end - - # Generate a kernel. - kernel = make_ahmc_kernel(spl.alg, ϵ) - - # Generate a phasepoint. Replaced during sample_init! - h, t = AHMC.sample_init(rng, h, θ_init) # this also ensure AHMC has the same dim as θ. - - # Unlink everything, if it was indeed linked before. - if waslinked - vi = invlink!!(vi, spl, model) - end - - return HMCState(vi, 0, 0, kernel.τ, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z) -end From a26ce1198354cdb54b352f659369694b11bf489f Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Fri, 30 Aug 2024 15:59:19 +0100 Subject: [PATCH 4/4] Fix remaining method ambiguities (#2304) * Enabling aqua ambiguity testing for Turing We test ambiguities only for Turing and not its dependencies. * Format * Fix bundle_samples method ambiguity Concretely: 1. Creating an `AbstractTransition` type which all the Transitions in Turing subtype. 2. Modifying the type signature of bundle_samples to take a Vector{<:Union{AbstractTransition,AbstractVarInfo}} as the first argument. The AbstractVarInfo case occurs when sampling with Prior(), so the type signature of this argument mirrors that of the Sampler in the same function. * Fix get() ambiguities Done by: 1. Constraining the type parameter to AbstractVector{Symbol} 2. Modifying the method below it to use a vector instead of a tuple * Bump to 0.34.0 --------- Co-authored-by: Abhinav Singh <73834077+abhinavsns@users.noreply.github.com> Co-authored-by: Xianda Sun <5433119+sunxd3@users.noreply.github.com> --- Project.toml | 2 +- src/mcmc/Inference.jl | 8 +++++--- src/mcmc/particle_mcmc.jl | 4 ++-- src/mcmc/sghmc.jl | 2 +- src/optimisation/Optimisation.jl | 8 ++++---- test/Aqua.jl | 5 +++-- 6 files changed, 16 insertions(+), 13 deletions(-) diff --git a/Project.toml b/Project.toml index 33d4be908a..62dfe4b204 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.33.3" +version = "0.34.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 5f05121e74..b7bdf206b9 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -213,7 +213,9 @@ end # Extended in contrib/inference/abstractmcmc.jl getstats(t) = nothing -struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}} +abstract type AbstractTransition end + +struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}} <: AbstractTransition θ :: T lp :: F # TODO: merge `lp` with `stat` stat :: S @@ -409,7 +411,7 @@ getlogevidence(transitions, sampler, state) = missing # Default MCMCChains.Chains constructor. # This is type piracy (at least for SampleFromPrior). function AbstractMCMC.bundle_samples( - ts::Vector, + ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}}, model::AbstractModel, spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior}, state, @@ -472,7 +474,7 @@ end # This is type piracy (for SampleFromPrior). function AbstractMCMC.bundle_samples( - ts::Vector, + ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}}, model::AbstractModel, spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior}, state, diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 8d38660c67..02a53766e0 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -45,7 +45,7 @@ end SMC(space::Symbol...) = SMC(space) SMC(space::Tuple) = SMC(AdvancedPS.ResampleWithESSThreshold(), space) -struct SMCTransition{T,F<:AbstractFloat} +struct SMCTransition{T,F<:AbstractFloat} <: AbstractTransition "The parameters for any given sample." θ::T "The joint log probability of the sample (NOTE: does not work, always set to zero)." @@ -222,7 +222,7 @@ end const CSMC = PG # type alias of PG as Conditional SMC -struct PGTransition{T,F<:AbstractFloat} +struct PGTransition{T,F<:AbstractFloat} <: AbstractTransition "The parameters for any given sample." θ::T "The joint log probability of the sample (NOTE: does not work, always set to zero)." diff --git a/src/mcmc/sghmc.jl b/src/mcmc/sghmc.jl index 84a9c18f3c..fbc4b11868 100644 --- a/src/mcmc/sghmc.jl +++ b/src/mcmc/sghmc.jl @@ -193,7 +193,7 @@ function SGLD( return SGLD{typeof(adtype),space,typeof(stepsize)}(stepsize, adtype) end -struct SGLDTransition{T,F<:Real} +struct SGLDTransition{T,F<:Real} <: AbstractTransition "The parameters for any given sample." θ::T "The joint log probability of the sample." diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index e069897920..eecfcad22b 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -277,13 +277,13 @@ StatsBase.loglikelihood(m::ModeResult) = m.lp """ Base.get(m::ModeResult, var_symbol::Symbol) - Base.get(m::ModeResult, var_symbols) + Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol}) Return the values of all the variables with the symbol(s) `var_symbol` in the mode result `m`. The return value is a `NamedTuple` with `var_symbols` as the key(s). The second -argument should be either a `Symbol` or an iterator of `Symbol`s. +argument should be either a `Symbol` or a vector of `Symbol`s. """ -function Base.get(m::ModeResult, var_symbols) +function Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol}) log_density = m.f # Get all the variable names in the model. This is the same as the list of keys in # m.values, but they are more convenient to filter when they are VarNames rather than @@ -304,7 +304,7 @@ function Base.get(m::ModeResult, var_symbols) return (; zip(var_symbols, value_vectors)...) end -Base.get(m::ModeResult, var_symbol::Symbol) = get(m, (var_symbol,)) +Base.get(m::ModeResult, var_symbol::Symbol) = get(m, [var_symbol]) """ ModeResult(log_density::OptimLogDensity, solution::SciMLBase.OptimizationSolution) diff --git a/test/Aqua.jl b/test/Aqua.jl index 0b536770b2..e159cae9ca 100644 --- a/test/Aqua.jl +++ b/test/Aqua.jl @@ -3,8 +3,9 @@ module AquaTests using Aqua: Aqua using Turing -# TODO(mhauru) We skip testing for method ambiguities because it catches a lot of problems -# in dependencies. Would like to check it for just Turing.jl itself though. +# We test ambiguities separately because it catches a lot of problems +# in dependencies but we test it for Turing. +Aqua.test_ambiguities([Turing]) Aqua.test_all(Turing; ambiguities=false) end