diff --git a/Project.toml b/Project.toml index 33d4be908a..62dfe4b204 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.33.3" +version = "0.34.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 5f05121e74..b7bdf206b9 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -213,7 +213,9 @@ end # Extended in contrib/inference/abstractmcmc.jl getstats(t) = nothing -struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}} +abstract type AbstractTransition end + +struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}} <: AbstractTransition θ :: T lp :: F # TODO: merge `lp` with `stat` stat :: S @@ -409,7 +411,7 @@ getlogevidence(transitions, sampler, state) = missing # Default MCMCChains.Chains constructor. # This is type piracy (at least for SampleFromPrior). function AbstractMCMC.bundle_samples( - ts::Vector, + ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}}, model::AbstractModel, spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior}, state, @@ -472,7 +474,7 @@ end # This is type piracy (for SampleFromPrior). function AbstractMCMC.bundle_samples( - ts::Vector, + ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}}, model::AbstractModel, spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior}, state, diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 8d38660c67..02a53766e0 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -45,7 +45,7 @@ end SMC(space::Symbol...) = SMC(space) SMC(space::Tuple) = SMC(AdvancedPS.ResampleWithESSThreshold(), space) -struct SMCTransition{T,F<:AbstractFloat} +struct SMCTransition{T,F<:AbstractFloat} <: AbstractTransition "The parameters for any given sample." θ::T "The joint log probability of the sample (NOTE: does not work, always set to zero)." @@ -222,7 +222,7 @@ end const CSMC = PG # type alias of PG as Conditional SMC -struct PGTransition{T,F<:AbstractFloat} +struct PGTransition{T,F<:AbstractFloat} <: AbstractTransition "The parameters for any given sample." θ::T "The joint log probability of the sample (NOTE: does not work, always set to zero)." diff --git a/src/mcmc/sghmc.jl b/src/mcmc/sghmc.jl index 84a9c18f3c..fbc4b11868 100644 --- a/src/mcmc/sghmc.jl +++ b/src/mcmc/sghmc.jl @@ -193,7 +193,7 @@ function SGLD( return SGLD{typeof(adtype),space,typeof(stepsize)}(stepsize, adtype) end -struct SGLDTransition{T,F<:Real} +struct SGLDTransition{T,F<:Real} <: AbstractTransition "The parameters for any given sample." θ::T "The joint log probability of the sample." diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index e069897920..eecfcad22b 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -277,13 +277,13 @@ StatsBase.loglikelihood(m::ModeResult) = m.lp """ Base.get(m::ModeResult, var_symbol::Symbol) - Base.get(m::ModeResult, var_symbols) + Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol}) Return the values of all the variables with the symbol(s) `var_symbol` in the mode result `m`. The return value is a `NamedTuple` with `var_symbols` as the key(s). The second -argument should be either a `Symbol` or an iterator of `Symbol`s. +argument should be either a `Symbol` or a vector of `Symbol`s. """ -function Base.get(m::ModeResult, var_symbols) +function Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol}) log_density = m.f # Get all the variable names in the model. This is the same as the list of keys in # m.values, but they are more convenient to filter when they are VarNames rather than @@ -304,7 +304,7 @@ function Base.get(m::ModeResult, var_symbols) return (; zip(var_symbols, value_vectors)...) end -Base.get(m::ModeResult, var_symbol::Symbol) = get(m, (var_symbol,)) +Base.get(m::ModeResult, var_symbol::Symbol) = get(m, [var_symbol]) """ ModeResult(log_density::OptimLogDensity, solution::SciMLBase.OptimizationSolution) diff --git a/test/Aqua.jl b/test/Aqua.jl index 0b536770b2..e159cae9ca 100644 --- a/test/Aqua.jl +++ b/test/Aqua.jl @@ -3,8 +3,9 @@ module AquaTests using Aqua: Aqua using Turing -# TODO(mhauru) We skip testing for method ambiguities because it catches a lot of problems -# in dependencies. Would like to check it for just Turing.jl itself though. +# We test ambiguities separately because it catches a lot of problems +# in dependencies but we test it for Turing. +Aqua.test_ambiguities([Turing]) Aqua.test_all(Turing; ambiguities=false) end