Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move probability interface docs from DynamicPPL to this repo #528

Merged
merged 1 commit into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

julia_version = "1.10.5"
manifest_format = "2.0"
project_hash = "e0661388214f9e03749b3fccc86939dfd3853246"
project_hash = "24c19d9af05129bd51c95e6af97ff0385a185787"

[[deps.ADTypes]]
git-tree-sha1 = "016833eb52ba2d6bea9fcb50ca295980e728ee24"
Expand Down
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@ CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataDeps = "124859b0-ceae-595e-8997-d05f6a7a8dfe"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
DifferentialEquations = "0c46a032-eb83-5123-abaf-570d42b7fbaa"
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb"
DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand Down
1 change: 1 addition & 0 deletions _quarto.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ website:
contents:
- tutorials/docs-10-using-turing-autodiff/index.qmd
- tutorials/usage-custom-distribution/index.qmd
- tutorials/usage-probability-interface/index.qmd
- tutorials/usage-modifying-logprob/index.qmd
- tutorials/usage-generated-quantities/index.qmd
- tutorials/docs-17-mode-estimation/index.qmd
Expand Down
182 changes: 182 additions & 0 deletions tutorials/usage-probability-interface/index.qmd
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
---
title: Querying Model Probabilities
engine: julia
---

```{julia}
#| echo: false
#| output: false
using Pkg;
Pkg.instantiate();
```

The easiest way to manipulate and query Turing models is via the DynamicPPL probability interface.

Let's use a simple model of normally-distributed data as an example.

```{julia}
using Turing
using LinearAlgebra: I
using Random

@model function gdemo(n)
μ ~ Normal(0, 1)
x ~ MvNormal(fill(μ, n), I)
end
```

We generate some data using `μ = 0`:

```{julia}
Random.seed!(1776)
dataset = randn(100)
dataset[1:5]
```

## Conditioning and Deconditioning

Bayesian models can be transformed with two main operations, conditioning and deconditioning (also known as marginalization).
Conditioning takes a variable and fixes its value as known.
We do this by passing a model and a collection of conditioned variables to `|`, or its alias, `condition`:

```{julia}
# (equivalently)
# conditioned_model = condition(gdemo(length(dataset)), (x=dataset, μ=0))
conditioned_model = gdemo(length(dataset)) | (x=dataset, μ=0)
```

This operation can be reversed by applying `decondition`:

```{julia}
original_model = decondition(conditioned_model)
```

We can also decondition only some of the variables:

```{julia}
partially_conditioned = decondition(conditioned_model, :μ)
```

We can see which of the variables in a model have been conditioned with `DynamicPPL.conditioned`:

```{julia}
DynamicPPL.conditioned(partially_conditioned)
```

::: {.callout-note}
Sometimes it is helpful to define convenience functions for conditioning on some variable(s).
For instance, in this example we might want to define a version of `gdemo` that conditions on some observations of `x`:

```julia
gdemo(x::AbstractVector{<:Real}) = gdemo(length(x)) | (; x)
```

For illustrative purposes, however, we do not use this function in the examples below.
:::

## Probabilities and Densities

We often want to calculate the (unnormalized) probability density for an event.
This probability might be a prior, a likelihood, or a posterior (joint) density.
DynamicPPL provides convenient functions for this.
To begin, let's define a model `gdemo`, condition it on a dataset, and draw a sample.
The returned sample only contains `μ`, since the value of `x` has already been fixed:

```{julia}
model = gdemo(length(dataset)) | (x=dataset,)

Random.seed!(124)
sample = rand(model)
```

We can then calculate the joint probability of a set of samples (here drawn from the prior) with `logjoint`.

```{julia}
logjoint(model, sample)
```

For models with many variables `rand(model)` can be prohibitively slow since it returns a `NamedTuple` of samples from the prior distribution of the unconditioned variables.
We recommend working with samples of type `DataStructures.OrderedDict` in this case:

```{julia}
using DataStructures: OrderedDict

Random.seed!(124)
sample_dict = rand(OrderedDict, model)
```

`logjoint` can also be used on this sample:

```{julia}
logjoint(model, sample_dict)
```

The prior probability and the likelihood of a set of samples can be calculated with the functions `logprior` and `loglikelihood` respectively.
The log joint probability is the sum of these two quantities:

```{julia}
logjoint(model, sample) ≈ loglikelihood(model, sample) + logprior(model, sample)
```

```{julia}
logjoint(model, sample_dict) ≈ loglikelihood(model, sample_dict) + logprior(model, sample_dict)
```

## Example: Cross-validation

To give an example of the probability interface in use, we can use it to estimate the performance of our model using cross-validation.
In cross-validation, we split the dataset into several equal parts.
Then, we choose one of these sets to serve as the validation set.
Here, we measure fit using the cross entropy (Bayes loss).[^1]
(For the sake of simplicity, in the following code, we enforce that `nfolds` must divide the number of data points.
For a more competent implementation, see [MLUtils.jl](https://juliaml.github.io/MLUtils.jl/dev/api/#MLUtils.kfolds).)

```{julia}
# Calculate the train/validation splits across `nfolds` partitions, assume `length(dataset)` divides `nfolds`
function kfolds(dataset::Array{<:Real}, nfolds::Int)
fold_size, remaining = divrem(length(dataset), nfolds)
if remaining != 0
error("The number of folds must divide the number of data points.")
end
first_idx = firstindex(dataset)
last_idx = lastindex(dataset)
splits = map(0:(nfolds - 1)) do i
start_idx = first_idx + i * fold_size
end_idx = start_idx + fold_size
train_set_indices = [first_idx:(start_idx - 1); end_idx:last_idx]
return (view(dataset, train_set_indices), view(dataset, start_idx:(end_idx - 1)))
end
return splits
end

function cross_val(
dataset::Vector{<:Real};
nfolds::Int=5,
nsamples::Int=1_000,
rng::Random.AbstractRNG=Random.default_rng(),
)
# Initialize `loss` in a way such that the loop below does not change its type
model = gdemo(1) | (x=[first(dataset)],)
loss = zero(logjoint(model, rand(rng, model)))

for (train, validation) in kfolds(dataset, nfolds)
# First, we train the model on the training set, i.e., we obtain samples from the posterior.
# For normally-distributed data, the posterior can be computed in closed form.
# For general models, however, typically samples will be generated using MCMC with Turing.
posterior = Normal(mean(train), 1)
samples = rand(rng, posterior, nsamples)

# Evaluation on the validation set.
validation_model = gdemo(length(validation)) | (x=validation,)
loss += sum(samples) do sample
logjoint(validation_model, (μ=sample,))
end
end

return loss
end

cross_val(dataset)
```

[^1]: See [ParetoSmooth.jl](https://github.com/TuringLang/ParetoSmooth.jl) for a faster and more accurate implementation of cross-validation than the one provided here.
Loading