diff --git a/src/inference/variational.jl b/src/inference/variational.jl index 0be8f769..09713cc7 100644 --- a/src/inference/variational.jl +++ b/src/inference/variational.jl @@ -16,7 +16,7 @@ function single_sample_gradient_estimate!( accumulate_param_gradients!(var_trace, nothing, log_weight * scale_factor) # unbiased estimate of objective function, and trace - (log_weight, var_trace, model_trace) + return (log_weight, var_trace, model_trace) end function vimco_geometric_baselines(log_weights) @@ -29,12 +29,12 @@ function vimco_geometric_baselines(log_weights) baselines[i] = logsumexp(log_weights) - log(num_samples) log_weights[i] = temp end - baselines + return baselines end function logdiffexp(x, y) m = max(x, y) - m + log(exp(x - m) - exp(y - m)) + return m + log(exp(x - m) - exp(y - m)) end function vimco_arithmetic_baselines(log_weights) @@ -46,7 +46,7 @@ function vimco_arithmetic_baselines(log_weights) log_f_hat = log_sum_f_without_i - log(num_samples - 1) baselines[i] = logsumexp(log_sum_f_without_i, log_f_hat) - log(num_samples) end - baselines + return baselines end # black box, VIMCO gradient estimator @@ -85,7 +85,7 @@ function multi_sample_gradient_estimate!( # collection of traces and normalized importance weights, and estimate of # objective function - (L, traces, weights_normalized) + return (L, traces, weights_normalized) end function _maybe_accumulate_param_grad!(trace, optimizer, scale_factor::Real) @@ -117,6 +117,7 @@ update the parameters of `model`. - `callback`: Callback function that takes `(iter, traces, elbo_estimate)` as input, where `iter` is the iteration number and `traces` are samples from `var_model` for that iteration. +- `multithreaded`: if `true`, gradient estimation may use multiple threads. """ function black_box_vi!( model::GenerativeFunction, model_args::Tuple, @@ -125,31 +126,32 @@ function black_box_vi!( var_model::GenerativeFunction, var_model_args::Tuple, var_model_optimizer; iters=1000, samples_per_iter=100, verbose=false, - callback=(iter, traces, elbo_estimate) -> nothing) + callback=(iter, traces, elbo_estimate) -> nothing, + multithreaded=false) var_traces = Vector{Any}(undef, samples_per_iter) model_traces = Vector{Any}(undef, samples_per_iter) + log_weights = Vector{Float64}(undef, samples_per_iter) elbo_history = Vector{Float64}(undef, iters) for iter=1:iters # compute gradient estimate and objective function estimate - elbo_estimate = 0.0 - # TODO multithread (note that this would require accumulate_param_gradients! to be threadsafe) - for sample=1:samples_per_iter - - # accumulate the variational family gradients - (log_weight, var_trace, model_trace) = single_sample_gradient_estimate!( - var_model, var_model_args, - model, model_args, observations, 1/samples_per_iter) - elbo_estimate += (log_weight / samples_per_iter) - - # accumulate the generative model gradients - _maybe_accumulate_param_grad!(model_trace, model_optimizer, 1.0 / samples_per_iter) - - # record the traces - var_traces[sample] = var_trace - model_traces[sample] = model_trace + if multithreaded + Threads.@threads for i in 1:samples_per_iter + black_box_vi_iter!( + var_traces, model_traces, log_weights, i, samples_per_iter, + var_model, var_model_args, + model, model_args, observations, model_optimizer) + end + else + for i in 1:samples_per_iter + black_box_vi_iter!( + var_traces, model_traces, log_weights, i, samples_per_iter, + var_model, var_model_args, + model, model_args, observations, model_optimizer) + end end + elbo_estimate = sum(log_weights) elbo_history[iter] = elbo_estimate # print it @@ -167,9 +169,34 @@ function black_box_vi!( end end - (elbo_history[end], var_traces, elbo_history, model_traces) + return (elbo_history[end], var_traces, elbo_history, model_traces) +end + +function black_box_vi_iter!( + var_traces::Vector, model_traces::Vector, log_weights::Vector{Float64}, + i::Int, n::Int, + var_model::GenerativeFunction, var_model_args::Tuple, + model::GenerativeFunction, model_args::Tuple, + observations::ChoiceMap, + model_optimizer) + + # accumulate the variational family gradients + (log_weight, var_trace, model_trace) = single_sample_gradient_estimate!( + var_model, var_model_args, + model, model_args, observations, 1.0 / n) + log_weights[i] = log_weight / n + + # accumulate the generative model gradients + _maybe_accumulate_param_grad!(model_trace, model_optimizer, 1.0 / n) + + # record the traces + var_traces[i] = var_trace + model_traces[i] = model_trace + + return nothing end + black_box_vi!(model::GenerativeFunction, model_args::Tuple, observations::ChoiceMap, var_model::GenerativeFunction, var_model_args::Tuple, @@ -205,6 +232,7 @@ update the parameters of `model`. - `callback`: Callback function that takes `(iter, traces, elbo_estimate)` as input, where `iter` is the iteration number and `traces` are samples from `var_model` for that iteration. +- `multithreaded`: if `true`, gradient estimation may use multiple threads. """ function black_box_vimco!( model::GenerativeFunction, model_args::Tuple, @@ -212,35 +240,37 @@ function black_box_vimco!( var_model::GenerativeFunction, var_model_args::Tuple, var_model_optimizer::CompositeOptimizer, grad_est_samples::Int; iters=1000, samples_per_iter=100, geometric=true, verbose=false, - callback=(iter, traces, elbo_estimate) -> nothing) + callback=(iter, traces, elbo_estimate) -> nothing, + multithreaded=false) resampled_var_traces = Vector{Any}(undef, samples_per_iter) model_traces = Vector{Any}(undef, samples_per_iter) + log_weights = Vector{Float64}(undef, samples_per_iter) iwelbo_history = Vector{Float64}(undef, iters) for iter=1:iters # compute gradient estimate and objective function estimate - iwelbo_estimate = 0. - for sample=1:samples_per_iter - - # accumulate the variational family gradients - (est, original_var_traces, weights) = multi_sample_gradient_estimate!( - var_model, var_model_args, - model, model_args, observations, grad_est_samples, - 1/samples_per_iter, geometric) - iwelbo_estimate += (est / samples_per_iter) - - # record a variational trace obtained by resampling from the weighted collection - resampled_var_traces[sample] = original_var_traces[categorical(weights)] - - # accumulate the generative model gradient estimator - for (var_trace, weight) in zip(original_var_traces, weights) - constraints = merge(observations, get_choices(var_trace)) - (model_trace, _) = generate(model, model_args, constraints) - _maybe_accumulate_param_grad!(model_trace, model_optimizer, weight / samples_per_iter) + if multithreaded + Threads.@threads for i in 1:samples_per_iter + black_box_vimco_iter!( + resampled_var_traces, log_weights, + i, samples_per_iter, + var_model, var_model_args, model, model_args, + observations, geometric, grad_est_samples, + model_optimizer) + end + else + for i in 1:samples_per_iter + black_box_vimco_iter!( + resampled_var_traces, log_weights, + i, samples_per_iter, + var_model, var_model_args, model, model_args, + observations, geometric, grad_est_samples, + model_optimizer) end end + iwelbo_estimate = sum(log_weights) iwelbo_history[iter] = iwelbo_estimate # print it @@ -262,6 +292,34 @@ function black_box_vimco!( (iwelbo_history[end], resampled_var_traces, iwelbo_history, model_traces) end +function black_box_vimco_iter!( + resampled_var_traces::Vector, log_weights::Vector{Float64}, + i::Int, samples_per_iter::Int, + var_model::GenerativeFunction, var_model_args::Tuple, + model::GenerativeFunction, model_args::Tuple, + observations::ChoiceMap, geometric::Bool, grad_est_samples::Int, + model_optimizer) + + # accumulate the variational family gradients + (est, original_var_traces, weights) = multi_sample_gradient_estimate!( + var_model, var_model_args, + model, model_args, observations, grad_est_samples, + 1/samples_per_iter, geometric) + log_weights[i] = est / samples_per_iter + + # record a variational trace obtained by resampling from the weighted collection + resampled_var_traces[i] = original_var_traces[categorical(weights)] + + # accumulate the generative model gradient estimator + for (var_trace, weight) in zip(original_var_traces, weights) + constraints = merge(observations, get_choices(var_trace)) + (model_trace, _) = generate(model, model_args, constraints) + _maybe_accumulate_param_grad!(model_trace, model_optimizer, weight / samples_per_iter) + end + + return nothing +end + black_box_vimco!(model::GenerativeFunction, model_args::Tuple, observations::ChoiceMap, var_model::GenerativeFunction, var_model_args::Tuple, diff --git a/src/optimization.jl b/src/optimization.jl index 2f45d239..1721058a 100644 --- a/src/optimization.jl +++ b/src/optimization.jl @@ -375,7 +375,7 @@ Thread-safe (multiple threads can increment the gradient of the same parameter c function increment_gradient!( id::Tuple{GenerativeFunction,Symbol}, increment, store::JuliaParameterStore=default_julia_parameter_store) - accumulator = get_gradient_accumulator(store, id) + accumulator = get_gradient_accumulator(id, store) in_place_add!(accumulator, increment) return nothing end @@ -555,5 +555,3 @@ function assess(gen_fn::GenerativeFunction, args::Tuple, choices::ChoiceMap) end propose(gen_fn::GenerativeFunction, args::Tuple) = propose(gen_fn, args, default_parameter_context) - - diff --git a/test/inference/importance_sampling.jl b/test/inference/importance_sampling.jl index 4a4842b2..c2e9ac41 100644 --- a/test/inference/importance_sampling.jl +++ b/test/inference/importance_sampling.jl @@ -15,33 +15,36 @@ n = 4 - (traces, log_weights, lml_est) = importance_sampling(model, (), observations, n) - @test length(traces) == n - @test length(log_weights) == n - @test isapprox(logsumexp(log_weights), 0., atol=1e-14) - @test !isnan(lml_est) - for trace in traces - @test get_choices(trace)[:y] == y + for multithreaded in [false, true] + (traces, log_weights, lml_est) = importance_sampling( + model, (), observations, n; multithreaded=multithreaded) + @test length(traces) == n + @test length(log_weights) == n + @test isapprox(logsumexp(log_weights), 0., atol=1e-14) + @test !isnan(lml_est) + for trace in traces + @test get_choices(trace)[:y] == y + end end - (traces, log_weights, lml_est) = importance_sampling(model, (), observations, proposal, (), n) - @test length(traces) == n - @test length(log_weights) == n - @test isapprox(logsumexp(log_weights), 0., atol=1e-14) - @test !isnan(lml_est) - for trace in traces - @test get_choices(trace)[:y] == y + for multithreaded in [false, true] + (traces, log_weights, lml_est) = importance_sampling( + model, (), observations, proposal, (), n; + multithreaded=multithreaded) + @test length(traces) == n + @test length(log_weights) == n + @test isapprox(logsumexp(log_weights), 0., atol=1e-14) + @test !isnan(lml_est) + for trace in traces + @test get_choices(trace)[:y] == y + end end (trace, lml_est) = importance_resampling(model, (), observations, n) - @test isapprox(logsumexp(log_weights), 0., atol=1e-14) @test !isnan(lml_est) @test get_choices(trace)[:y] == y (trace, lml_est) = importance_resampling(model, (), observations, proposal, (), n) - @test isapprox(logsumexp(log_weights), 0., atol=1e-14) @test !isnan(lml_est) @test get_choices(trace)[:y] == y end - - diff --git a/test/inference/variational.jl b/test/inference/variational.jl index f3e4037f..f4407087 100644 --- a/test/inference/variational.jl +++ b/test/inference/variational.jl @@ -17,40 +17,46 @@ end register_parameters!(approx, [:slope_mu, :slope_log_std, :intercept_mu, :intercept_log_std]) - # to regular black box variational inference - init_parameter!((approx, :slope_mu), 0.0) - init_parameter!((approx, :slope_log_std), 0.0) - init_parameter!((approx, :intercept_mu), 0.0) - init_parameter!((approx, :intercept_log_std), 0.0) - observations = choicemap() - optimizer = init_optimizer(DecayStepGradientDescent(1, 100000), approx) optimizer = init_optimizer(DecayStepGradientDescent(1., 1000), approx) - black_box_vi!(model, (), observations, approx, (), optimizer; - iters=2000, samples_per_iter=100, verbose=false) - slope_mu = get_parameter_value((approx, :slope_mu)) - slope_log_std = get_parameter_value((approx, :slope_log_std)) - intercept_mu = get_parameter_value((approx, :intercept_mu)) - intercept_log_std = get_parameter_value((approx, :intercept_log_std)) - @test isapprox(slope_mu, -1., atol=0.001) - @test isapprox(slope_log_std, 0.5, atol=0.001) - @test isapprox(intercept_mu, 1., atol=0.001) - @test isapprox(intercept_log_std, 2.0, atol=0.001) + + # test regular black box variational inference + for multithreaded in [false, true] + init_parameter!((approx, :slope_mu), 0.0) + init_parameter!((approx, :slope_log_std), 0.0) + init_parameter!((approx, :intercept_mu), 0.0) + init_parameter!((approx, :intercept_log_std), 0.0) + black_box_vi!(model, (), observations, approx, (), optimizer; + iters=2000, samples_per_iter=100, verbose=false, multithreaded=multithreaded) + + slope_mu = get_parameter_value((approx, :slope_mu)) + slope_log_std = get_parameter_value((approx, :slope_log_std)) + intercept_mu = get_parameter_value((approx, :intercept_mu)) + intercept_log_std = get_parameter_value((approx, :intercept_log_std)) + @test isapprox(slope_mu, -1., atol=0.001) + @test isapprox(slope_log_std, 0.5, atol=0.001) + @test isapprox(intercept_mu, 1., atol=0.001) + @test isapprox(intercept_log_std, 2.0, atol=0.001) + end # smoke test for black box variational inference with Monte Carlo objectives - init_parameter!((approx, :slope_mu), 0.0) - init_parameter!((approx, :slope_log_std), 0.0) - init_parameter!((approx, :intercept_mu), 0.0) - init_parameter!((approx, :intercept_log_std), 0.0) - black_box_vimco!(model, (), observations, approx, (), optimizer, 20; - iters=50, samples_per_iter=100, verbose=false, geometric=false) - - init_parameter!((approx, :slope_mu), 0.0) - init_parameter!((approx, :slope_log_std), 0.0) - init_parameter!((approx, :intercept_mu), 0.0) - init_parameter!((approx, :intercept_log_std), 0.0) - black_box_vimco!(model, (), observations, approx, (), optimizer, 20; - iters=50, samples_per_iter=100, verbose=false, geometric=true) + for multithreaded in [false, true] + init_parameter!((approx, :slope_mu), 0.0) + init_parameter!((approx, :slope_log_std), 0.0) + init_parameter!((approx, :intercept_mu), 0.0) + init_parameter!((approx, :intercept_log_std), 0.0) + black_box_vimco!(model, (), observations, approx, (), optimizer, 20; + iters=50, samples_per_iter=100, verbose=false, geometric=false, + multithreaded=multithreaded) + + init_parameter!((approx, :slope_mu), 0.0) + init_parameter!((approx, :slope_log_std), 0.0) + init_parameter!((approx, :intercept_mu), 0.0) + init_parameter!((approx, :intercept_log_std), 0.0) + black_box_vimco!(model, (), observations, approx, (), optimizer, 20; + iters=50, samples_per_iter=100, verbose=false, geometric=true, + multithreaded=multithreaded) + end end diff --git a/test/optimization.jl b/test/optimization.jl index ff6a7bb8..0fae5207 100644 --- a/test/optimization.jl +++ b/test/optimization.jl @@ -70,11 +70,30 @@ end @test get_parameter_value((foo, :phi), store) == [2.0, 3.0] @test Gen.get_value(Gen.get_gradient_accumulator((foo, :phi), store)) == [0.0, 0.0] + # check that the default global Julia store was unaffected + @test_throws KeyError get_parameter_value((foo, :theta)) + @test_throws KeyError get_gradient((foo, :theta)) + @test_throws KeyError increment_gradient!((foo, :theta), 1.0) + # FixedStepGradientDescent + init_parameter!((foo, :theta), 1.0, store) + init_parameter!((foo, :phi), [1.0, 2.0], store) + increment_gradient!((foo, :theta), 2.0, store) + increment_gradient!((foo, :phi), [1.0, 3.0], store) + optimizer = init_optimizer(FixedStepGradientDescent(1e-2), [(foo, :theta)], store) + apply_update!(optimizer) # update just theta + @test get_gradient((foo, :theta), store) == 0.0 + @test get_parameter_value((foo, :theta), store) == 1.0 + (2.0 * 1e-2) + @test get_gradient((foo, :phi), store) == [1.0, 3.0] # unchanged + @test get_parameter_value((foo, :phi), store) == [1.0, 2.0] # unchanged + optimizer = init_optimizer(FixedStepGradientDescent(1e-2), [(foo, :phi)], store) + apply_update!(optimizer) # update just phi + @test get_gradient((foo, :phi), store) == [0.0, 0.0] + @test get_parameter_value((foo, :phi), store) == ([1.0, 2.0] .+ 1e-2 * [1.0, 3.0]) # DecayStepGradientDescent + # TODO - # init_optimizer and apply_update! for FixedStepGradientDescent and DecayStepGradientDescent # default_parameter_context and default_julia_parameter_store end