From fd21439dbe395c6252f5a43ba25a3376a86b4663 Mon Sep 17 00:00:00 2001 From: Abhinav Singh <73834077+abhinavsns@users.noreply.github.com> Date: Thu, 18 Jul 2024 11:57:36 +0200 Subject: [PATCH 1/5] Enabling aqua ambiguity testing for Turing We test ambiguities only for Turing and not its dependencies. --- test/Aqua.jl | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/test/Aqua.jl b/test/Aqua.jl index 0b536770b2..b3ff8be4b3 100644 --- a/test/Aqua.jl +++ b/test/Aqua.jl @@ -3,8 +3,10 @@ module AquaTests using Aqua: Aqua using Turing -# TODO(mhauru) We skip testing for method ambiguities because it catches a lot of problems -# in dependencies. Would like to check it for just Turing.jl itself though. +# We test ambiguities separately because it catches a lot of problems +# in dependencies but we test it for Turing. +Aqua.test_ambiguities([Turing]) Aqua.test_all(Turing; ambiguities=false) + end From 7d7116b01e08e82428de11d834a1285ed4d0d270 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Tue, 30 Jul 2024 17:42:46 +0100 Subject: [PATCH 2/5] Format --- test/Aqua.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/Aqua.jl b/test/Aqua.jl index b3ff8be4b3..e159cae9ca 100644 --- a/test/Aqua.jl +++ b/test/Aqua.jl @@ -8,5 +8,4 @@ using Turing Aqua.test_ambiguities([Turing]) Aqua.test_all(Turing; ambiguities=false) - end From a6e01f8517714df56b80f0c5fdb1baa850f64949 Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 14 Aug 2024 00:01:07 +0100 Subject: [PATCH 3/5] Fix bundle_samples method ambiguity Concretely: 1. Creating an `AbstractTransition` type which all the Transitions in Turing subtype. 2. Modifying the type signature of bundle_samples to take a Vector{<:Union{AbstractTransition,AbstractVarInfo}} as the first argument. The AbstractVarInfo case occurs when sampling with Prior(), so the type signature of this argument mirrors that of the Sampler in the same function. --- src/mcmc/Inference.jl | 8 +++++--- src/mcmc/particle_mcmc.jl | 4 ++-- src/mcmc/sghmc.jl | 2 +- 3 files changed, 8 insertions(+), 6 deletions(-) diff --git a/src/mcmc/Inference.jl b/src/mcmc/Inference.jl index 5f05121e74..b7bdf206b9 100644 --- a/src/mcmc/Inference.jl +++ b/src/mcmc/Inference.jl @@ -213,7 +213,9 @@ end # Extended in contrib/inference/abstractmcmc.jl getstats(t) = nothing -struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}} +abstract type AbstractTransition end + +struct Transition{T, F<:AbstractFloat, S<:Union{NamedTuple, Nothing}} <: AbstractTransition θ :: T lp :: F # TODO: merge `lp` with `stat` stat :: S @@ -409,7 +411,7 @@ getlogevidence(transitions, sampler, state) = missing # Default MCMCChains.Chains constructor. # This is type piracy (at least for SampleFromPrior). function AbstractMCMC.bundle_samples( - ts::Vector, + ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}}, model::AbstractModel, spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior}, state, @@ -472,7 +474,7 @@ end # This is type piracy (for SampleFromPrior). function AbstractMCMC.bundle_samples( - ts::Vector, + ts::Vector{<:Union{AbstractTransition,AbstractVarInfo}}, model::AbstractModel, spl::Union{Sampler{<:InferenceAlgorithm},SampleFromPrior}, state, diff --git a/src/mcmc/particle_mcmc.jl b/src/mcmc/particle_mcmc.jl index 8d38660c67..02a53766e0 100644 --- a/src/mcmc/particle_mcmc.jl +++ b/src/mcmc/particle_mcmc.jl @@ -45,7 +45,7 @@ end SMC(space::Symbol...) = SMC(space) SMC(space::Tuple) = SMC(AdvancedPS.ResampleWithESSThreshold(), space) -struct SMCTransition{T,F<:AbstractFloat} +struct SMCTransition{T,F<:AbstractFloat} <: AbstractTransition "The parameters for any given sample." θ::T "The joint log probability of the sample (NOTE: does not work, always set to zero)." @@ -222,7 +222,7 @@ end const CSMC = PG # type alias of PG as Conditional SMC -struct PGTransition{T,F<:AbstractFloat} +struct PGTransition{T,F<:AbstractFloat} <: AbstractTransition "The parameters for any given sample." θ::T "The joint log probability of the sample (NOTE: does not work, always set to zero)." diff --git a/src/mcmc/sghmc.jl b/src/mcmc/sghmc.jl index 84a9c18f3c..fbc4b11868 100644 --- a/src/mcmc/sghmc.jl +++ b/src/mcmc/sghmc.jl @@ -193,7 +193,7 @@ function SGLD( return SGLD{typeof(adtype),space,typeof(stepsize)}(stepsize, adtype) end -struct SGLDTransition{T,F<:Real} +struct SGLDTransition{T,F<:Real} <: AbstractTransition "The parameters for any given sample." θ::T "The joint log probability of the sample." From 36accade0ca4c3161fe2dfcfb3024758df5efb9e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 14 Aug 2024 15:34:38 +0100 Subject: [PATCH 4/5] Fix get() ambiguities Done by: 1. Constraining the type parameter to AbstractVector{Symbol} 2. Modifying the method below it to use a vector instead of a tuple --- src/optimisation/Optimisation.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index e069897920..eecfcad22b 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -277,13 +277,13 @@ StatsBase.loglikelihood(m::ModeResult) = m.lp """ Base.get(m::ModeResult, var_symbol::Symbol) - Base.get(m::ModeResult, var_symbols) + Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol}) Return the values of all the variables with the symbol(s) `var_symbol` in the mode result `m`. The return value is a `NamedTuple` with `var_symbols` as the key(s). The second -argument should be either a `Symbol` or an iterator of `Symbol`s. +argument should be either a `Symbol` or a vector of `Symbol`s. """ -function Base.get(m::ModeResult, var_symbols) +function Base.get(m::ModeResult, var_symbols::AbstractVector{Symbol}) log_density = m.f # Get all the variable names in the model. This is the same as the list of keys in # m.values, but they are more convenient to filter when they are VarNames rather than @@ -304,7 +304,7 @@ function Base.get(m::ModeResult, var_symbols) return (; zip(var_symbols, value_vectors)...) end -Base.get(m::ModeResult, var_symbol::Symbol) = get(m, (var_symbol,)) +Base.get(m::ModeResult, var_symbol::Symbol) = get(m, [var_symbol]) """ ModeResult(log_density::OptimLogDensity, solution::SciMLBase.OptimizationSolution) From 7adc2ab1e59b0b40f9fe6bae81865ef8e8ca139e Mon Sep 17 00:00:00 2001 From: Penelope Yong Date: Wed, 14 Aug 2024 15:35:44 +0100 Subject: [PATCH 5/5] Bump to 0.34.0 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 33d4be908a..62dfe4b204 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Turing" uuid = "fce5fe82-541a-59a6-adf8-730c64b5f9a0" -version = "0.33.3" +version = "0.34.0" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"