Skip to content

Commit

Permalink
Merge pull request #545 from probcomp/enumerative_inference
Browse files Browse the repository at this point in the history
Add enumerative inference to the inference library
  • Loading branch information
ztangent authored Nov 4, 2024
2 parents a5fc8e3 + 358d3e4 commit 91d798f
Show file tree
Hide file tree
Showing 5 changed files with 290 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ pages = [
"Custom Generative Functions" => "ref/modeling/custom_gen_fns.md",
],
"Inference Library" => [
"Enumerative Inference" => "ref/inference/enumerative.md",
"Importance Sampling" => "ref/inference/importance.md",
"Markov Chain Monte Carlo" => "ref/inference/mcmc.md",
"Particle Filtering & SMC" => "ref/inference/pf.md",
Expand Down
18 changes: 18 additions & 0 deletions docs/src/ref/inference/enumerative.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Enumerative Inference

Enumerative inference can be used to compute the exact posterior distribution for a generative model
with a finite number of discrete random choices, to compute a grid approximation of a continuous
posterior density, or to perform stratified sampling by enumerating over discrete random choices and sampling
the continuous random choices. This functionality is provided by [`enumerative_inference`](@ref).

```@docs
enumerative_inference
```

To construct a rectangular grid of [choice maps](../core/choice_maps.md) and their associated log-volumes to iterate over, use the [`choice_vol_grid`](@ref) function.

```@docs
choice_vol_grid
```

When the space of possible choice maps is not rectangular (e.g. some addresses only exist depending on the values of other addresses), iterators over choice maps and log-volumes can be also be manually constructed.
134 changes: 134 additions & 0 deletions src/inference/enumerative.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
"""
(traces, log_norm_weights, lml_est) = enumerative_inference(
model::GenerativeFunction, model_args::Tuple,
observations::ChoiceMap, choice_vol_iter
)
Run enumerative inference over a `model`, given `observations` and an iterator over
choice maps and their associated log-volumes (`choice_vol_iter`), specifying the
choices to be iterated over. An iterator over a grid of choice maps and log-volumes
can be constructed with [`choice_vol_grid`](@ref).
Return an array of traces and associated log-weights with the same shape as
`choice_vol_iter`. The log-weight of each trace is normalized, and corresponds
to the log probability of the volume of sample space that the trace represents.
Also return an estimate of the log marginal likelihood of the observations (`lml_est`).
All addresses in the `observations` choice map must be sampled by the model when
given the model arguments. The same constraint applies to choice maps enumerated
over by `choice_vol_iter`, which must also avoid sharing addresses with the
`observations`. When the choice maps in `choice_vol_iter` do not fully specify
the values of all unobserved random choices, the unspecified choices are sampled
from the internal proposal distribution of the model.
"""
function enumerative_inference(
model::GenerativeFunction{T,U}, model_args::Tuple,
observations::ChoiceMap, choice_vol_iter::I
) where {T,U,I}
if Base.IteratorSize(I) isa Base.HasShape
traces = Array{U}(undef, size(choice_vol_iter))
log_weights = Array{Float64}(undef, size(choice_vol_iter))
elseif Base.IteratorSize(I) isa Base.HasLength
traces = Vector{U}(undef, length(choice_vol_iter))
log_weights = Vector{Float64}(undef, length(choice_vol_iter))
else
choice_vol_iter = collect(choice_vol_iter)
traces = Vector{U}(undef, length(choice_vol_iter))
log_weights = Vector{Float64}(undef, length(choice_vol_iter))
end
for (i, (choices, log_vol)) in enumerate(choice_vol_iter)
constraints = merge(observations, choices)
(traces[i], log_weight) = generate(model, model_args, constraints)
log_weights[i] = log_weight + log_vol
end
log_total_weight = logsumexp(log_weights)
log_normalized_weights = log_weights .- log_total_weight
return (traces, log_normalized_weights, log_total_weight)
end

"""
choice_vol_grid((addr, vals, [support, dims])::Tuple...; anchor=:midpoint)
Given tuples of the form `(addr, vals, [support, dims])`, construct an iterator
over tuples of the form `(choices::ChoiceMap, log_vol::Real)` via grid enumeration.
Each `addr` is an address of a random choice, and `vals` are the corresponding
values or intervals to enumerate over. The (optional) `support` denotes whether
each random choice is `:discrete` (default) or `:continuous`. This controls how
the grid is constructed:
- `support = :discrete`: The grid iterates over each value in `vals`.
- `support = :continuous` and `dims == Val(1)`: The grid iterates over the
anchors of 1D intervals whose endpoints are given by `vals`.
- `support = :continuous` and `dims == Val(N)` where `N` > 1: The grid iterates
over the anchors of multi-dimensional regions defined `vals`, which is a tuple
of interval endpoints for each dimension.
Continuous choices are assumed to have `dims = Val(1)` dimensions by default.
The `anchor` keyword argument controls which point in each interval is used as
the anchor (`:left`, `:right`, or `:midpoint`).
The log-volume `log_vol` associated with each set of `choices` in the grid is given
by the log-product of the volumes of each continuous region used to construct those
choices. If all addresses enumerated over are `:discrete`, then `log_vol = 0.0`.
"""
function choice_vol_grid(grid_specs::Tuple...; anchor::Symbol=:midpoint)
val_iter = (expand_grid_spec_to_values(spec...; anchor=anchor)
for spec in grid_specs)
val_iter = Iterators.product(val_iter...)
vol_iter = (expand_grid_spec_to_volumes(spec...) for spec in grid_specs)
vol_iter = Iterators.product(vol_iter...)
choice_vol_iter = Iterators.map(zip(val_iter, vol_iter)) do (vals, vols)
return (choicemap(vals...), sum(vols))
end
return choice_vol_iter
end

function expand_grid_spec_to_values(
addr, vals, support::Symbol = :discrete, dims::Val{N} = Val(1);
anchor::Symbol = :midpoint
) where {N}
if support == :discrete
return ((addr, v) for v in vals)
elseif support == :continuous && N == 1
if anchor == :left
vals = @view(vals[begin:end-1])
elseif anchor == :right
vals = @view(vals[begin+1:end])
else
vals = @view(vals[begin:end-1]) .+ (diff(vals) ./ 2)
end
return ((addr, v) for v in vals)
elseif support == :continuous && N > 1
@assert length(vals) == N "Dimension mismatch between `vals` and `dims`"
vals = map(vals) do vs
if anchor == :left
vs = @view(vs[begin:end-1])
elseif anchor == :right
vs = @view(vs[begin+1:end])
else
vs = @view(vs[begin:end-1]) .+ (diff(vs) ./ 2)
end
return vs
end
return ((addr, collect(v)) for v in Iterators.product(vals...))
else
error("Support must be :discrete or :continuous")
end
end

function expand_grid_spec_to_volumes(
addr, vals, support::Symbol = :discrete, dims::Val{N} = Val(1)
) where {N}
if support == :discrete
return zeros(length(vals))
elseif support == :continuous && N == 1
return log.(diff(vals))
elseif support == :continuous && N > 1
@assert length(vals) == N "Dimension mismatch between `vals` and `dims`"
diffs = Iterators.product((log.(diff(vs)) for vs in vals)...)
return (sum(ds) for ds in diffs)
else
error("Support must be :discrete or :continuous")
end
end

export enumerative_inference, choice_vol_grid
1 change: 1 addition & 0 deletions src/inference/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ include("hmc.jl")
include("mala.jl")
include("elliptical_slice.jl")

include("enumerative.jl")
include("importance.jl")
include("particle_filter.jl")
include("map_optimize.jl")
Expand Down
136 changes: 136 additions & 0 deletions test/inference/enumerative.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
@testset "enumerative inference" begin

# polynomial regression model
@gen function poly_model(n::Int, xs)
degree ~ uniform_discrete(1, n)
coeffs = zeros(n+1)
for d in 0:n
coeffs[d+1] = {(:coeff, d)} ~ uniform(-1, 1)
end
ys = zeros(length(xs))
for (i, x) in enumerate(xs)
x_powers = x .^ (0:n)
y_mean = sum(coeffs[d+1] * x_powers[d+1] for d in 0:degree)
ys[i] = {(:y, i)} ~ normal(y_mean, 0.1)
end
return ys
end

# synthetic dataset
coeffs = [0.5, 0.1, -0.5]
xs = collect(0.5:0.5:3.0)
ys = [(coeffs' * [x .^ d for d in 0:2]) for x in xs]

observations = choicemap()
for (i, y) in enumerate(ys)
observations[(:y, i)] = y
end

# test construction of choicemap-volume grid
grid = choice_vol_grid(
(:degree, 1:2),
((:coeff, 0), -1:0.2:1, :continuous),
((:coeff, 1), -1:0.2:1, :continuous),
((:coeff, 2), -1:0.2:1, :continuous),
anchor = :midpoint
)

@test size(grid) == (2, 10, 10, 10)
@test length(grid) == 2000

choices, log_vol = first(grid)
@test choices == choicemap(
(:degree, 1),
((:coeff, 0), -0.9), ((:coeff, 1), -0.9), ((:coeff, 2), -0.9),
)
@test log_vol log(0.2^3)

test_choices(n::Int, cs) =
cs[:degree] in 1:n && all(-1.0 <= cs[(:coeff, d)] <= 1.0 for d in 1:n)

@test all(test_choices(2, choices) for (choices, _) in grid)
@test all(log_vol log(0.2^3) for (_, log_vol) in grid)

# run enumerative inference over grid
traces, log_norm_weights, lml_est =
enumerative_inference(poly_model, (2, xs), observations, grid)

@test size(traces) == (2, 10, 10, 10)
@test length(traces) == 2000
@test all(test_choices(2, tr) for tr in traces)

# test that log-weights are as expected
log_joint_weights = [get_score(tr) + log(0.2^3) for tr in traces]
lml_expected = logsumexp(log_joint_weights)
@test lml_est lml_expected
@test all((jw - lml_expected) w for (jw, w) in zip(log_joint_weights, log_norm_weights))

# test that polynomial is most likely quadratic
degree_probs = sum(exp.(log_norm_weights), dims=(2, 3, 4))
@test argmax(vec(degree_probs)) == 2

# test that MAP trace recovers the original coefficients
map_trace_idx = argmax(log_norm_weights)
map_trace = traces[map_trace_idx]
@test map_trace[:degree] == 2
@test map_trace[(:coeff, 0)] == 0.5
@test map_trace[(:coeff, 1)] == 0.1
@test map_trace[(:coeff, 2)] == -0.5

# 2D mixture of normals
@gen function mixture_model()
sign ~ bernoulli(0.5)
mu = sign ? fill(0.5, 2) : fill(-0.5, 2)
z ~ broadcasted_normal(mu, ones(2))
end

# test construction of grid with 2D random variable
grid = choice_vol_grid(
(:sign, [false, true]),
(:z, (-2.0:0.1:2.0, -2.0:0.1:2.0), :continuous, Val(2)),
anchor = :left
)

@test size(grid) == (2, 40, 40)
@test length(grid) == 3200

choices, log_vol = first(grid)
@test choices == choicemap((:sign, false), (:z, [-2.0, -2.0]))
@test log_vol log(0.1^2)

@test all(all([-2.0, -2.0] .<= choices[:z] .<= [2.0, 2.0]) for (choices, _) in grid)
@test all(log_vol log(0.1^2) for (_, log_vol) in grid)

# run enumerative inference over grid
traces, log_norm_weights, lml_est =
enumerative_inference(mixture_model, (), choicemap(), grid)

@test size(traces) == (2, 40, 40)
@test length(traces) == 3200
@test all(all([-2.0, -2.0] .<= tr[:z] .<= [2.0, 2.0]) for tr in traces)

# test that log-weights are as expected
function expected_logpdf(tr)
x, y = tr[:z]
mu = tr[:sign] ? 0.5 : -0.5
return log(0.5) + logpdf(normal, x, mu, 1.0) + logpdf(normal, y, mu, 1.0)
end

log_joint_weights = [expected_logpdf(tr) + log(0.1^2) for tr in traces]
lml_expected = logsumexp(log_joint_weights)
@test lml_est lml_expected
@test all((jw - lml_expected) w for (jw, w) in zip(log_joint_weights, log_norm_weights))

# test that maximal log-weights are at modes
max_log_weight = maximum(log_norm_weights)
max_idxs = findall(log_norm_weights .== max_log_weight)

max_trace_1 = traces[max_idxs[1]]
@test max_trace_1[:sign] == false
@test max_trace_1[:z] == [-0.5, -0.5]

max_trace_2 = traces[max_idxs[2]]
@test max_trace_2[:sign] == true
@test max_trace_2[:z] == [0.5, 0.5]

end

0 comments on commit 91d798f

Please sign in to comment.