-
Notifications
You must be signed in to change notification settings - Fork 162
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #545 from probcomp/enumerative_inference
Add enumerative inference to the inference library
- Loading branch information
Showing
5 changed files
with
290 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# Enumerative Inference | ||
|
||
Enumerative inference can be used to compute the exact posterior distribution for a generative model | ||
with a finite number of discrete random choices, to compute a grid approximation of a continuous | ||
posterior density, or to perform stratified sampling by enumerating over discrete random choices and sampling | ||
the continuous random choices. This functionality is provided by [`enumerative_inference`](@ref). | ||
|
||
```@docs | ||
enumerative_inference | ||
``` | ||
|
||
To construct a rectangular grid of [choice maps](../core/choice_maps.md) and their associated log-volumes to iterate over, use the [`choice_vol_grid`](@ref) function. | ||
|
||
```@docs | ||
choice_vol_grid | ||
``` | ||
|
||
When the space of possible choice maps is not rectangular (e.g. some addresses only exist depending on the values of other addresses), iterators over choice maps and log-volumes can be also be manually constructed. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
""" | ||
(traces, log_norm_weights, lml_est) = enumerative_inference( | ||
model::GenerativeFunction, model_args::Tuple, | ||
observations::ChoiceMap, choice_vol_iter | ||
) | ||
Run enumerative inference over a `model`, given `observations` and an iterator over | ||
choice maps and their associated log-volumes (`choice_vol_iter`), specifying the | ||
choices to be iterated over. An iterator over a grid of choice maps and log-volumes | ||
can be constructed with [`choice_vol_grid`](@ref). | ||
Return an array of traces and associated log-weights with the same shape as | ||
`choice_vol_iter`. The log-weight of each trace is normalized, and corresponds | ||
to the log probability of the volume of sample space that the trace represents. | ||
Also return an estimate of the log marginal likelihood of the observations (`lml_est`). | ||
All addresses in the `observations` choice map must be sampled by the model when | ||
given the model arguments. The same constraint applies to choice maps enumerated | ||
over by `choice_vol_iter`, which must also avoid sharing addresses with the | ||
`observations`. When the choice maps in `choice_vol_iter` do not fully specify | ||
the values of all unobserved random choices, the unspecified choices are sampled | ||
from the internal proposal distribution of the model. | ||
""" | ||
function enumerative_inference( | ||
model::GenerativeFunction{T,U}, model_args::Tuple, | ||
observations::ChoiceMap, choice_vol_iter::I | ||
) where {T,U,I} | ||
if Base.IteratorSize(I) isa Base.HasShape | ||
traces = Array{U}(undef, size(choice_vol_iter)) | ||
log_weights = Array{Float64}(undef, size(choice_vol_iter)) | ||
elseif Base.IteratorSize(I) isa Base.HasLength | ||
traces = Vector{U}(undef, length(choice_vol_iter)) | ||
log_weights = Vector{Float64}(undef, length(choice_vol_iter)) | ||
else | ||
choice_vol_iter = collect(choice_vol_iter) | ||
traces = Vector{U}(undef, length(choice_vol_iter)) | ||
log_weights = Vector{Float64}(undef, length(choice_vol_iter)) | ||
end | ||
for (i, (choices, log_vol)) in enumerate(choice_vol_iter) | ||
constraints = merge(observations, choices) | ||
(traces[i], log_weight) = generate(model, model_args, constraints) | ||
log_weights[i] = log_weight + log_vol | ||
end | ||
log_total_weight = logsumexp(log_weights) | ||
log_normalized_weights = log_weights .- log_total_weight | ||
return (traces, log_normalized_weights, log_total_weight) | ||
end | ||
|
||
""" | ||
choice_vol_grid((addr, vals, [support, dims])::Tuple...; anchor=:midpoint) | ||
Given tuples of the form `(addr, vals, [support, dims])`, construct an iterator | ||
over tuples of the form `(choices::ChoiceMap, log_vol::Real)` via grid enumeration. | ||
Each `addr` is an address of a random choice, and `vals` are the corresponding | ||
values or intervals to enumerate over. The (optional) `support` denotes whether | ||
each random choice is `:discrete` (default) or `:continuous`. This controls how | ||
the grid is constructed: | ||
- `support = :discrete`: The grid iterates over each value in `vals`. | ||
- `support = :continuous` and `dims == Val(1)`: The grid iterates over the | ||
anchors of 1D intervals whose endpoints are given by `vals`. | ||
- `support = :continuous` and `dims == Val(N)` where `N` > 1: The grid iterates | ||
over the anchors of multi-dimensional regions defined `vals`, which is a tuple | ||
of interval endpoints for each dimension. | ||
Continuous choices are assumed to have `dims = Val(1)` dimensions by default. | ||
The `anchor` keyword argument controls which point in each interval is used as | ||
the anchor (`:left`, `:right`, or `:midpoint`). | ||
The log-volume `log_vol` associated with each set of `choices` in the grid is given | ||
by the log-product of the volumes of each continuous region used to construct those | ||
choices. If all addresses enumerated over are `:discrete`, then `log_vol = 0.0`. | ||
""" | ||
function choice_vol_grid(grid_specs::Tuple...; anchor::Symbol=:midpoint) | ||
val_iter = (expand_grid_spec_to_values(spec...; anchor=anchor) | ||
for spec in grid_specs) | ||
val_iter = Iterators.product(val_iter...) | ||
vol_iter = (expand_grid_spec_to_volumes(spec...) for spec in grid_specs) | ||
vol_iter = Iterators.product(vol_iter...) | ||
choice_vol_iter = Iterators.map(zip(val_iter, vol_iter)) do (vals, vols) | ||
return (choicemap(vals...), sum(vols)) | ||
end | ||
return choice_vol_iter | ||
end | ||
|
||
function expand_grid_spec_to_values( | ||
addr, vals, support::Symbol = :discrete, dims::Val{N} = Val(1); | ||
anchor::Symbol = :midpoint | ||
) where {N} | ||
if support == :discrete | ||
return ((addr, v) for v in vals) | ||
elseif support == :continuous && N == 1 | ||
if anchor == :left | ||
vals = @view(vals[begin:end-1]) | ||
elseif anchor == :right | ||
vals = @view(vals[begin+1:end]) | ||
else | ||
vals = @view(vals[begin:end-1]) .+ (diff(vals) ./ 2) | ||
end | ||
return ((addr, v) for v in vals) | ||
elseif support == :continuous && N > 1 | ||
@assert length(vals) == N "Dimension mismatch between `vals` and `dims`" | ||
vals = map(vals) do vs | ||
if anchor == :left | ||
vs = @view(vs[begin:end-1]) | ||
elseif anchor == :right | ||
vs = @view(vs[begin+1:end]) | ||
else | ||
vs = @view(vs[begin:end-1]) .+ (diff(vs) ./ 2) | ||
end | ||
return vs | ||
end | ||
return ((addr, collect(v)) for v in Iterators.product(vals...)) | ||
else | ||
error("Support must be :discrete or :continuous") | ||
end | ||
end | ||
|
||
function expand_grid_spec_to_volumes( | ||
addr, vals, support::Symbol = :discrete, dims::Val{N} = Val(1) | ||
) where {N} | ||
if support == :discrete | ||
return zeros(length(vals)) | ||
elseif support == :continuous && N == 1 | ||
return log.(diff(vals)) | ||
elseif support == :continuous && N > 1 | ||
@assert length(vals) == N "Dimension mismatch between `vals` and `dims`" | ||
diffs = Iterators.product((log.(diff(vs)) for vs in vals)...) | ||
return (sum(ds) for ds in diffs) | ||
else | ||
error("Support must be :discrete or :continuous") | ||
end | ||
end | ||
|
||
export enumerative_inference, choice_vol_grid |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
@testset "enumerative inference" begin | ||
|
||
# polynomial regression model | ||
@gen function poly_model(n::Int, xs) | ||
degree ~ uniform_discrete(1, n) | ||
coeffs = zeros(n+1) | ||
for d in 0:n | ||
coeffs[d+1] = {(:coeff, d)} ~ uniform(-1, 1) | ||
end | ||
ys = zeros(length(xs)) | ||
for (i, x) in enumerate(xs) | ||
x_powers = x .^ (0:n) | ||
y_mean = sum(coeffs[d+1] * x_powers[d+1] for d in 0:degree) | ||
ys[i] = {(:y, i)} ~ normal(y_mean, 0.1) | ||
end | ||
return ys | ||
end | ||
|
||
# synthetic dataset | ||
coeffs = [0.5, 0.1, -0.5] | ||
xs = collect(0.5:0.5:3.0) | ||
ys = [(coeffs' * [x .^ d for d in 0:2]) for x in xs] | ||
|
||
observations = choicemap() | ||
for (i, y) in enumerate(ys) | ||
observations[(:y, i)] = y | ||
end | ||
|
||
# test construction of choicemap-volume grid | ||
grid = choice_vol_grid( | ||
(:degree, 1:2), | ||
((:coeff, 0), -1:0.2:1, :continuous), | ||
((:coeff, 1), -1:0.2:1, :continuous), | ||
((:coeff, 2), -1:0.2:1, :continuous), | ||
anchor = :midpoint | ||
) | ||
|
||
@test size(grid) == (2, 10, 10, 10) | ||
@test length(grid) == 2000 | ||
|
||
choices, log_vol = first(grid) | ||
@test choices == choicemap( | ||
(:degree, 1), | ||
((:coeff, 0), -0.9), ((:coeff, 1), -0.9), ((:coeff, 2), -0.9), | ||
) | ||
@test log_vol ≈ log(0.2^3) | ||
|
||
test_choices(n::Int, cs) = | ||
cs[:degree] in 1:n && all(-1.0 <= cs[(:coeff, d)] <= 1.0 for d in 1:n) | ||
|
||
@test all(test_choices(2, choices) for (choices, _) in grid) | ||
@test all(log_vol ≈ log(0.2^3) for (_, log_vol) in grid) | ||
|
||
# run enumerative inference over grid | ||
traces, log_norm_weights, lml_est = | ||
enumerative_inference(poly_model, (2, xs), observations, grid) | ||
|
||
@test size(traces) == (2, 10, 10, 10) | ||
@test length(traces) == 2000 | ||
@test all(test_choices(2, tr) for tr in traces) | ||
|
||
# test that log-weights are as expected | ||
log_joint_weights = [get_score(tr) + log(0.2^3) for tr in traces] | ||
lml_expected = logsumexp(log_joint_weights) | ||
@test lml_est ≈ lml_expected | ||
@test all((jw - lml_expected) ≈ w for (jw, w) in zip(log_joint_weights, log_norm_weights)) | ||
|
||
# test that polynomial is most likely quadratic | ||
degree_probs = sum(exp.(log_norm_weights), dims=(2, 3, 4)) | ||
@test argmax(vec(degree_probs)) == 2 | ||
|
||
# test that MAP trace recovers the original coefficients | ||
map_trace_idx = argmax(log_norm_weights) | ||
map_trace = traces[map_trace_idx] | ||
@test map_trace[:degree] == 2 | ||
@test map_trace[(:coeff, 0)] == 0.5 | ||
@test map_trace[(:coeff, 1)] == 0.1 | ||
@test map_trace[(:coeff, 2)] == -0.5 | ||
|
||
# 2D mixture of normals | ||
@gen function mixture_model() | ||
sign ~ bernoulli(0.5) | ||
mu = sign ? fill(0.5, 2) : fill(-0.5, 2) | ||
z ~ broadcasted_normal(mu, ones(2)) | ||
end | ||
|
||
# test construction of grid with 2D random variable | ||
grid = choice_vol_grid( | ||
(:sign, [false, true]), | ||
(:z, (-2.0:0.1:2.0, -2.0:0.1:2.0), :continuous, Val(2)), | ||
anchor = :left | ||
) | ||
|
||
@test size(grid) == (2, 40, 40) | ||
@test length(grid) == 3200 | ||
|
||
choices, log_vol = first(grid) | ||
@test choices == choicemap((:sign, false), (:z, [-2.0, -2.0])) | ||
@test log_vol ≈ log(0.1^2) | ||
|
||
@test all(all([-2.0, -2.0] .<= choices[:z] .<= [2.0, 2.0]) for (choices, _) in grid) | ||
@test all(log_vol ≈ log(0.1^2) for (_, log_vol) in grid) | ||
|
||
# run enumerative inference over grid | ||
traces, log_norm_weights, lml_est = | ||
enumerative_inference(mixture_model, (), choicemap(), grid) | ||
|
||
@test size(traces) == (2, 40, 40) | ||
@test length(traces) == 3200 | ||
@test all(all([-2.0, -2.0] .<= tr[:z] .<= [2.0, 2.0]) for tr in traces) | ||
|
||
# test that log-weights are as expected | ||
function expected_logpdf(tr) | ||
x, y = tr[:z] | ||
mu = tr[:sign] ? 0.5 : -0.5 | ||
return log(0.5) + logpdf(normal, x, mu, 1.0) + logpdf(normal, y, mu, 1.0) | ||
end | ||
|
||
log_joint_weights = [expected_logpdf(tr) + log(0.1^2) for tr in traces] | ||
lml_expected = logsumexp(log_joint_weights) | ||
@test lml_est ≈ lml_expected | ||
@test all((jw - lml_expected) ≈ w for (jw, w) in zip(log_joint_weights, log_norm_weights)) | ||
|
||
# test that maximal log-weights are at modes | ||
max_log_weight = maximum(log_norm_weights) | ||
max_idxs = findall(log_norm_weights .== max_log_weight) | ||
|
||
max_trace_1 = traces[max_idxs[1]] | ||
@test max_trace_1[:sign] == false | ||
@test max_trace_1[:z] == [-0.5, -0.5] | ||
|
||
max_trace_2 = traces[max_idxs[2]] | ||
@test max_trace_2[:sign] == true | ||
@test max_trace_2[:z] == [0.5, 0.5] | ||
|
||
end |