-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Substantial updates to tutorial 01_gaussian-mixture-model #439
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, I made a few comments 🙂
Where we sum the components with `logsumexp` from the [`StatsFuns.jl` package](https://github.com/JuliaStats/StatsFuns.jl). | ||
|
||
|
||
The manually incremented likelihood can be added to the log-probability with `Turing.@addlogprob!`, giving us the following model: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO we should not recommend the use of Turing.@addlogprob!
in it's so easy to misuse and to get (silently) wrong results because it operates completely outside of the ~
logic in Turing/DynamicPPL. Instead I think usually one should use ~
with a (possibly custom) distribution.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good! I initially wasn't going to include that section for basically the reasons you bring up, but I ended up including it (even though I don't actually sample from that model) to motivate what's going on with the MixtureModel lpdf.
I can replace it with a custom distribution (although this might be a little long for a model that's really just exposition), or omit it entirely.
Now, re-running our model, we can see that the assigned means are consistent across chains: | ||
|
||
```julia | ||
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains; discard_initial = burn); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe let's keep the tutorial simple and avoid surprising warnings in singlethreaded environments:
chains = sample(model, sampler, MCMCThreads(), nsamples, nchains; discard_initial = burn); | |
chains = sample(model, sampler, nsamples, nchains; discard_initial = burn); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actually un-resolving this because I don't think it works? As it is right now, I'm not sure if Turing allows multiple chains without specifying a type of parallelism.
The documentation, if it's current, seems to suggest I should do something like:
chains = mapreduce(c -> sample(model_fun, sampler, 1000), chainscat, 1:num_chains)
I'm not sure if that's worth it just to get rid of the warning — let me know what you think though.
tutorials/01-gaussian-mixture-model/01_gaussian-mixture-model.jmd
Outdated
Show resolved
Hide resolved
tutorials/01-gaussian-mixture-model/01_gaussian-mixture-model.jmd
Outdated
Show resolved
Hide resolved
tutorials/01-gaussian-mixture-model/01_gaussian-mixture-model.jmd
Outdated
Show resolved
Hide resolved
tutorials/01-gaussian-mixture-model/01_gaussian-mixture-model.jmd
Outdated
Show resolved
Hide resolved
# Return sample_class(yi) for fixed μ, w. | ||
function sample_class(xi) | ||
lvec = [(logpdf(d, xi) + log(w[i])) for (i, d) in enumerate(dists)] | ||
rand(Categorical(exp.(lvec .- logsumexp(lvec)))) | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should be defined outside of the model and probably use softmax
or softmax!
directly.
tutorials/01-gaussian-mixture-model/01_gaussian-mixture-model.jmd
Outdated
Show resolved
Hide resolved
tutorials/01-gaussian-mixture-model/01_gaussian-mixture-model.jmd
Outdated
Show resolved
Hide resolved
tutorials/01-gaussian-mixture-model/01_gaussian-mixture-model.jmd
Outdated
Show resolved
Hide resolved
fix sample call Co-authored-by: David Widmann <[email protected]>
remove use of MCMCThread() Co-authored-by: David Widmann <[email protected]>
Remove Bijectors import Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
Co-authored-by: David Widmann <[email protected]>
I think maybe this should be closed and revisited when #441 is done? fwiw the current thing that's keeping this frozen is the multithreading stuff. If we want to stay away from: chains = sample(model, sampler, MCMCThreads(), nsamples, nchains; discard_initial = burn); to avoid warnings in single threaded environments, we'll need to update a bunch of tutorials, because this is pretty common across all the tutorials. |
Thanks, @JasonPekos, for the PR. Would you like to migrate your changes here to #441? |
Yup, will do. |
First PR! Hope everything is ok.
As discussed in the slack, this PR adds the following significant changes to this tutorial
ordered()
fromBijectors.jl
in making the model identifiable (currently it is multimodal, and the seed is just lucky.)Turing.@addlogprob!
~ MixtureModel(dists, weights)
generated_quantities()
There are also a few minor changes: