Skip to content

Commit

Permalink
add multi-threading to VI and importance sampling with tests
Browse files Browse the repository at this point in the history
  • Loading branch information
marcoct committed May 18, 2021
1 parent 22b9c98 commit 64ee5e0
Show file tree
Hide file tree
Showing 6 changed files with 179 additions and 102 deletions.
142 changes: 100 additions & 42 deletions src/inference/variational.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ function single_sample_gradient_estimate!(
accumulate_param_gradients!(var_trace, nothing, log_weight * scale_factor)

# unbiased estimate of objective function, and trace
(log_weight, var_trace, model_trace)
return (log_weight, var_trace, model_trace)
end

function vimco_geometric_baselines(log_weights)
Expand All @@ -29,12 +29,12 @@ function vimco_geometric_baselines(log_weights)
baselines[i] = logsumexp(log_weights) - log(num_samples)
log_weights[i] = temp
end
baselines
return baselines
end

function logdiffexp(x, y)
m = max(x, y)
m + log(exp(x - m) - exp(y - m))
return m + log(exp(x - m) - exp(y - m))
end

function vimco_arithmetic_baselines(log_weights)
Expand All @@ -46,7 +46,7 @@ function vimco_arithmetic_baselines(log_weights)
log_f_hat = log_sum_f_without_i - log(num_samples - 1)
baselines[i] = logsumexp(log_sum_f_without_i, log_f_hat) - log(num_samples)
end
baselines
return baselines
end

# black box, VIMCO gradient estimator
Expand Down Expand Up @@ -85,7 +85,7 @@ function multi_sample_gradient_estimate!(

# collection of traces and normalized importance weights, and estimate of
# objective function
(L, traces, weights_normalized)
return (L, traces, weights_normalized)
end

function _maybe_accumulate_param_grad!(trace, optimizer, scale_factor::Real)
Expand Down Expand Up @@ -117,6 +117,7 @@ update the parameters of `model`.
- `callback`: Callback function that takes `(iter, traces, elbo_estimate)`
as input, where `iter` is the iteration number and `traces` are samples
from `var_model` for that iteration.
- `multithreaded`: if `true`, gradient estimation may use multiple threads.
"""
function black_box_vi!(
model::GenerativeFunction, model_args::Tuple,
Expand All @@ -125,31 +126,32 @@ function black_box_vi!(
var_model::GenerativeFunction, var_model_args::Tuple,
var_model_optimizer;
iters=1000, samples_per_iter=100, verbose=false,
callback=(iter, traces, elbo_estimate) -> nothing)
callback=(iter, traces, elbo_estimate) -> nothing,
multithreaded=false)

var_traces = Vector{Any}(undef, samples_per_iter)
model_traces = Vector{Any}(undef, samples_per_iter)
log_weights = Vector{Float64}(undef, samples_per_iter)
elbo_history = Vector{Float64}(undef, iters)
for iter=1:iters

# compute gradient estimate and objective function estimate
elbo_estimate = 0.0
# TODO multithread (note that this would require accumulate_param_gradients! to be threadsafe)
for sample=1:samples_per_iter

# accumulate the variational family gradients
(log_weight, var_trace, model_trace) = single_sample_gradient_estimate!(
var_model, var_model_args,
model, model_args, observations, 1/samples_per_iter)
elbo_estimate += (log_weight / samples_per_iter)

# accumulate the generative model gradients
_maybe_accumulate_param_grad!(model_trace, model_optimizer, 1.0 / samples_per_iter)

# record the traces
var_traces[sample] = var_trace
model_traces[sample] = model_trace
if multithreaded
Threads.@threads for i in 1:samples_per_iter
black_box_vi_iter!(
var_traces, model_traces, log_weights, i, samples_per_iter,
var_model, var_model_args,
model, model_args, observations, model_optimizer)
end
else
for i in 1:samples_per_iter
black_box_vi_iter!(
var_traces, model_traces, log_weights, i, samples_per_iter,
var_model, var_model_args,
model, model_args, observations, model_optimizer)
end
end
elbo_estimate = sum(log_weights)
elbo_history[iter] = elbo_estimate

# print it
Expand All @@ -167,9 +169,34 @@ function black_box_vi!(
end
end

(elbo_history[end], var_traces, elbo_history, model_traces)
return (elbo_history[end], var_traces, elbo_history, model_traces)
end

function black_box_vi_iter!(
var_traces::Vector, model_traces::Vector, log_weights::Vector{Float64},
i::Int, n::Int,
var_model::GenerativeFunction, var_model_args::Tuple,
model::GenerativeFunction, model_args::Tuple,
observations::ChoiceMap,
model_optimizer)

# accumulate the variational family gradients
(log_weight, var_trace, model_trace) = single_sample_gradient_estimate!(
var_model, var_model_args,
model, model_args, observations, 1.0 / n)
log_weights[i] = log_weight / n

# accumulate the generative model gradients
_maybe_accumulate_param_grad!(model_trace, model_optimizer, 1.0 / n)

# record the traces
var_traces[i] = var_trace
model_traces[i] = model_trace

return nothing
end


black_box_vi!(model::GenerativeFunction, model_args::Tuple,
observations::ChoiceMap,
var_model::GenerativeFunction, var_model_args::Tuple,
Expand Down Expand Up @@ -205,42 +232,45 @@ update the parameters of `model`.
- `callback`: Callback function that takes `(iter, traces, elbo_estimate)`
as input, where `iter` is the iteration number and `traces` are samples
from `var_model` for that iteration.
- `multithreaded`: if `true`, gradient estimation may use multiple threads.
"""
function black_box_vimco!(
model::GenerativeFunction, model_args::Tuple,
model_optimizer::Union{CompositeOptimizer,Nothing}, observations::ChoiceMap,
var_model::GenerativeFunction, var_model_args::Tuple,
var_model_optimizer::CompositeOptimizer, grad_est_samples::Int;
iters=1000, samples_per_iter=100, geometric=true, verbose=false,
callback=(iter, traces, elbo_estimate) -> nothing)
callback=(iter, traces, elbo_estimate) -> nothing,
multithreaded=false)

resampled_var_traces = Vector{Any}(undef, samples_per_iter)
model_traces = Vector{Any}(undef, samples_per_iter)
log_weights = Vector{Float64}(undef, samples_per_iter)

iwelbo_history = Vector{Float64}(undef, iters)
for iter=1:iters

# compute gradient estimate and objective function estimate
iwelbo_estimate = 0.
for sample=1:samples_per_iter

# accumulate the variational family gradients
(est, original_var_traces, weights) = multi_sample_gradient_estimate!(
var_model, var_model_args,
model, model_args, observations, grad_est_samples,
1/samples_per_iter, geometric)
iwelbo_estimate += (est / samples_per_iter)

# record a variational trace obtained by resampling from the weighted collection
resampled_var_traces[sample] = original_var_traces[categorical(weights)]

# accumulate the generative model gradient estimator
for (var_trace, weight) in zip(original_var_traces, weights)
constraints = merge(observations, get_choices(var_trace))
(model_trace, _) = generate(model, model_args, constraints)
_maybe_accumulate_param_grad!(model_trace, model_optimizer, weight / samples_per_iter)
if multithreaded
Threads.@threads for i in 1:samples_per_iter
black_box_vimco_iter!(
resampled_var_traces, log_weights,
i, samples_per_iter,
var_model, var_model_args, model, model_args,
observations, geometric, grad_est_samples,
model_optimizer)
end
else
for i in 1:samples_per_iter
black_box_vimco_iter!(
resampled_var_traces, log_weights,
i, samples_per_iter,
var_model, var_model_args, model, model_args,
observations, geometric, grad_est_samples,
model_optimizer)
end
end
iwelbo_estimate = sum(log_weights)
iwelbo_history[iter] = iwelbo_estimate

# print it
Expand All @@ -262,6 +292,34 @@ function black_box_vimco!(
(iwelbo_history[end], resampled_var_traces, iwelbo_history, model_traces)
end

function black_box_vimco_iter!(
resampled_var_traces::Vector, log_weights::Vector{Float64},
i::Int, samples_per_iter::Int,
var_model::GenerativeFunction, var_model_args::Tuple,
model::GenerativeFunction, model_args::Tuple,
observations::ChoiceMap, geometric::Bool, grad_est_samples::Int,
model_optimizer)

# accumulate the variational family gradients
(est, original_var_traces, weights) = multi_sample_gradient_estimate!(
var_model, var_model_args,
model, model_args, observations, grad_est_samples,
1/samples_per_iter, geometric)
log_weights[i] = est / samples_per_iter

# record a variational trace obtained by resampling from the weighted collection
resampled_var_traces[i] = original_var_traces[categorical(weights)]

# accumulate the generative model gradient estimator
for (var_trace, weight) in zip(original_var_traces, weights)
constraints = merge(observations, get_choices(var_trace))
(model_trace, _) = generate(model, model_args, constraints)
_maybe_accumulate_param_grad!(model_trace, model_optimizer, weight / samples_per_iter)
end

return nothing
end

black_box_vimco!(model::GenerativeFunction, model_args::Tuple,
observations::ChoiceMap,
var_model::GenerativeFunction, var_model_args::Tuple,
Expand Down
4 changes: 1 addition & 3 deletions src/optimization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ Thread-safe (multiple threads can increment the gradient of the same parameter c
function increment_gradient!(
id::Tuple{GenerativeFunction,Symbol}, increment,
store::JuliaParameterStore=default_julia_parameter_store)
accumulator = get_gradient_accumulator(store, id)
accumulator = get_gradient_accumulator(id, store)
in_place_add!(accumulator, increment)
return nothing
end
Expand Down Expand Up @@ -555,5 +555,3 @@ function assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap)
end

propose(gen_fn::GenerativeFunction, args::Tuple) = propose(gen_fn, args, default_parameter_context)


39 changes: 21 additions & 18 deletions test/inference/importance_sampling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,33 +15,36 @@

n = 4

(traces, log_weights, lml_est) = importance_sampling(model, (), observations, n)
@test length(traces) == n
@test length(log_weights) == n
@test isapprox(logsumexp(log_weights), 0., atol=1e-14)
@test !isnan(lml_est)
for trace in traces
@test get_choices(trace)[:y] == y
for multithreaded in [false, true]
(traces, log_weights, lml_est) = importance_sampling(
model, (), observations, n; multithreaded=multithreaded)
@test length(traces) == n
@test length(log_weights) == n
@test isapprox(logsumexp(log_weights), 0., atol=1e-14)
@test !isnan(lml_est)
for trace in traces
@test get_choices(trace)[:y] == y
end
end

(traces, log_weights, lml_est) = importance_sampling(model, (), observations, proposal, (), n)
@test length(traces) == n
@test length(log_weights) == n
@test isapprox(logsumexp(log_weights), 0., atol=1e-14)
@test !isnan(lml_est)
for trace in traces
@test get_choices(trace)[:y] == y
for multithreaded in [false, true]
(traces, log_weights, lml_est) = importance_sampling(
model, (), observations, proposal, (), n;
multithreaded=multithreaded)
@test length(traces) == n
@test length(log_weights) == n
@test isapprox(logsumexp(log_weights), 0., atol=1e-14)
@test !isnan(lml_est)
for trace in traces
@test get_choices(trace)[:y] == y
end
end

(trace, lml_est) = importance_resampling(model, (), observations, n)
@test isapprox(logsumexp(log_weights), 0., atol=1e-14)
@test !isnan(lml_est)
@test get_choices(trace)[:y] == y

(trace, lml_est) = importance_resampling(model, (), observations, proposal, (), n)
@test isapprox(logsumexp(log_weights), 0., atol=1e-14)
@test !isnan(lml_est)
@test get_choices(trace)[:y] == y
end


66 changes: 36 additions & 30 deletions test/inference/variational.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,40 +17,46 @@
end
register_parameters!(approx, [:slope_mu, :slope_log_std, :intercept_mu, :intercept_log_std])

# to regular black box variational inference
init_parameter!((approx, :slope_mu), 0.0)
init_parameter!((approx, :slope_log_std), 0.0)
init_parameter!((approx, :intercept_mu), 0.0)
init_parameter!((approx, :intercept_log_std), 0.0)

observations = choicemap()
optimizer = init_optimizer(DecayStepGradientDescent(1, 100000), approx)
optimizer = init_optimizer(DecayStepGradientDescent(1., 1000), approx)
black_box_vi!(model, (), observations, approx, (), optimizer;
iters=2000, samples_per_iter=100, verbose=false)
slope_mu = get_parameter_value((approx, :slope_mu))
slope_log_std = get_parameter_value((approx, :slope_log_std))
intercept_mu = get_parameter_value((approx, :intercept_mu))
intercept_log_std = get_parameter_value((approx, :intercept_log_std))
@test isapprox(slope_mu, -1., atol=0.001)
@test isapprox(slope_log_std, 0.5, atol=0.001)
@test isapprox(intercept_mu, 1., atol=0.001)
@test isapprox(intercept_log_std, 2.0, atol=0.001)

# test regular black box variational inference
for multithreaded in [false, true]
init_parameter!((approx, :slope_mu), 0.0)
init_parameter!((approx, :slope_log_std), 0.0)
init_parameter!((approx, :intercept_mu), 0.0)
init_parameter!((approx, :intercept_log_std), 0.0)
black_box_vi!(model, (), observations, approx, (), optimizer;
iters=2000, samples_per_iter=100, verbose=false, multithreaded=multithreaded)

slope_mu = get_parameter_value((approx, :slope_mu))
slope_log_std = get_parameter_value((approx, :slope_log_std))
intercept_mu = get_parameter_value((approx, :intercept_mu))
intercept_log_std = get_parameter_value((approx, :intercept_log_std))
@test isapprox(slope_mu, -1., atol=0.001)
@test isapprox(slope_log_std, 0.5, atol=0.001)
@test isapprox(intercept_mu, 1., atol=0.001)
@test isapprox(intercept_log_std, 2.0, atol=0.001)
end

# smoke test for black box variational inference with Monte Carlo objectives
init_parameter!((approx, :slope_mu), 0.0)
init_parameter!((approx, :slope_log_std), 0.0)
init_parameter!((approx, :intercept_mu), 0.0)
init_parameter!((approx, :intercept_log_std), 0.0)
black_box_vimco!(model, (), observations, approx, (), optimizer, 20;
iters=50, samples_per_iter=100, verbose=false, geometric=false)

init_parameter!((approx, :slope_mu), 0.0)
init_parameter!((approx, :slope_log_std), 0.0)
init_parameter!((approx, :intercept_mu), 0.0)
init_parameter!((approx, :intercept_log_std), 0.0)
black_box_vimco!(model, (), observations, approx, (), optimizer, 20;
iters=50, samples_per_iter=100, verbose=false, geometric=true)
for multithreaded in [false, true]
init_parameter!((approx, :slope_mu), 0.0)
init_parameter!((approx, :slope_log_std), 0.0)
init_parameter!((approx, :intercept_mu), 0.0)
init_parameter!((approx, :intercept_log_std), 0.0)
black_box_vimco!(model, (), observations, approx, (), optimizer, 20;
iters=50, samples_per_iter=100, verbose=false, geometric=false,
multithreaded=multithreaded)

init_parameter!((approx, :slope_mu), 0.0)
init_parameter!((approx, :slope_log_std), 0.0)
init_parameter!((approx, :intercept_mu), 0.0)
init_parameter!((approx, :intercept_log_std), 0.0)
black_box_vimco!(model, (), observations, approx, (), optimizer, 20;
iters=50, samples_per_iter=100, verbose=false, geometric=true,
multithreaded=multithreaded)
end

end

Expand Down
Loading

0 comments on commit 64ee5e0

Please sign in to comment.