diff --git a/docs/pages.jl b/docs/pages.jl index 0f5b89b1..3a096203 100644 --- a/docs/pages.jl +++ b/docs/pages.jl @@ -30,6 +30,7 @@ pages = [ "Custom Generative Functions" => "ref/modeling/custom_gen_fns.md", ], "Inference Library" => [ + "Enumerative Inference" => "ref/inference/enumerative.md", "Importance Sampling" => "ref/inference/importance.md", "Markov Chain Monte Carlo" => "ref/inference/mcmc.md", "Particle Filtering & SMC" => "ref/inference/pf.md", diff --git a/docs/src/ref/inference/enumerative.md b/docs/src/ref/inference/enumerative.md new file mode 100644 index 00000000..c9e95b83 --- /dev/null +++ b/docs/src/ref/inference/enumerative.md @@ -0,0 +1,18 @@ +# Enumerative Inference + +Enumerative inference can be used to compute the exact posterior distribution for a generative model +with a finite number of discrete random choices, to compute a grid approximation of a continuous +posterior density, or to perform stratified sampling by enumerating over discrete random choices and sampling +the continuous random choices. This functionality is provided by [`enumerative_inference`](@ref). + +```@docs +enumerative_inference +``` + +To construct a rectangular grid of [choice maps](../core/choice_maps.md) and their associated log-volumes to iterate over, use the [`choice_vol_grid`](@ref) function. + +```@docs +choice_vol_grid +``` + +When the space of possible choice maps is not rectangular (e.g. some addresses only exist depending on the values of other addresses), iterators over choice maps and log-volumes can be also be manually constructed. diff --git a/src/inference/enumerative.jl b/src/inference/enumerative.jl new file mode 100644 index 00000000..5e24690f --- /dev/null +++ b/src/inference/enumerative.jl @@ -0,0 +1,134 @@ +""" + (traces, log_norm_weights, lml_est) = enumerative_inference( + model::GenerativeFunction, model_args::Tuple, + observations::ChoiceMap, choice_vol_iter + ) + +Run enumerative inference over a `model`, given `observations` and an iterator over +choice maps and their associated log-volumes (`choice_vol_iter`), specifying the +choices to be iterated over. An iterator over a grid of choice maps and log-volumes +can be constructed with [`choice_vol_grid`](@ref). + +Return an array of traces and associated log-weights with the same shape as +`choice_vol_iter`. The log-weight of each trace is normalized, and corresponds +to the log probability of the volume of sample space that the trace represents. +Also return an estimate of the log marginal likelihood of the observations (`lml_est`). + +All addresses in the `observations` choice map must be sampled by the model when +given the model arguments. The same constraint applies to choice maps enumerated +over by `choice_vol_iter`, which must also avoid sharing addresses with the +`observations`. When the choice maps in `choice_vol_iter` do not fully specify +the values of all unobserved random choices, the unspecified choices are sampled +from the internal proposal distribution of the model. +""" +function enumerative_inference( + model::GenerativeFunction{T,U}, model_args::Tuple, + observations::ChoiceMap, choice_vol_iter::I +) where {T,U,I} + if Base.IteratorSize(I) isa Base.HasShape + traces = Array{U}(undef, size(choice_vol_iter)) + log_weights = Array{Float64}(undef, size(choice_vol_iter)) + elseif Base.IteratorSize(I) isa Base.HasLength + traces = Vector{U}(undef, length(choice_vol_iter)) + log_weights = Vector{Float64}(undef, length(choice_vol_iter)) + else + choice_vol_iter = collect(choice_vol_iter) + traces = Vector{U}(undef, length(choice_vol_iter)) + log_weights = Vector{Float64}(undef, length(choice_vol_iter)) + end + for (i, (choices, log_vol)) in enumerate(choice_vol_iter) + constraints = merge(observations, choices) + (traces[i], log_weight) = generate(model, model_args, constraints) + log_weights[i] = log_weight + log_vol + end + log_total_weight = logsumexp(log_weights) + log_normalized_weights = log_weights .- log_total_weight + return (traces, log_normalized_weights, log_total_weight) +end + +""" + choice_vol_grid((addr, vals, [support, dims])::Tuple...; anchor=:midpoint) + +Given tuples of the form `(addr, vals, [support, dims])`, construct an iterator +over tuples of the form `(choices::ChoiceMap, log_vol::Real)` via grid enumeration. + +Each `addr` is an address of a random choice, and `vals` are the corresponding +values or intervals to enumerate over. The (optional) `support` denotes whether +each random choice is `:discrete` (default) or `:continuous`. This controls how +the grid is constructed: +- `support = :discrete`: The grid iterates over each value in `vals`. +- `support = :continuous` and `dims == Val(1)`: The grid iterates over the + anchors of 1D intervals whose endpoints are given by `vals`. +- `support = :continuous` and `dims == Val(N)` where `N` > 1: The grid iterates + over the anchors of multi-dimensional regions defined `vals`, which is a tuple + of interval endpoints for each dimension. +Continuous choices are assumed to have `dims = Val(1)` dimensions by default. +The `anchor` keyword argument controls which point in each interval is used as +the anchor (`:left`, `:right`, or `:midpoint`). + +The log-volume `log_vol` associated with each set of `choices` in the grid is given +by the log-product of the volumes of each continuous region used to construct those +choices. If all addresses enumerated over are `:discrete`, then `log_vol = 0.0`. +""" +function choice_vol_grid(grid_specs::Tuple...; anchor::Symbol=:midpoint) + val_iter = (expand_grid_spec_to_values(spec...; anchor=anchor) + for spec in grid_specs) + val_iter = Iterators.product(val_iter...) + vol_iter = (expand_grid_spec_to_volumes(spec...) for spec in grid_specs) + vol_iter = Iterators.product(vol_iter...) + choice_vol_iter = Iterators.map(zip(val_iter, vol_iter)) do (vals, vols) + return (choicemap(vals...), sum(vols)) + end + return choice_vol_iter +end + +function expand_grid_spec_to_values( + addr, vals, support::Symbol = :discrete, dims::Val{N} = Val(1); + anchor::Symbol = :midpoint +) where {N} + if support == :discrete + return ((addr, v) for v in vals) + elseif support == :continuous && N == 1 + if anchor == :left + vals = @view(vals[begin:end-1]) + elseif anchor == :right + vals = @view(vals[begin+1:end]) + else + vals = @view(vals[begin:end-1]) .+ (diff(vals) ./ 2) + end + return ((addr, v) for v in vals) + elseif support == :continuous && N > 1 + @assert length(vals) == N "Dimension mismatch between `vals` and `dims`" + vals = map(vals) do vs + if anchor == :left + vs = @view(vs[begin:end-1]) + elseif anchor == :right + vs = @view(vs[begin+1:end]) + else + vs = @view(vs[begin:end-1]) .+ (diff(vs) ./ 2) + end + return vs + end + return ((addr, collect(v)) for v in Iterators.product(vals...)) + else + error("Support must be :discrete or :continuous") + end +end + +function expand_grid_spec_to_volumes( + addr, vals, support::Symbol = :discrete, dims::Val{N} = Val(1) +) where {N} + if support == :discrete + return zeros(length(vals)) + elseif support == :continuous && N == 1 + return log.(diff(vals)) + elseif support == :continuous && N > 1 + @assert length(vals) == N "Dimension mismatch between `vals` and `dims`" + diffs = Iterators.product((log.(diff(vs)) for vs in vals)...) + return (sum(ds) for ds in diffs) + else + error("Support must be :discrete or :continuous") + end +end + +export enumerative_inference, choice_vol_grid \ No newline at end of file diff --git a/src/inference/inference.jl b/src/inference/inference.jl index d37298e2..1792b38c 100644 --- a/src/inference/inference.jl +++ b/src/inference/inference.jl @@ -21,6 +21,7 @@ include("hmc.jl") include("mala.jl") include("elliptical_slice.jl") +include("enumerative.jl") include("importance.jl") include("particle_filter.jl") include("map_optimize.jl") diff --git a/test/inference/enumerative.jl b/test/inference/enumerative.jl new file mode 100644 index 00000000..e7ad57c9 --- /dev/null +++ b/test/inference/enumerative.jl @@ -0,0 +1,136 @@ +@testset "enumerative inference" begin + + # polynomial regression model + @gen function poly_model(n::Int, xs) + degree ~ uniform_discrete(1, n) + coeffs = zeros(n+1) + for d in 0:n + coeffs[d+1] = {(:coeff, d)} ~ uniform(-1, 1) + end + ys = zeros(length(xs)) + for (i, x) in enumerate(xs) + x_powers = x .^ (0:n) + y_mean = sum(coeffs[d+1] * x_powers[d+1] for d in 0:degree) + ys[i] = {(:y, i)} ~ normal(y_mean, 0.1) + end + return ys + end + + # synthetic dataset + coeffs = [0.5, 0.1, -0.5] + xs = collect(0.5:0.5:3.0) + ys = [(coeffs' * [x .^ d for d in 0:2]) for x in xs] + + observations = choicemap() + for (i, y) in enumerate(ys) + observations[(:y, i)] = y + end + + # test construction of choicemap-volume grid + grid = choice_vol_grid( + (:degree, 1:2), + ((:coeff, 0), -1:0.2:1, :continuous), + ((:coeff, 1), -1:0.2:1, :continuous), + ((:coeff, 2), -1:0.2:1, :continuous), + anchor = :midpoint + ) + + @test size(grid) == (2, 10, 10, 10) + @test length(grid) == 2000 + + choices, log_vol = first(grid) + @test choices == choicemap( + (:degree, 1), + ((:coeff, 0), -0.9), ((:coeff, 1), -0.9), ((:coeff, 2), -0.9), + ) + @test log_vol ≈ log(0.2^3) + + test_choices(n::Int, cs) = + cs[:degree] in 1:n && all(-1.0 <= cs[(:coeff, d)] <= 1.0 for d in 1:n) + + @test all(test_choices(2, choices) for (choices, _) in grid) + @test all(log_vol ≈ log(0.2^3) for (_, log_vol) in grid) + + # run enumerative inference over grid + traces, log_norm_weights, lml_est = + enumerative_inference(poly_model, (2, xs), observations, grid) + + @test size(traces) == (2, 10, 10, 10) + @test length(traces) == 2000 + @test all(test_choices(2, tr) for tr in traces) + + # test that log-weights are as expected + log_joint_weights = [get_score(tr) + log(0.2^3) for tr in traces] + lml_expected = logsumexp(log_joint_weights) + @test lml_est ≈ lml_expected + @test all((jw - lml_expected) ≈ w for (jw, w) in zip(log_joint_weights, log_norm_weights)) + + # test that polynomial is most likely quadratic + degree_probs = sum(exp.(log_norm_weights), dims=(2, 3, 4)) + @test argmax(vec(degree_probs)) == 2 + + # test that MAP trace recovers the original coefficients + map_trace_idx = argmax(log_norm_weights) + map_trace = traces[map_trace_idx] + @test map_trace[:degree] == 2 + @test map_trace[(:coeff, 0)] == 0.5 + @test map_trace[(:coeff, 1)] == 0.1 + @test map_trace[(:coeff, 2)] == -0.5 + + # 2D mixture of normals + @gen function mixture_model() + sign ~ bernoulli(0.5) + mu = sign ? fill(0.5, 2) : fill(-0.5, 2) + z ~ broadcasted_normal(mu, ones(2)) + end + + # test construction of grid with 2D random variable + grid = choice_vol_grid( + (:sign, [false, true]), + (:z, (-2.0:0.1:2.0, -2.0:0.1:2.0), :continuous, Val(2)), + anchor = :left + ) + + @test size(grid) == (2, 40, 40) + @test length(grid) == 3200 + + choices, log_vol = first(grid) + @test choices == choicemap((:sign, false), (:z, [-2.0, -2.0])) + @test log_vol ≈ log(0.1^2) + + @test all(all([-2.0, -2.0] .<= choices[:z] .<= [2.0, 2.0]) for (choices, _) in grid) + @test all(log_vol ≈ log(0.1^2) for (_, log_vol) in grid) + + # run enumerative inference over grid + traces, log_norm_weights, lml_est = + enumerative_inference(mixture_model, (), choicemap(), grid) + + @test size(traces) == (2, 40, 40) + @test length(traces) == 3200 + @test all(all([-2.0, -2.0] .<= tr[:z] .<= [2.0, 2.0]) for tr in traces) + + # test that log-weights are as expected + function expected_logpdf(tr) + x, y = tr[:z] + mu = tr[:sign] ? 0.5 : -0.5 + return log(0.5) + logpdf(normal, x, mu, 1.0) + logpdf(normal, y, mu, 1.0) + end + + log_joint_weights = [expected_logpdf(tr) + log(0.1^2) for tr in traces] + lml_expected = logsumexp(log_joint_weights) + @test lml_est ≈ lml_expected + @test all((jw - lml_expected) ≈ w for (jw, w) in zip(log_joint_weights, log_norm_weights)) + + # test that maximal log-weights are at modes + max_log_weight = maximum(log_norm_weights) + max_idxs = findall(log_norm_weights .== max_log_weight) + + max_trace_1 = traces[max_idxs[1]] + @test max_trace_1[:sign] == false + @test max_trace_1[:z] == [-0.5, -0.5] + + max_trace_2 = traces[max_idxs[2]] + @test max_trace_2[:sign] == true + @test max_trace_2[:z] == [0.5, 0.5] + +end