Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/master' into mhauru/tapir-tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mhauru committed Sep 2, 2024
2 parents fab002d + a26ce11 commit 2f93d75
Show file tree
Hide file tree
Showing 10 changed files with 315 additions and 71 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Turing"
uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
version = "0.33.3"
version = "0.34.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
8 changes: 5 additions & 3 deletions src/mcmc/Inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ end
# Extended in contrib/inference/abstractmcmc.jl
getstats(t) = nothing

struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}}
abstract type AbstractTransition end

struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}} <: AbstractTransition
θ :: T
lp :: F # TODO: merge `lp` with `stat`
stat :: S
Expand Down Expand Up @@ -409,7 +411,7 @@ getlogevidence(transitions, sampler, state) = missing
# Default MCMCChains.Chains constructor.
# This is type piracy (at least for SampleFromPrior).
function AbstractMCMC.bundle_samples(
ts::Vector,
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
model::AbstractModel,
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior},
state,
Expand Down Expand Up @@ -472,7 +474,7 @@ end

# This is type piracy (for SampleFromPrior).
function AbstractMCMC.bundle_samples(
ts::Vector,
ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}},
model::AbstractModel,
spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior},
state,
Expand Down
55 changes: 0 additions & 55 deletions src/mcmc/hmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -563,58 +563,3 @@ end
function AHMCAdaptor(::Hamiltonian, ::AHMC.AbstractMetric; kwargs...)
return AHMC.Adaptation.NoAdaptation()
end

##########################
# HMC State Constructors #
##########################

function HMCState(
rng::AbstractRNG,
model::Model,
spl::Sampler{<:Hamiltonian},
vi::AbstractVarInfo;
kwargs...,
)
# Link everything if needed.
waslinked = islinked(vi, spl)
if !waslinked
vi = link!!(vi, spl, model)
end

# Get the initial log pdf and gradient functions.
∂logπ∂θ = gen_∂logπ∂θ(vi, spl, model)
logπ = Turing.LogDensityFunction(
vi,
model,
DynamicPPL.SamplingContext(rng, spl, DynamicPPL.leafcontext(model.context)),
)

# Get the metric type.
metricT = getmetricT(spl.alg)

# Create a Hamiltonian.
θ_init = Vector{Float64}(spl.state.vi[spl])
metric = metricT(length(θ_init))
h = AHMC.Hamiltonian(metric, logπ, ∂logπ∂θ)

# Find good eps if not provided one
if iszero(spl.alg.ϵ)
ϵ = AHMC.find_good_stepsize(rng, h, θ_init)
@info "Found initial step size" ϵ
else
ϵ = spl.alg.ϵ
end

# Generate a kernel.
kernel = make_ahmc_kernel(spl.alg, ϵ)

# Generate a phasepoint. Replaced during sample_init!
h, t = AHMC.sample_init(rng, h, θ_init) # this also ensure AHMC has the same dim as θ.

# Unlink everything, if it was indeed linked before.
if waslinked
vi = invlink!!(vi, spl, model)
end

return HMCState(vi, 0, 0, kernel.τ, h, AHMCAdaptor(spl.alg, metric; ϵ=ϵ), t.z)
end
4 changes: 2 additions & 2 deletions src/mcmc/particle_mcmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ end
SMC(space::Symbol...) = SMC(space)
SMC(space::Tuple) = SMC(AdvancedPS.ResampleWithESSThreshold(), space)

struct SMCTransition{T,F<:AbstractFloat}
struct SMCTransition{T,F<:AbstractFloat} <: AbstractTransition
"The parameters for any given sample."
θ::T
"The joint log probability of the sample (NOTE: does not work, always set to zero)."
Expand Down Expand Up @@ -222,7 +222,7 @@ end

const CSMC = PG # type alias of PG as Conditional SMC

struct PGTransition{T,F<:AbstractFloat}
struct PGTransition{T,F<:AbstractFloat} <: AbstractTransition
"The parameters for any given sample."
θ::T
"The joint log probability of the sample (NOTE: does not work, always set to zero)."
Expand Down
2 changes: 1 addition & 1 deletion src/mcmc/sghmc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ function SGLD(
return SGLD{typeof(adtype),space,typeof(stepsize)}(stepsize, adtype)
end

struct SGLDTransition{T,F<:Real}
struct SGLDTransition{T,F<:Real} <: AbstractTransition
"The parameters for any given sample."
θ::T
"The joint log probability of the sample."
Expand Down
8 changes: 4 additions & 4 deletions src/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,13 +277,13 @@ StatsBase.loglikelihood(m::ModeResult) = m.lp

"""
Base.get(m::ModeResult, var_symbol::Symbol)
Base.get(m::ModeResult, var_symbols)
Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})
Return the values of all the variables with the symbol(s) `var_symbol` in the mode result
`m`. The return value is a `NamedTuple` with `var_symbols` as the key(s). The second
argument should be either a `Symbol` or an iterator of `Symbol`s.
argument should be either a `Symbol` or a vector of `Symbol`s.
"""
function Base.get(m::ModeResult, var_symbols)
function Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol})
log_density = m.f
# Get all the variable names in the model. This is the same as the list of keys in
# m.values, but they are more convenient to filter when they are VarNames rather than
Expand All @@ -304,7 +304,7 @@ function Base.get(m::ModeResult, var_symbols)
return (; zip(var_symbols, value_vectors)...)
end

Base.get(m::ModeResult, var_symbol::Symbol) = get(m, (var_symbol,))
Base.get(m::ModeResult, var_symbol::Symbol) = get(m, [var_symbol])

"""
ModeResult(log_density::OptimLogDensity, solution::SciMLBase.OptimizationSolution)
Expand Down
5 changes: 3 additions & 2 deletions test/Aqua.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ module AquaTests
using Aqua: Aqua
using Turing

# TODO(mhauru) We skip testing for method ambiguities because it catches a lot of problems
# in dependencies. Would like to check it for just Turing.jl itself though.
# We test ambiguities separately because it catches a lot of problems
# in dependencies but we test it for Turing.
Aqua.test_ambiguities([Turing])
Aqua.test_all(Turing; ambiguities=false)

end
10 changes: 10 additions & 0 deletions test/mcmc/hmc.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module HMCTests

using ..Models: gdemo_default
using ..ADUtils: ADTypeCheckContext
#using ..Models: gdemo
using ..NumericalTests: check_gdemo, check_numerical
import ..ADUtils
Expand Down Expand Up @@ -324,6 +325,15 @@ ADUtils.install_tapir && import Tapir
# KS will compare the empirical CDFs, which seems like a reasonable thing to do here.
@test pvalue(ApproximateTwoSampleKSTest(vec(results), vec(results_prior))) > 0.001
end

@testset "Check ADType" begin
alg = HMC(0.1, 10; adtype=adbackend)
m = DynamicPPL.contextualize(
gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context)
)
# These will error if the adbackend being used is not the one set.
sample(rng, m, alg, 10)
end
end

end
14 changes: 13 additions & 1 deletion test/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
module OptimisationTests

using ..Models: gdemo, gdemo_default
using ..ADUtils: ADTypeCheckContext
using Distributions
using Distributions.FillArrays: Zeros
using DynamicPPL: DynamicPPL
Expand Down Expand Up @@ -140,7 +141,6 @@ using Turing
gdemo_default, OptimizationOptimJL.LBFGS(); initial_params=true_value
)
m3 = maximum_likelihood(gdemo_default, OptimizationOptimJL.Newton())
# TODO(mhauru) How can we check that the adtype is actually AutoReverseDiff?
m4 = maximum_likelihood(
gdemo_default, OptimizationOptimJL.BFGS(); adtype=AutoReverseDiff()
)
Expand Down Expand Up @@ -616,6 +616,18 @@ using Turing
@assert vcat(get_a[:a], get_b[:b]) == result.values.array
@assert get(result, :c) == (; :c => Array{Float64}[])
end

@testset "ADType" begin
Random.seed!(222)
for adbackend in (AutoReverseDiff(), AutoForwardDiff(), AutoTracker())
m = DynamicPPL.contextualize(
gdemo_default, ADTypeCheckContext(adbackend, gdemo_default.context)
)
# These will error if the adbackend being used is not the one set.
maximum_likelihood(m; adtype=adbackend)
maximum_a_posteriori(m; adtype=adbackend)
end
end
end

end
Loading

0 comments on commit 2f93d75

Please sign in to comment.