Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add enumerative inference to the inference library #545

Merged
merged 4 commits into from
Nov 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pages = [
"Custom Generative Functions" => "ref/modeling/custom_gen_fns.md",
],
"Inference Library" => [
"Enumerative Inference" => "ref/inference/enumerative.md",
"Importance Sampling" => "ref/inference/importance.md",
"Markov Chain Monte Carlo" => "ref/inference/mcmc.md",
"Particle Filtering & SMC" => "ref/inference/pf.md",
Expand Down
18 changes: 18 additions & 0 deletions docs/src/ref/inference/enumerative.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Enumerative Inference

Enumerative inference can be used to compute the exact posterior distribution for a generative model
with a finite number of discrete random choices, to compute a grid approximation of a continuous
posterior density, or to perform stratified sampling by enumerating over discrete random choices and sampling
the continuous random choices. This functionality is provided by [`enumerative_inference`](@ref).

```@docs
enumerative_inference
```

To construct a rectangular grid of [choice maps](../core/choice_maps.md) and their associated log-volumes to iterate over, use the [`choice_vol_grid`](@ref) function.

```@docs
choice_vol_grid
```

When the space of possible choice maps is not rectangular (e.g. some addresses only exist depending on the values of other addresses), iterators over choice maps and log-volumes can be also be manually constructed.
134 changes: 134 additions & 0 deletions src/inference/enumerative.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
(traces, log_norm_weights, lml_est) = enumerative_inference(
model::GenerativeFunction, model_args::Tuple,
observations::ChoiceMap, choice_vol_iter
)

Run enumerative inference over a `model`, given `observations` and an iterator over
choice maps and their associated log-volumes (`choice_vol_iter`), specifying the
choices to be iterated over. An iterator over a grid of choice maps and log-volumes
can be constructed with [`choice_vol_grid`](@ref).

Return an array of traces and associated log-weights with the same shape as
`choice_vol_iter`. The log-weight of each trace is normalized, and corresponds
to the log probability of the volume of sample space that the trace represents.
Also return an estimate of the log marginal likelihood of the observations (`lml_est`).

All addresses in the `observations` choice map must be sampled by the model when
given the model arguments. The same constraint applies to choice maps enumerated
over by `choice_vol_iter`, which must also avoid sharing addresses with the
`observations`. When the choice maps in `choice_vol_iter` do not fully specify
the values of all unobserved random choices, the unspecified choices are sampled
from the internal proposal distribution of the model.
"""
function enumerative_inference(
model::GenerativeFunction{T,U}, model_args::Tuple,
observations::ChoiceMap, choice_vol_iter::I
) where {T,U,I}
if Base.IteratorSize(I) isa Base.HasShape
traces = Array{U}(undef, size(choice_vol_iter))
log_weights = Array{Float64}(undef, size(choice_vol_iter))
elseif Base.IteratorSize(I) isa Base.HasLength
traces = Vector{U}(undef, length(choice_vol_iter))
log_weights = Vector{Float64}(undef, length(choice_vol_iter))
else
choice_vol_iter = collect(choice_vol_iter)
traces = Vector{U}(undef, length(choice_vol_iter))
log_weights = Vector{Float64}(undef, length(choice_vol_iter))
end
for (i, (choices, log_vol)) in enumerate(choice_vol_iter)
constraints = merge(observations, choices)
(traces[i], log_weight) = generate(model, model_args, constraints)
log_weights[i] = log_weight + log_vol
end
log_total_weight = logsumexp(log_weights)
log_normalized_weights = log_weights .- log_total_weight
return (traces, log_normalized_weights, log_total_weight)
end

"""
choice_vol_grid((addr, vals, [support, dims])::Tuple...; anchor=:midpoint)

Given tuples of the form `(addr, vals, [support, dims])`, construct an iterator
over tuples of the form `(choices::ChoiceMap, log_vol::Real)` via grid enumeration.

Each `addr` is an address of a random choice, and `vals` are the corresponding
values or intervals to enumerate over. The (optional) `support` denotes whether
each random choice is `:discrete` (default) or `:continuous`. This controls how
the grid is constructed:
- `support = :discrete`: The grid iterates over each value in `vals`.
- `support = :continuous` and `dims == Val(1)`: The grid iterates over the
anchors of 1D intervals whose endpoints are given by `vals`.
- `support = :continuous` and `dims == Val(N)` where `N` > 1: The grid iterates
over the anchors of multi-dimensional regions defined `vals`, which is a tuple
of interval endpoints for each dimension.
Continuous choices are assumed to have `dims = Val(1)` dimensions by default.
The `anchor` keyword argument controls which point in each interval is used as
the anchor (`:left`, `:right`, or `:midpoint`).

The log-volume `log_vol` associated with each set of `choices` in the grid is given
by the log-product of the volumes of each continuous region used to construct those
choices. If all addresses enumerated over are `:discrete`, then `log_vol = 0.0`.
"""
function choice_vol_grid(grid_specs::Tuple...; anchor::Symbol=:midpoint)
val_iter = (expand_grid_spec_to_values(spec...; anchor=anchor)
for spec in grid_specs)
val_iter = Iterators.product(val_iter...)
vol_iter = (expand_grid_spec_to_volumes(spec...) for spec in grid_specs)
vol_iter = Iterators.product(vol_iter...)
choice_vol_iter = Iterators.map(zip(val_iter, vol_iter)) do (vals, vols)
return (choicemap(vals...), sum(vols))
end
return choice_vol_iter
end

function expand_grid_spec_to_values(
addr, vals, support::Symbol = :discrete, dims::Val{N} = Val(1);
anchor::Symbol = :midpoint
) where {N}
if support == :discrete
return ((addr, v) for v in vals)
elseif support == :continuous && N == 1
if anchor == :left
vals = @view(vals[begin:end-1])
elseif anchor == :right
vals = @view(vals[begin+1:end])
else
vals = @view(vals[begin:end-1]) .+ (diff(vals) ./ 2)
end
return ((addr, v) for v in vals)
elseif support == :continuous && N > 1
@assert length(vals) == N "Dimension mismatch between `vals` and `dims`"
vals = map(vals) do vs
if anchor == :left
vs = @view(vs[begin:end-1])
elseif anchor == :right
vs = @view(vs[begin+1:end])
else
vs = @view(vs[begin:end-1]) .+ (diff(vs) ./ 2)
end
return vs
end
return ((addr, collect(v)) for v in Iterators.product(vals...))
else
error("Support must be :discrete or :continuous")
end
end

function expand_grid_spec_to_volumes(
addr, vals, support::Symbol = :discrete, dims::Val{N} = Val(1)
) where {N}
if support == :discrete
return zeros(length(vals))
elseif support == :continuous && N == 1
return log.(diff(vals))
elseif support == :continuous && N > 1
@assert length(vals) == N "Dimension mismatch between `vals` and `dims`"
diffs = Iterators.product((log.(diff(vs)) for vs in vals)...)
return (sum(ds) for ds in diffs)
else
error("Support must be :discrete or :continuous")
end
end

export enumerative_inference, choice_vol_grid
1 change: 1 addition & 0 deletions src/inference/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ include("hmc.jl")
include("mala.jl")
include("elliptical_slice.jl")

include("enumerative.jl")
include("importance.jl")
include("particle_filter.jl")
include("map_optimize.jl")
Expand Down
136 changes: 136 additions & 0 deletions test/inference/enumerative.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
@testset "enumerative inference" begin

# polynomial regression model
@gen function poly_model(n::Int, xs)
degree ~ uniform_discrete(1, n)
coeffs = zeros(n+1)
for d in 0:n
coeffs[d+1] = {(:coeff, d)} ~ uniform(-1, 1)
end
ys = zeros(length(xs))
for (i, x) in enumerate(xs)
x_powers = x .^ (0:n)
y_mean = sum(coeffs[d+1] * x_powers[d+1] for d in 0:degree)
ys[i] = {(:y, i)} ~ normal(y_mean, 0.1)
end
return ys
end

# synthetic dataset
coeffs = [0.5, 0.1, -0.5]
xs = collect(0.5:0.5:3.0)
ys = [(coeffs' * [x .^ d for d in 0:2]) for x in xs]

observations = choicemap()
for (i, y) in enumerate(ys)
observations[(:y, i)] = y
end

# test construction of choicemap-volume grid
grid = choice_vol_grid(
(:degree, 1:2),
((:coeff, 0), -1:0.2:1, :continuous),
((:coeff, 1), -1:0.2:1, :continuous),
((:coeff, 2), -1:0.2:1, :continuous),
anchor = :midpoint
)

@test size(grid) == (2, 10, 10, 10)
@test length(grid) == 2000

choices, log_vol = first(grid)
@test choices == choicemap(
(:degree, 1),
((:coeff, 0), -0.9), ((:coeff, 1), -0.9), ((:coeff, 2), -0.9),
)
@test log_vol ≈ log(0.2^3)

test_choices(n::Int, cs) =
cs[:degree] in 1:n && all(-1.0 <= cs[(:coeff, d)] <= 1.0 for d in 1:n)

@test all(test_choices(2, choices) for (choices, _) in grid)
@test all(log_vol ≈ log(0.2^3) for (_, log_vol) in grid)

# run enumerative inference over grid
traces, log_norm_weights, lml_est =
enumerative_inference(poly_model, (2, xs), observations, grid)

@test size(traces) == (2, 10, 10, 10)
@test length(traces) == 2000
@test all(test_choices(2, tr) for tr in traces)

# test that log-weights are as expected
log_joint_weights = [get_score(tr) + log(0.2^3) for tr in traces]
lml_expected = logsumexp(log_joint_weights)
@test lml_est ≈ lml_expected
@test all((jw - lml_expected) ≈ w for (jw, w) in zip(log_joint_weights, log_norm_weights))

# test that polynomial is most likely quadratic
degree_probs = sum(exp.(log_norm_weights), dims=(2, 3, 4))
@test argmax(vec(degree_probs)) == 2

# test that MAP trace recovers the original coefficients
map_trace_idx = argmax(log_norm_weights)
map_trace = traces[map_trace_idx]
@test map_trace[:degree] == 2
@test map_trace[(:coeff, 0)] == 0.5
@test map_trace[(:coeff, 1)] == 0.1
@test map_trace[(:coeff, 2)] == -0.5

# 2D mixture of normals
@gen function mixture_model()
sign ~ bernoulli(0.5)
mu = sign ? fill(0.5, 2) : fill(-0.5, 2)
z ~ broadcasted_normal(mu, ones(2))
end

# test construction of grid with 2D random variable
grid = choice_vol_grid(
(:sign, [false, true]),
(:z, (-2.0:0.1:2.0, -2.0:0.1:2.0), :continuous, Val(2)),
anchor = :left
)

@test size(grid) == (2, 40, 40)
@test length(grid) == 3200

choices, log_vol = first(grid)
@test choices == choicemap((:sign, false), (:z, [-2.0, -2.0]))
@test log_vol ≈ log(0.1^2)

@test all(all([-2.0, -2.0] .<= choices[:z] .<= [2.0, 2.0]) for (choices, _) in grid)
@test all(log_vol ≈ log(0.1^2) for (_, log_vol) in grid)

# run enumerative inference over grid
traces, log_norm_weights, lml_est =
enumerative_inference(mixture_model, (), choicemap(), grid)

@test size(traces) == (2, 40, 40)
@test length(traces) == 3200
@test all(all([-2.0, -2.0] .<= tr[:z] .<= [2.0, 2.0]) for tr in traces)

# test that log-weights are as expected
function expected_logpdf(tr)
x, y = tr[:z]
mu = tr[:sign] ? 0.5 : -0.5
return log(0.5) + logpdf(normal, x, mu, 1.0) + logpdf(normal, y, mu, 1.0)
end

log_joint_weights = [expected_logpdf(tr) + log(0.1^2) for tr in traces]
lml_expected = logsumexp(log_joint_weights)
@test lml_est ≈ lml_expected
@test all((jw - lml_expected) ≈ w for (jw, w) in zip(log_joint_weights, log_norm_weights))

# test that maximal log-weights are at modes
max_log_weight = maximum(log_norm_weights)
max_idxs = findall(log_norm_weights .== max_log_weight)

max_trace_1 = traces[max_idxs[1]]
@test max_trace_1[:sign] == false
@test max_trace_1[:z] == [-0.5, -0.5]

max_trace_2 = traces[max_idxs[2]]
@test max_trace_2[:sign] == true
@test max_trace_2[:z] == [0.5, 0.5]

end
Loading