Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix gamma mixture #222

Closed
wants to merge 9 commits into from
Closed
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 86 additions & 26 deletions demo/Gamma Mixture.ipynb

Large diffs are not rendered by default.

23 changes: 22 additions & 1 deletion src/approximations/importance.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ convenient functions to generate samples and weights to approximate expectations
- `rng`: random number generator objects
- `nsamples`: number of samples generated by default
"""
struct ImportanceSamplingApproximation{T, R}
struct ImportanceSamplingApproximation{T, R} <: AbstractFormConstraint
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ImportanceSamplingApproximation should not be declared as an AbstractFormConstraint. This is wrong from the semantic point of view. There should be an extra MomentsApproximationFormConstraint (or something similar).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rng :: R
nsamples :: Int
bsamples :: Vector{T}
Expand All @@ -22,6 +22,10 @@ struct ImportanceSamplingApproximation{T, R}
rsamples :: Vector{T}
end

function ImportanceSamplingApproximation(nsamples::Int; resampling::Bool = true)
return ImportanceSamplingApproximation(Float64, Random.GLOBAL_RNG, nsamples; resampling = resampling)
end

function ImportanceSamplingApproximation(rng::R, nsamples::Int; resampling::Bool = true) where {R}
return ImportanceSamplingApproximation(Float64, rng, nsamples; resampling = resampling)
end
Expand Down Expand Up @@ -83,3 +87,20 @@ function approximate_meancov(approximation::ImportanceSamplingApproximation, g::

return m, v
end

is_point_mass_form_constraint(::ImportanceSamplingApproximation) = false
default_form_check_strategy(::ImportanceSamplingApproximation) = FormConstraintCheckEach()
default_prod_constraint(::ImportanceSamplingApproximation) = ProdGeneric()

function constrain_form(approximation::ImportanceSamplingApproximation, distribution::DistProduct)
m, v = approximate_meancov(approximation, x -> pdf(getright(distribution), x), getleft(distribution))

a = m^2 / v
b = m / v
return GammaShapeRate(a, b)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be changed somehow

end

constrain_form(::ImportanceSamplingApproximation, distribution::Any) = distribution

make_form_constraint(::Type{ImportanceSamplingApproximation}, args...; kwargs...) =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be removed in favor of separate MomentsApproximationFormConstraint

ImportanceSamplingApproximation(args..., kwargs...)
4 changes: 4 additions & 0 deletions src/distributions/gamma_shape_rate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -65,3 +65,7 @@ end

Distributions.pdf(dist::GammaShapeRate, x::Real) = (rate(dist)^shape(dist)) / gamma(shape(dist)) * x^(shape(dist) - 1) * exp(-rate(dist) * x)
Distributions.logpdf(dist::GammaShapeRate, x::Real) = shape(dist) * log(rate(dist)) - loggamma(shape(dist)) + (shape(dist) - 1) * log(x) - rate(dist) * x

function Random.rand(rng::AbstractRNG, dist::GammaShapeRate)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lets also add more rand methods (which accepts size and inplace versions).

return rand(rng, convert(GammaShapeScale, dist))
end
14 changes: 13 additions & 1 deletion src/distributions/sample_list.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,7 @@ end
weightedmean(sl::SampleList) = first(weightedmean_precision(sl))

mean(::typeof(log), sl::SampleList) = sample_list_logmean(variate_form(sl), sl)
mean(::typeof(loggamma), sl::SampleList) = sample_list_loggammamean(variate_form(sl), sl)
mean(::typeof(xtlog), sl::SampleList) = sample_list_meanlogmean(variate_form(sl), sl)
mean(::typeof(mirrorlog), sl::SampleList) = sample_list_mirroredlogmean(variate_form(sl), sl)

Expand Down Expand Up @@ -425,7 +426,6 @@ function approximate_prod_with_sample_list(
else
SampleListMeta(nothing, nothing, nothing, log_integrand)
end

return SampleList(Val(D), get_linear_samples(rsamples), rweights, meta)
end

Expand Down Expand Up @@ -515,6 +515,8 @@ function sample_list_covm! end
function sample_list_mean_var end
# Compute E[log(x)]
function sample_list_logmean end
# Compute E[log Γ(x)]
function sample_list_loggammamean end
# Compute E[xlog(x)]
function sample_list_meanlogmean end
# Compute E[log(1 - x)]
Expand Down Expand Up @@ -554,6 +556,16 @@ function sample_list_logmean(::Type{Univariate}, sl::SampleList)
return logμ
end

function sample_list_loggammamean(::Type{Univariate}, sl::SampleList)
n, samples, weights = get_data(sl)
# @show weights
logμ = sample_list_zero_element(sl)
for i in 1:n
logμ += weights[i] * loggamma(samples[i])
end
return logμ
end

function sample_list_meanlogmean(::Type{Univariate}, sl::SampleList)
n, samples, weights = get_data(sl)
μlogμ = sample_list_zero_element(sl)
Expand Down
2 changes: 1 addition & 1 deletion src/nodes/gamma_mixture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ end
q_out::Any,
q_switch::Any,
q_a::NTuple{N, Any},
q_b::NTuple{N, GammaShapeRate}
q_b::NTuple{N, GammaDistributionsFamily}
) where {N} = begin
z_bar = probvec(q_switch)
return mapreduce(+, 1:N, init = 0.0) do i
Expand Down
12 changes: 5 additions & 7 deletions src/nodes/gamma_shape_rate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@ import SpecialFunctions: loggamma

@node GammaShapeRate Stochastic [out, (α, aliases = [shape]), (β, aliases = [rate])]

@average_energy GammaShapeRate (q_out::Any, q_α::PointMass, q_β::Any) = begin
mean(loggamma, q_α) - mean(q_α) * mean(log, q_β) - (mean(q_α) - 1.0) * mean(log, q_out) + mean(q_β) * mean(q_out)
end

@average_energy GammaShapeRate (q_out::Any, q_α::GammaDistributionsFamily, q_β::Any) = begin
mean(loggamma, q_α) - mean(q_α) * mean(log, q_β) - (mean(q_α) - 1.0) * mean(log, q_out) + mean(q_β) * mean(q_out)
end
@average_energy GammaShapeRate (q_out::Any, q_α::Union{PointMass, SampleList, GammaDistributionsFamily}, q_β::Any) =
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wonder if can simply put Any there

begin
mean(loggamma, q_α) - mean(q_α) * mean(log, q_β) - (mean(q_α) - 1.0) * mean(log, q_out) +
mean(q_β) * mean(q_out)
end