Skip to content

Commit

Permalink
Check that the correct AD backend is being used (#2291)
Browse files Browse the repository at this point in the history
* Add ADTypeCheckContext

* Check ADType use in optimisation

* Use ADTypeCheckContext with hmc tests

* using A: A instead of import A

* More robust ADTypeCheckContext checks for Zygote
  • Loading branch information
mhauru authored Aug 7, 2024
1 parent 4766fdd commit 803d2f5
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 1 deletion.
10 changes: 10 additions & 0 deletions test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module HMCTests

using ..Models: gdemo_default
using ..ADUtils: ADTypeCheckContext
#using ..Models: gdemo
using ..NumericalTests: check_gdemo, check_numerical
using Distributions: Bernoulli, Beta, Categorical, Dirichlet, Normal, Wishart, sample
Expand Down Expand Up @@ -321,6 +322,15 @@ using Turing
# KS will compare the empirical CDFs, which seems like a reasonable thing to do here.
@test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.001
end

@testset "Check ADType" begin
alg = HMC(0.1, 10; adtype=adbackend)
m = DynamicPPL.contextualize(
gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context)
)
# These will error if the adbackend being used is not the one set.
sample(rng, m, alg, 10)
end
end

end
14 changes: 13 additions & 1 deletion test/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module OptimisationTests

using ..Models: gdemo, gdemo_default
using ..ADUtils: ADTypeCheckContext
using Distributions
using Distributions.FillArrays: Zeros
using DynamicPPL: DynamicPPL
Expand Down Expand Up @@ -140,7 +141,6 @@ using Turing
gdemo_default, OptimizationOptimJL.LBFGS(); initial_params=true_value
)
m3 = maximum_likelihood(gdemo_default, OptimizationOptimJL.Newton())
# TODO(mhauru) How can we check that the adtype is actually AutoReverseDiff?
m4 = maximum_likelihood(
gdemo_default, OptimizationOptimJL.BFGS(); adtype=AutoReverseDiff()
)
Expand Down Expand Up @@ -616,6 +616,18 @@ using Turing
@assert vcat(get_a[:a], get_b[:b]) == result.values.array
@assert get(result, :c) == (; :c => Array{Float64}[])
end

@testset "ADType" begin
Random.seed!(222)
for adbackend in (AutoReverseDiff(), AutoForwardDiff(), AutoTracker())
m = DynamicPPL.contextualize(
gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context)
)
# These will error if the adbackend being used is not the one set.
maximum_likelihood(m; adtype=adbackend)
maximum_a_posteriori(m; adtype=adbackend)
end
end
end

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import Turing

include(pkgdir(Turing) * "/test/test_utils/models.jl")
include(pkgdir(Turing) * "/test/test_utils/numerical_tests.jl")
include(pkgdir(Turing) * "/test/test_utils/ad_utils.jl")

Turing.setprogress!(false)

Expand Down
270 changes: 270 additions & 0 deletions test/test_utils/ad_utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,270 @@
module ADUtils

using ForwardDiff: ForwardDiff
using ReverseDiff: ReverseDiff
using Test: Test
using Tracker: Tracker
using Turing: Turing
using Turing: DynamicPPL
using Zygote: Zygote

export ADTypeCheckContext

"""Element types that are always valid for a VarInfo regardless of ADType."""
const always_valid_eltypes = (AbstractFloat, AbstractIrrational, Integer, Rational)

"""A dictionary mapping ADTypes to the element types they use."""
const eltypes_by_adtype = Dict(
Turing.AutoForwardDiff => (ForwardDiff.Dual,),
Turing.AutoReverseDiff => (
ReverseDiff.TrackedArray,
ReverseDiff.TrackedMatrix,
ReverseDiff.TrackedReal,
ReverseDiff.TrackedStyle,
ReverseDiff.TrackedType,
ReverseDiff.TrackedVecOrMat,
ReverseDiff.TrackedVector,
),
# Zygote.Dual is actually the same as ForwardDiff.Dual, so can't distinguish between the
# two by element type. However, we have other checks for Zygote, see check_adtype.
Turing.AutoZygote => (Zygote.Dual,),
Turing.AutoTracker => (
Tracker.Tracked,
Tracker.TrackedArray,
Tracker.TrackedMatrix,
Tracker.TrackedReal,
Tracker.TrackedStyle,
Tracker.TrackedVecOrMat,
Tracker.TrackedVector,
),
)

"""
AbstractWrongADBackendError
An abstract error thrown when we seem to be using a different AD backend than expected.
"""
abstract type AbstractWrongADBackendError <: Exception end

"""
WrongADBackendError
An error thrown when we seem to be using a different AD backend than expected.
"""
struct WrongADBackendError <: AbstractWrongADBackendError
actual_adtype::Type
expected_adtype::Type
end

function Base.showerror(io::IO, e::WrongADBackendError)
return print(
io, "Expected to use $(e.expected_adtype), but using $(e.actual_adtype) instead."
)
end

"""
IncompatibleADTypeError
An error thrown when an element type is encountered that is unexpected for the given ADType.
"""
struct IncompatibleADTypeError <: AbstractWrongADBackendError
valtype::Type
adtype::Type
end

function Base.showerror(io::IO, e::IncompatibleADTypeError)
return print(
io,
"Incompatible ADType: Did not expect element of type $(e.valtype) with $(e.adtype)",
)
end

"""
ADTypeCheckContext{ADType,ChildContext}
A context for checking that the expected ADType is being used.
Evaluating a model with this context will check that the types of values in a `VarInfo` are
compatible with the ADType of the context. If the check fails, an `IncompatibleADTypeError`
is thrown.
For instance, evaluating a model with
`ADTypeCheckContext(AutoForwardDiff(), child_context)`
would throw an error if within the model a type associated with e.g. ReverseDiff was
encountered.
As a current short-coming, this context can not distinguish between ForwardDiff and Zygote.
"""
struct ADTypeCheckContext{ADType,ChildContext<:DynamicPPL.AbstractContext} <:
DynamicPPL.AbstractContext
child::ChildContext

function ADTypeCheckContext(adbackend, child)
adtype = adbackend isa Type ? adbackend : typeof(adbackend)
if !any(adtype <: k for k in keys(eltypes_by_adtype))
throw(ArgumentError("Unsupported ADType: $adtype"))
end
return new{adtype,typeof(child)}(child)
end
end

adtype(_::ADTypeCheckContext{ADType}) where {ADType} = ADType

DynamicPPL.NodeTrait(::ADTypeCheckContext) = DynamicPPL.IsParent()
DynamicPPL.childcontext(c::ADTypeCheckContext) = c.child
function DynamicPPL.setchildcontext(c::ADTypeCheckContext, child)
return ADTypeCheckContext(adtype(c), child)
end

"""
valid_eltypes(context::ADTypeCheckContext)
Return the element types that are valid for the ADType of `context` as a tuple.
"""
function valid_eltypes(context::ADTypeCheckContext)
context_at = adtype(context)
for at in keys(eltypes_by_adtype)
if context_at <: at
return (eltypes_by_adtype[at]..., always_valid_eltypes...)
end
end
# This should never be reached due to the check in the inner constructor.
throw(ArgumentError("Unsupported ADType: $(adtype(context))"))
end

"""
check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.VarInfo)
Check that the element types in `vi` are compatible with the ADType of `context`.
When Zygote is being used, we also more explicitly check that `adtype(context)` is
`AutoZygote`. This is because Zygote uses the same element type as ForwardDiff, so we can't
discriminate between the two based on element type alone. This function will still fail to
catch cases where Zygote is supposed to be used, but ForwardDiff is used instead.
Throw an `IncompatibleADTypeError` if an incompatible element type is encountered, or
`WrongADBackendError` if Zygote is used unexpectedly.
"""
function check_adtype(context::ADTypeCheckContext, vi::DynamicPPL.AbstractVarInfo)
Zygote.hook(vi) do _
if !(adtype(context) <: Turing.AutoZygote)
throw(WrongADBackendError(Turing.AutoZygote, adtype(context)))
end
end

valids = valid_eltypes(context)
for val in vi[:]
valtype = typeof(val)
if !any(valtype .<: valids)
throw(IncompatibleADTypeError(valtype, adtype(context)))
end
end
return nothing
end

# A bunch of tilde_assume/tilde_observe methods that just call the same method on the child
# context, and then call check_adtype on the result before returning the results from the
# child context.

function DynamicPPL.tilde_assume(context::ADTypeCheckContext, right, vn, vi)
value, logp, vi = DynamicPPL.tilde_assume(
DynamicPPL.childcontext(context), right, vn, vi
)
check_adtype(context, vi)
return value, logp, vi
end

function DynamicPPL.tilde_assume(rng, context::ADTypeCheckContext, sampler, right, vn, vi)
value, logp, vi = DynamicPPL.tilde_assume(
rng, DynamicPPL.childcontext(context), sampler, right, vn, vi
)
check_adtype(context, vi)
return value, logp, vi
end

function DynamicPPL.tilde_observe(context::ADTypeCheckContext, right, left, vi)
logp, vi = DynamicPPL.tilde_observe(DynamicPPL.childcontext(context), right, left, vi)
check_adtype(context, vi)
return logp, vi
end

function DynamicPPL.tilde_observe(context::ADTypeCheckContext, sampler, right, left, vi)
logp, vi = DynamicPPL.tilde_observe(
DynamicPPL.childcontext(context), sampler, right, left, vi
)
check_adtype(context, vi)
return logp, vi
end

function DynamicPPL.dot_tilde_assume(context::ADTypeCheckContext, right, left, vn, vi)
value, logp, vi = DynamicPPL.dot_tilde_assume(
DynamicPPL.childcontext(context), right, left, vn, vi
)
check_adtype(context, vi)
return value, logp, vi
end

function DynamicPPL.dot_tilde_assume(
rng, context::ADTypeCheckContext, sampler, right, left, vn, vi
)
value, logp, vi = DynamicPPL.dot_tilde_assume(
rng, DynamicPPL.childcontext(context), sampler, right, left, vn, vi
)
check_adtype(context, vi)
return value, logp, vi
end

function DynamicPPL.dot_tilde_observe(context::ADTypeCheckContext, right, left, vi)
logp, vi = DynamicPPL.dot_tilde_observe(
DynamicPPL.childcontext(context), right, left, vi
)
check_adtype(context, vi)
return logp, vi
end

function DynamicPPL.dot_tilde_observe(context::ADTypeCheckContext, sampler, right, left, vi)
logp, vi = DynamicPPL.dot_tilde_observe(
DynamicPPL.childcontext(context), sampler, right, left, vi
)
check_adtype(context, vi)
return logp, vi
end

# Check that the ADTypeCheckContext works as expected.
Test.@testset "ADTypeCheckContext" begin
Turing.@model test_model() = x ~ Turing.Normal(0, 1)
tm = test_model()
adtypes = (
Turing.AutoForwardDiff(),
Turing.AutoReverseDiff(),
Turing.AutoZygote(),
Turing.AutoTracker(),
)
for actual_adtype in adtypes
sampler = Turing.HMC(0.1, 5; adtype=actual_adtype)
for expected_adtype in adtypes
if (
actual_adtype == Turing.AutoForwardDiff() &&
expected_adtype == Turing.AutoZygote()
)
# TODO(mhauru) We are currently unable to check this case.
continue
end
contextualised_tm = DynamicPPL.contextualize(
tm, ADTypeCheckContext(expected_adtype, tm.context)
)
Test.@testset "Expected: $expected_adtype, Actual: $actual_adtype" begin
if actual_adtype == expected_adtype
# Check that this does not throw an error.
Turing.sample(contextualised_tm, sampler, 2)
else
Test.@test_throws AbstractWrongADBackendError Turing.sample(
contextualised_tm, sampler, 2
)
end
end
end
end
end

end

0 comments on commit 803d2f5

Please sign in to comment.