From 3e4f6950ba8cfd0496c80db5b010af1ef794d3e9 Mon Sep 17 00:00:00 2001 From: femtomc Date: Mon, 16 Nov 2020 19:01:39 -0500 Subject: [PATCH 01/30] Initial work on a Switch combinator. --- src/modeling_library/switch/switch.jl | 51 +++++++++++++++++++++++++++ 1 file changed, 51 insertions(+) create mode 100644 src/modeling_library/switch/switch.jl diff --git a/src/modeling_library/switch/switch.jl b/src/modeling_library/switch/switch.jl new file mode 100644 index 00000000..e1737516 --- /dev/null +++ b/src/modeling_library/switch/switch.jl @@ -0,0 +1,51 @@ +# Trace used by Switch combinator. +struct SwitchTrace{G1, G2, T1, T2, Tr0, Tr} <: Trace + cond_fn::GenerativeFunction{Bool, Tr0} + a::GenerativeFunction{T1, Tr} + b::GenerativeFunction{T2, Tr} + cond::Tr0 + branch::Tr + retval::Union{T1, T2} + args::Tuple + score::Float64 + noise::Float64 +end + +function SwitchTrace{G1, G2, T1, T2, Tr0, Tr1, Tr2}(cond::Generativefunction{Bool}, + a::GenerativeFunction{T1}, + b::GenerativeFunction{T2}, + cond_subtrace::Tr0, + branch_subtrace::Union{Tr1, Tr2}, + retval::Union{T1, T2}, + args::Tuple, + score::Float64 + noise::Float64) where {G1, G2, T1, T2, Tr0, Tr1, Tr2} +end + +@inline get_choices(tr::SwitchTrace) = SwitchTraceChoiceMap(tr) +@inline get_retval(tr::SwitchTrace) = tr.retval +@inline get_args(tr::SwitchTrace) = tr.args +@inline get_score(tr::SwitchTrace) = tr.score +# TODO. @inline get_gen_fn(tr::SwitchTrace) = tr.gen_fn + +@inline function Base.getindex(tr::SwitchTrace, addr::Pair) + (first, rest) = addr + subtr = getfield(trace, first) + subtrace[rest] +end +@inline Base.getindex(tr::SwitchTrace, addr::Symbol) = getfield(trace, addr) + +function project(tr::SwitchTrace, selection::Selection) + weight = 0. + for k in [:cond, :branch] + subselection = selection[k] + weight += project(getindex(tr, k), subselection) + end + weight +end +project(tr::SwitchTrace, ::EmptySelection) = tr.noise + +@inline function get_submap(choices::SwitchTraceChoiceMap, addr::Symbol) + hasfield(choices, addr) || return EmptyChoiceMap() + get_choices(getfield(choices, addr)) +end From bd4f8307c2cd9446a71a66cfb3c44fa8c88db3eb Mon Sep 17 00:00:00 2001 From: femtomc Date: Mon, 16 Nov 2020 20:51:58 -0500 Subject: [PATCH 02/30] Initial implementation of propose and generate. --- src/modeling_library/switch/generate.jl | 47 +++++++++++++++++++++++ src/modeling_library/switch/propose.jl | 31 +++++++++++++++ src/modeling_library/switch/switch.jl | 51 ++++--------------------- src/modeling_library/switch/trace.jl | 38 ++++++++++++++++++ 4 files changed, 123 insertions(+), 44 deletions(-) create mode 100644 src/modeling_library/switch/generate.jl create mode 100644 src/modeling_library/switch/propose.jl create mode 100644 src/modeling_library/switch/trace.jl diff --git a/src/modeling_library/switch/generate.jl b/src/modeling_library/switch/generate.jl new file mode 100644 index 00000000..21a402e6 --- /dev/null +++ b/src/modeling_library/switch/generate.jl @@ -0,0 +1,47 @@ +mutable struct SwitchGenerateState{T1, T2, Tr} + score::Float64 + noise::Float64 + weight::Float64 + cond::Bool + subtrace::Tr + retval::Union{T1, T2} + SwitchGenerateState{T1, T2, Tr}(score::Float64, noise::Float64, weight::Float64) = new{T1, T2, Tr}(score, noise, weight) +end + +function process!(gen_fn::Switch{T1, T2, Tr}, + branch_p::Float64, + args::Tuple, + choices::ChoiceMap, + state::SwitchGenerateState{T1, T2, Tr}) where {T1, T2, Tr} + + # create flip distribution + flip_d = Bernoulli(branch_p) + + # check for constraints at :cond + constrained = has_value(choices, :cond) + !constrained && check_no_submap(choices, :cond) + + # get/constrain flip value + constrained ? (flip = get_value(choices, :cond); state.weight += logpdf(flip_d, flip)) : flip = rand(flip_d) + state.cond = flip + + # generate subtrace + constraints = get_submap(choices, :cond) + (subtrace, weight) = generate(flip ? gen_fn.a : gen_fn.b, args, constraints) + state.subtrace = subtrace + state.weight += weight + + # return from branch + get_retval(subtrace) +end + +function generate(gen_fn::Switch{T1, T2, Tr}, + args::Tuple, + choices::ChoiceMap) where {T1, T2, Tr} + + branch_p = args[1] + state = SwitchGenerateState{T1, T2, Tr}(0.0, 0.0, 0.0) + process!(gen_fn, branch_p, args[2 : end], choices, state) + trace = SwitchTrace{T1, T2, Tr}(gen_fn, branch_p, state.cond, state.subtrace, state.retval, args[2 : end], state.score, state.noise) + (trace, state.weight) +end diff --git a/src/modeling_library/switch/propose.jl b/src/modeling_library/switch/propose.jl new file mode 100644 index 00000000..fb9a2dce --- /dev/null +++ b/src/modeling_library/switch/propose.jl @@ -0,0 +1,31 @@ +mutable struct SwitchProposeState{T} + choices::DynamicChoiceMap + weight::Float64 + retval::T + SwitchProposeState{T}(choices, weight) where T = new{T}(choices, weight) +end + +function process_new!(gen_fn::Switch{T1, T2, Tr}, + branch_p::Float64, + args::Tuple, + state::SwitchProposeState{T}) where T + + flip_d = Bernoulli(branch_p) + flip = rand(flip_d) + (submap, weight, retval) = propose(flip ? gen_fn.a : gen_fn.b, args) + set_value!(state.choices, :cond, flip) + state.weight += logpdf(flip_d, flip) + set_submap!(state.choices, :branch, submap) + state.weight += weight + state.retval = retval +end + +function propose(gen_fn::Switch{T1, T2, Tr}, + args::Tuple) where {T1, T2, Tr} + + branch_p = args[1] + choices = choicemap() + state = SwitchProposeState{Union{T1, T2}}(choices, 0.0) + process_new!(gen_fn, branch_p, args[2:end], state) + (state.choices, state.weight, state.retval) +end diff --git a/src/modeling_library/switch/switch.jl b/src/modeling_library/switch/switch.jl index e1737516..46309e22 100644 --- a/src/modeling_library/switch/switch.jl +++ b/src/modeling_library/switch/switch.jl @@ -1,51 +1,14 @@ -# Trace used by Switch combinator. -struct SwitchTrace{G1, G2, T1, T2, Tr0, Tr} <: Trace - cond_fn::GenerativeFunction{Bool, Tr0} +struct Switch{T1, T2, Tr} <: GenerativeFunction{Union{T1, T2}, SwitchTrace{T1, T2, Tr}} a::GenerativeFunction{T1, Tr} b::GenerativeFunction{T2, Tr} - cond::Tr0 - branch::Tr - retval::Union{T1, T2} - args::Tuple - score::Float64 - noise::Float64 end -function SwitchTrace{G1, G2, T1, T2, Tr0, Tr1, Tr2}(cond::Generativefunction{Bool}, - a::GenerativeFunction{T1}, - b::GenerativeFunction{T2}, - cond_subtrace::Tr0, - branch_subtrace::Union{Tr1, Tr2}, - retval::Union{T1, T2}, - args::Tuple, - score::Float64 - noise::Float64) where {G1, G2, T1, T2, Tr0, Tr1, Tr2} -end - -@inline get_choices(tr::SwitchTrace) = SwitchTraceChoiceMap(tr) -@inline get_retval(tr::SwitchTrace) = tr.retval -@inline get_args(tr::SwitchTrace) = tr.args -@inline get_score(tr::SwitchTrace) = tr.score -# TODO. @inline get_gen_fn(tr::SwitchTrace) = tr.gen_fn +export Switch -@inline function Base.getindex(tr::SwitchTrace, addr::Pair) - (first, rest) = addr - subtr = getfield(trace, first) - subtrace[rest] -end -@inline Base.getindex(tr::SwitchTrace, addr::Symbol) = getfield(trace, addr) - -function project(tr::SwitchTrace, selection::Selection) - weight = 0. - for k in [:cond, :branch] - subselection = selection[k] - weight += project(getindex(tr, k), subselection) - end - weight -end -project(tr::SwitchTrace, ::EmptySelection) = tr.noise +has_argument_grads(switch_fn::Switch) = has_argument_grads(switch_fn.a) && has_argument_grads(switch_fn.b) +accepts_output_grad(switch_fn::Switch) = accepts_output_grad(switch_fn.a) && accepts_output_grad(switch_fn.b) -@inline function get_submap(choices::SwitchTraceChoiceMap, addr::Symbol) - hasfield(choices, addr) || return EmptyChoiceMap() - get_choices(getfield(choices, addr)) +function (gen_fn::Switch)(flip_p::Float64, args...) + (_, _, retval) = propose(gen_fn, (flip_p, args...)) + retval end diff --git a/src/modeling_library/switch/trace.jl b/src/modeling_library/switch/trace.jl new file mode 100644 index 00000000..b391d17e --- /dev/null +++ b/src/modeling_library/switch/trace.jl @@ -0,0 +1,38 @@ +struct SwitchTrace{T1, T2, Tr} <: Trace + kernel::Switch{T1, T2, Tr} + p::Float64 + cond::Bool + branch::Tr + retval::Union{T1, T2} + args::Tuple + score::Float64 + noise::Float64 +end + +@inline get_choices(tr::SwitchTrace) = SwitchTraceChoiceMap(tr) +@inline get_retval(tr::SwitchTrace) = tr.retval +@inline get_args(tr::SwitchTrace) = tr.args +@inline get_score(tr::SwitchTrace) = tr.score +@inline get_gen_fn(tr::SwitchTrace) = tr.kernel + +@inline function Base.getindex(tr::SwitchTrace, addr::Pair) + (first, rest) = addr + subtr = getfield(trace, first) + subtrace[rest] +end +@inline Base.getindex(tr::SwitchTrace, addr::Symbol) = getfield(trace, addr) + +function project(tr::SwitchTrace, selection::Selection) + weight = 0. + for k in [:cond, :branch] + subselection = selection[k] + weight += project(getindex(tr, k), subselection) + end + weight +end +project(tr::SwitchTrace, ::EmptySelection) = tr.noise + +@inline function get_submap(choices::SwitchTraceChoiceMap, addr::Symbol) + hasfield(choices, addr) || return EmptyChoiceMap() + get_choices(getfield(choices, addr)) +end From 374a7b0ec86782fe05668549a0fd9671e88186c0 Mon Sep 17 00:00:00 2001 From: femtomc Date: Mon, 16 Nov 2020 21:05:18 -0500 Subject: [PATCH 03/30] Added implementaton of simulate. --- src/modeling_library/switch/simulate.jl | 35 ++++++++++++++++++++ test/modeling_library/switch.jl | 44 +++++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 src/modeling_library/switch/simulate.jl create mode 100644 test/modeling_library/switch.jl diff --git a/src/modeling_library/switch/simulate.jl b/src/modeling_library/switch/simulate.jl new file mode 100644 index 00000000..9cd1e021 --- /dev/null +++ b/src/modeling_library/switch/simulate.jl @@ -0,0 +1,35 @@ +mutable struct SwitchSimulateState{T1, T2, Tr} + score::Float64 + noise::Float64 + cond::Bool + subtrace::Tr + retval::Union{T1, T2} + SwitchGenerateState{T1, T2, Tr}(score::Float64, noise::Float64) = new{T1, T2, Tr}(score, noise, weight) +end + +function process!(gen_fn::Switch{T1, T2, Tr}, + branch_p::Float64, + args::Tuple, + state::SwitchGenerateState{T1, T2, Tr}) where {T1, T2, Tr} + local subtrace::Tr + local retval::Union{T1, T2} + flip_d = Bernoulli(branch_p) + flip = rand(flip_d) + state.score += logpdf(flip_d, flip) + state.cond = flip + subtrace = simulate(flip ? gen_fn.a : gen_fn.b, args) + state.noise += project(subtrace, EmptySelection()) + state.subtrace = subtrace + state.score += get_score(subtrace) + get_retval(subtrace) +end + +function simulate(gen_fn::Switch{T1, T2, Tr}, + args::Tuple) where {T1, T2, Tr} + + branch_p = args[1] + state = SwitchSimulateState{T1, T2, Tr}(0.0, 0.0) + process!(gen_fn, branch_p, args[2 : end], state) + trace = SwitchTrace{T1, T2, Tr}(gen_fn, branch_p, state.cond, state.subtrace, state.retval, args[2 : end], state.score, state.noise) + (trace, state.weight) +end diff --git a/test/modeling_library/switch.jl b/test/modeling_library/switch.jl new file mode 100644 index 00000000..6343deae --- /dev/null +++ b/test/modeling_library/switch.jl @@ -0,0 +1,44 @@ +@testset "switch combinator" begin + + @gen (grad) function foo((grad)(x::Float64), (grad)(y::Float64)) + @param std::Float64 + z = @trace(normal(x + y, std), :z) + return z + end + + @gen (grad) function baz((grad)(x::Float64), (grad)(y::Float64)) + @param std::Float64 + z = @trace(normal(x + 2 * y, std), :z) + return z + end + + set_param!(foo, :std, 1.) + set_param!(baz, :std, 1.) + + bar = Switch(foo, baz) + args = (1.0, 3.0) + + @testset "simulate" begin + end + + @testset "generate" begin + end + + @testset "propose" begin + end + + @testset "assess" begin + end + + @testset "update" begin + end + + @testset "regenerate" begin + end + + @testset "choice_gradients" begin + end + + @testset "accumulate_param_gradients!" begin + end +end From 5872593eefee9a586fc9241f68c62543116aa108 Mon Sep 17 00:00:00 2001 From: femtomc Date: Tue, 17 Nov 2020 00:12:48 -0500 Subject: [PATCH 04/30] Corrected some bugs with Bernoulli vs bernoulli. --- src/modeling_library/modeling_library.jl | 4 ++++ src/modeling_library/switch/generate.jl | 10 +++++----- src/modeling_library/switch/propose.jl | 7 +++---- src/modeling_library/switch/simulate.jl | 13 ++++++------- src/modeling_library/switch/switch.jl | 6 +++++- src/modeling_library/switch/trace.jl | 14 +++++++------- 6 files changed, 30 insertions(+), 24 deletions(-) diff --git a/src/modeling_library/modeling_library.jl b/src/modeling_library/modeling_library.jl index d0797426..5182e552 100644 --- a/src/modeling_library/modeling_library.jl +++ b/src/modeling_library/modeling_library.jl @@ -66,12 +66,16 @@ include("dist_dsl/dist_dsl.jl") # code shared by vector-shaped combinators include("vector.jl") +# trace for switch combinator +include("switch/trace.jl") + # built-in generative function combinators include("choice_at/choice_at.jl") include("call_at/call_at.jl") include("map/map.jl") include("unfold/unfold.jl") include("recurse/recurse.jl") +include("switch/switch.jl") ############################################################# # abstractions for constructing custom generative functions # diff --git a/src/modeling_library/switch/generate.jl b/src/modeling_library/switch/generate.jl index 21a402e6..78a03815 100644 --- a/src/modeling_library/switch/generate.jl +++ b/src/modeling_library/switch/generate.jl @@ -5,7 +5,7 @@ mutable struct SwitchGenerateState{T1, T2, Tr} cond::Bool subtrace::Tr retval::Union{T1, T2} - SwitchGenerateState{T1, T2, Tr}(score::Float64, noise::Float64, weight::Float64) = new{T1, T2, Tr}(score, noise, weight) + SwitchGenerateState{T1, T2, Tr}(score::Float64, noise::Float64, weight::Float64) where {T1, T2, Tr} = new{T1, T2, Tr}(score, noise, weight) end function process!(gen_fn::Switch{T1, T2, Tr}, @@ -15,24 +15,24 @@ function process!(gen_fn::Switch{T1, T2, Tr}, state::SwitchGenerateState{T1, T2, Tr}) where {T1, T2, Tr} # create flip distribution - flip_d = Bernoulli(branch_p) + flip_d = bernoulli(branch_p) # check for constraints at :cond constrained = has_value(choices, :cond) !constrained && check_no_submap(choices, :cond) # get/constrain flip value - constrained ? (flip = get_value(choices, :cond); state.weight += logpdf(flip_d, flip)) : flip = rand(flip_d) + constrained ? (flip = get_value(choices, :cond); state.weight += logpdf(Bernoulli(), flip, branch_p)) : flip = rand(flip_d) state.cond = flip # generate subtrace - constraints = get_submap(choices, :cond) + constraints = get_submap(choices, :branch) (subtrace, weight) = generate(flip ? gen_fn.a : gen_fn.b, args, constraints) state.subtrace = subtrace state.weight += weight # return from branch - get_retval(subtrace) + state.retval = get_retval(subtrace) end function generate(gen_fn::Switch{T1, T2, Tr}, diff --git a/src/modeling_library/switch/propose.jl b/src/modeling_library/switch/propose.jl index fb9a2dce..ccfa1640 100644 --- a/src/modeling_library/switch/propose.jl +++ b/src/modeling_library/switch/propose.jl @@ -8,13 +8,12 @@ end function process_new!(gen_fn::Switch{T1, T2, Tr}, branch_p::Float64, args::Tuple, - state::SwitchProposeState{T}) where T + state::SwitchProposeState{Union{T1, T2}}) where {T1, T2, Tr} - flip_d = Bernoulli(branch_p) - flip = rand(flip_d) + flip = bernoulli(branch_p) (submap, weight, retval) = propose(flip ? gen_fn.a : gen_fn.b, args) set_value!(state.choices, :cond, flip) - state.weight += logpdf(flip_d, flip) + state.weight += logpdf(Bernoulli(), flip, branch_p) set_submap!(state.choices, :branch, submap) state.weight += weight state.retval = retval diff --git a/src/modeling_library/switch/simulate.jl b/src/modeling_library/switch/simulate.jl index 9cd1e021..bbaccc4c 100644 --- a/src/modeling_library/switch/simulate.jl +++ b/src/modeling_library/switch/simulate.jl @@ -4,24 +4,23 @@ mutable struct SwitchSimulateState{T1, T2, Tr} cond::Bool subtrace::Tr retval::Union{T1, T2} - SwitchGenerateState{T1, T2, Tr}(score::Float64, noise::Float64) = new{T1, T2, Tr}(score, noise, weight) + SwitchSimulateState{T1, T2, Tr}(score::Float64, noise::Float64) where {T1, T2, Tr} = new{T1, T2, Tr}(score, noise) end function process!(gen_fn::Switch{T1, T2, Tr}, branch_p::Float64, args::Tuple, - state::SwitchGenerateState{T1, T2, Tr}) where {T1, T2, Tr} + state::SwitchSimulateState{T1, T2, Tr}) where {T1, T2, Tr} local subtrace::Tr local retval::Union{T1, T2} - flip_d = Bernoulli(branch_p) - flip = rand(flip_d) - state.score += logpdf(flip_d, flip) + flip = bernoulli(branch_p) + state.score += logpdf(Bernoulli(), flip, branch_p) state.cond = flip subtrace = simulate(flip ? gen_fn.a : gen_fn.b, args) state.noise += project(subtrace, EmptySelection()) state.subtrace = subtrace state.score += get_score(subtrace) - get_retval(subtrace) + state.retval = get_retval(subtrace) end function simulate(gen_fn::Switch{T1, T2, Tr}, @@ -31,5 +30,5 @@ function simulate(gen_fn::Switch{T1, T2, Tr}, state = SwitchSimulateState{T1, T2, Tr}(0.0, 0.0) process!(gen_fn, branch_p, args[2 : end], state) trace = SwitchTrace{T1, T2, Tr}(gen_fn, branch_p, state.cond, state.subtrace, state.retval, args[2 : end], state.score, state.noise) - (trace, state.weight) + trace end diff --git a/src/modeling_library/switch/switch.jl b/src/modeling_library/switch/switch.jl index 46309e22..14d72d6e 100644 --- a/src/modeling_library/switch/switch.jl +++ b/src/modeling_library/switch/switch.jl @@ -1,4 +1,4 @@ -struct Switch{T1, T2, Tr} <: GenerativeFunction{Union{T1, T2}, SwitchTrace{T1, T2, Tr}} +struct Switch{T1, T2, Tr} <: GenerativeFunction{Union{T1, T2}, Tr} a::GenerativeFunction{T1, Tr} b::GenerativeFunction{T2, Tr} end @@ -12,3 +12,7 @@ function (gen_fn::Switch)(flip_p::Float64, args...) (_, _, retval) = propose(gen_fn, (flip_p, args...)) retval end + +include("propose.jl") +include("simulate.jl") +include("generate.jl") diff --git a/src/modeling_library/switch/trace.jl b/src/modeling_library/switch/trace.jl index b391d17e..c9f2ca5c 100644 --- a/src/modeling_library/switch/trace.jl +++ b/src/modeling_library/switch/trace.jl @@ -1,5 +1,5 @@ struct SwitchTrace{T1, T2, Tr} <: Trace - kernel::Switch{T1, T2, Tr} + kernel::GenerativeFunction{Union{T1, T2}, Tr} p::Float64 cond::Bool branch::Tr @@ -9,7 +9,12 @@ struct SwitchTrace{T1, T2, Tr} <: Trace noise::Float64 end -@inline get_choices(tr::SwitchTrace) = SwitchTraceChoiceMap(tr) +@inline function get_choices(tr::SwitchTrace) + choices = choicemap() + set_submap!(choices, :branch, get_choices(tr.branch)) + set_value!(choices, :cond, tr.cond) + choices +end @inline get_retval(tr::SwitchTrace) = tr.retval @inline get_args(tr::SwitchTrace) = tr.args @inline get_score(tr::SwitchTrace) = tr.score @@ -31,8 +36,3 @@ function project(tr::SwitchTrace, selection::Selection) weight end project(tr::SwitchTrace, ::EmptySelection) = tr.noise - -@inline function get_submap(choices::SwitchTraceChoiceMap, addr::Symbol) - hasfield(choices, addr) || return EmptyChoiceMap() - get_choices(getfield(choices, addr)) -end From 9c0a9f21a320ca6683f3c588ca2a361f25d508de Mon Sep 17 00:00:00 2001 From: femtomc Date: Tue, 17 Nov 2020 00:20:39 -0500 Subject: [PATCH 05/30] Added assess implementation. --- scratch/switch_comb.jl | 30 +++++++++++++++++++++++++++ src/modeling_library/switch/assess.jl | 26 +++++++++++++++++++++++ src/modeling_library/switch/switch.jl | 1 + 3 files changed, 57 insertions(+) create mode 100644 scratch/switch_comb.jl create mode 100644 src/modeling_library/switch/assess.jl diff --git a/scratch/switch_comb.jl b/scratch/switch_comb.jl new file mode 100644 index 00000000..6193d548 --- /dev/null +++ b/scratch/switch_comb.jl @@ -0,0 +1,30 @@ +module SwitchComb + +include("../src/Gen.jl") +using .Gen + +@gen (grad) function foo((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + y, std), :z) + return z +end + +@gen (grad) function baz((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + 2 * y, std), :z) + return z +end + +sc = Switch(foo, baz) +chm, _, _ = propose(sc, (0.3, 5.0, 3.0)) +display(chm) + +tr = simulate(sc, (0.3, 5.0, 3.0)) +display(get_choices(tr)) + +chm = choicemap() +chm[:cond] = true +tr, _ = generate(sc, (0.3, 5.0, 3.0), chm) +display(get_choices(tr)) + +end # module diff --git a/src/modeling_library/switch/assess.jl b/src/modeling_library/switch/assess.jl new file mode 100644 index 00000000..054a7137 --- /dev/null +++ b/src/modeling_library/switch/assess.jl @@ -0,0 +1,26 @@ +mutable struct SwitchAssessState{T} + weight::Float64 + retval::T +end + +function process_new!(gen_fn::Switch{T1, T2, Tr}, + branch_p::Float64, + args::Tuple, + choices::ChoiceMap, + state::SwitchAssessState{Union{T1, T2}}) where {T1, T2, Tr} + flip = get_value(choices, :cond) + state.weight += logpdf(Bernoulli(), flip, branch_p) + submap = get_submap(choices, :branch) + (weight, retval) = assess(gen_fn.kernel, kernel_args, submap) + state.weight += weight + state.retval = retval +end + +function assess(gen_fn::Switch{T1, T2, Tr}, + args::Tuple, + choices::ChoiceMap) where {T1, T2, Tr} + branch_p = args[1] + state = SwitchAssessState{Union{T1, T2}}(0.0) + process_new!(gen_fn, branch_p, args[2 : end], choices, state) + (state.weight, state.retval) +end diff --git a/src/modeling_library/switch/switch.jl b/src/modeling_library/switch/switch.jl index 14d72d6e..1170d3a2 100644 --- a/src/modeling_library/switch/switch.jl +++ b/src/modeling_library/switch/switch.jl @@ -13,6 +13,7 @@ function (gen_fn::Switch)(flip_p::Float64, args...) retval end +include("assess.jl") include("propose.jl") include("simulate.jl") include("generate.jl") From 95baf0781a7486e0f8360b99593638c24b38e384 Mon Sep 17 00:00:00 2001 From: femtomc Date: Tue, 17 Nov 2020 19:44:32 -0500 Subject: [PATCH 06/30] Split into two combinators: Switch and WithProbability implementations. --- scratch/switch_comb.jl | 8 +-- src/modeling_library/cond.jl | 72 +++++++++++++++++++++ src/modeling_library/modeling_library.jl | 5 +- src/modeling_library/switch/assess.jl | 19 +++--- src/modeling_library/switch/generate.jl | 45 +++++-------- src/modeling_library/switch/propose.jl | 23 +++---- src/modeling_library/switch/simulate.jl | 36 +++++------ src/modeling_library/switch/switch.jl | 18 ++++-- src/modeling_library/switch/trace.jl | 38 ----------- src/modeling_library/switch/update.jl | 56 ++++++++++++++++ src/modeling_library/with_prob/assess.jl | 26 ++++++++ src/modeling_library/with_prob/generate.jl | 47 ++++++++++++++ src/modeling_library/with_prob/propose.jl | 30 +++++++++ src/modeling_library/with_prob/simulate.jl | 33 ++++++++++ src/modeling_library/with_prob/update.jl | 54 ++++++++++++++++ src/modeling_library/with_prob/with_prob.jl | 19 ++++++ 16 files changed, 404 insertions(+), 125 deletions(-) create mode 100644 src/modeling_library/cond.jl delete mode 100644 src/modeling_library/switch/trace.jl create mode 100644 src/modeling_library/switch/update.jl create mode 100644 src/modeling_library/with_prob/assess.jl create mode 100644 src/modeling_library/with_prob/generate.jl create mode 100644 src/modeling_library/with_prob/propose.jl create mode 100644 src/modeling_library/with_prob/simulate.jl create mode 100644 src/modeling_library/with_prob/update.jl create mode 100644 src/modeling_library/with_prob/with_prob.jl diff --git a/scratch/switch_comb.jl b/scratch/switch_comb.jl index 6193d548..952a8f82 100644 --- a/scratch/switch_comb.jl +++ b/scratch/switch_comb.jl @@ -16,15 +16,15 @@ end end sc = Switch(foo, baz) -chm, _, _ = propose(sc, (0.3, 5.0, 3.0)) +chm, _, _ = propose(sc, (2, 5.0, 3.0)) display(chm) -tr = simulate(sc, (0.3, 5.0, 3.0)) +tr = simulate(sc, (2, 5.0, 3.0)) display(get_choices(tr)) chm = choicemap() -chm[:cond] = true -tr, _ = generate(sc, (0.3, 5.0, 3.0), chm) +chm[:z] = 5.0 +tr, _ = generate(sc, (2, 5.0, 3.0), chm) display(get_choices(tr)) end # module diff --git a/src/modeling_library/cond.jl b/src/modeling_library/cond.jl new file mode 100644 index 00000000..c77846d4 --- /dev/null +++ b/src/modeling_library/cond.jl @@ -0,0 +1,72 @@ +# ------------ WithProbability trace ------------ # + +struct WithProbabilityTrace{T1, T2, Tr} <: Trace + kernel::GenerativeFunction{Union{T1, T2}, Tr} + p::Float64 + cond::Bool + branch::Tr + retval::Union{T1, T2} + args::Tuple + score::Float64 + noise::Float64 +end + +@inline function get_choices(tr::WithProbabilityTrace) + choices = choicemap() + set_submap!(choices, :branch, get_choices(tr.branch)) + set_value!(choices, :cond, tr.cond) + choices +end +@inline get_retval(tr::WithProbabilityTrace) = tr.retval +@inline get_args(tr::WithProbabilityTrace) = tr.args +@inline get_score(tr::WithProbabilityTrace) = tr.score +@inline get_gen_fn(tr::WithProbabilityTrace) = tr.kernel + +@inline function Base.getindex(tr::WithProbabilityTrace, addr::Pair) + (first, rest) = addr + subtr = getfield(trace, first) + subtrace[rest] +end +@inline Base.getindex(tr::WithProbabilityTrace, addr::Symbol) = getfield(trace, addr) + +function project(tr::WithProbabilityTrace, selection::Selection) + weight = 0. + for k in [:cond, :branch] + subselection = selection[k] + weight += project(getindex(tr, k), subselection) + end + weight +end +project(tr::WithProbabilityTrace, ::EmptySelection) = tr.noise + +# ------------ Switch trace ------------ # + +struct SwitchTrace{T} <: Trace + kernel::GenerativeFunction{T} + index::Int + branch::Trace + retval::T + args::Tuple + score::Float64 + noise::Float64 +end + +@inline get_choices(tr::SwitchTrace) = get_choices(tr.branch) +@inline get_retval(tr::SwitchTrace) = tr.retval +@inline get_args(tr::SwitchTrace) = tr.args +@inline get_score(tr::SwitchTrace) = tr.score +@inline get_gen_fn(tr::SwitchTrace) = tr.kernel + +@inline function Base.getindex(tr::SwitchTrace, addr::Pair) + (first, rest) = addr + subtr = getfield(trace, first) + subtrace[rest] +end +@inline Base.getindex(tr::SwitchTrace, addr::Symbol) = getfield(trace, addr) + +function project(tr::SwitchTrace, selection::Selection) + weight = 0. + weight += project(tr.branch, selection) + weight +end +project(tr::SwitchTrace, ::EmptySelection) = tr.noise diff --git a/src/modeling_library/modeling_library.jl b/src/modeling_library/modeling_library.jl index 5182e552..a6004510 100644 --- a/src/modeling_library/modeling_library.jl +++ b/src/modeling_library/modeling_library.jl @@ -66,8 +66,8 @@ include("dist_dsl/dist_dsl.jl") # code shared by vector-shaped combinators include("vector.jl") -# trace for switch combinator -include("switch/trace.jl") +# traces for with prob/switch combinator +include("cond.jl") # built-in generative function combinators include("choice_at/choice_at.jl") @@ -76,6 +76,7 @@ include("map/map.jl") include("unfold/unfold.jl") include("recurse/recurse.jl") include("switch/switch.jl") +include("with_prob/with_prob.jl") ############################################################# # abstractions for constructing custom generative functions # diff --git a/src/modeling_library/switch/assess.jl b/src/modeling_library/switch/assess.jl index 054a7137..642bde05 100644 --- a/src/modeling_library/switch/assess.jl +++ b/src/modeling_library/switch/assess.jl @@ -3,24 +3,21 @@ mutable struct SwitchAssessState{T} retval::T end -function process_new!(gen_fn::Switch{T1, T2, Tr}, +function process_new!(gen_fn::Switch{N, K, T}, branch_p::Float64, args::Tuple, choices::ChoiceMap, - state::SwitchAssessState{Union{T1, T2}}) where {T1, T2, Tr} - flip = get_value(choices, :cond) - state.weight += logpdf(Bernoulli(), flip, branch_p) - submap = get_submap(choices, :branch) - (weight, retval) = assess(gen_fn.kernel, kernel_args, submap) + state::SwitchAssessState{T}) where {N, K, T} + (weight, retval) = assess(getindex(gen_fn.mix, index), kernel_args, choices) state.weight += weight state.retval = retval end -function assess(gen_fn::Switch{T1, T2, Tr}, +function assess(gen_fn::Switch{N, K, T}, args::Tuple, - choices::ChoiceMap) where {T1, T2, Tr} - branch_p = args[1] - state = SwitchAssessState{Union{T1, T2}}(0.0) - process_new!(gen_fn, branch_p, args[2 : end], choices, state) + choices::ChoiceMap) where {N, K, T} + index = args[1] + state = SwitchAssessState{T}(0.0) + process_new!(gen_fn, index, args[2 : end], choices, state) (state.weight, state.retval) end diff --git a/src/modeling_library/switch/generate.jl b/src/modeling_library/switch/generate.jl index 78a03815..fb77320e 100644 --- a/src/modeling_library/switch/generate.jl +++ b/src/modeling_library/switch/generate.jl @@ -1,47 +1,32 @@ -mutable struct SwitchGenerateState{T1, T2, Tr} +mutable struct SwitchGenerateState{T} score::Float64 noise::Float64 weight::Float64 - cond::Bool - subtrace::Tr - retval::Union{T1, T2} - SwitchGenerateState{T1, T2, Tr}(score::Float64, noise::Float64, weight::Float64) where {T1, T2, Tr} = new{T1, T2, Tr}(score, noise, weight) + index::Int + subtrace::Trace + retval::T + SwitchGenerateState{T}(score::Float64, noise::Float64, weight::Float64) where T = new{T}(score, noise, weight) end -function process!(gen_fn::Switch{T1, T2, Tr}, - branch_p::Float64, +function process!(gen_fn::Switch{N, K, T}, + index::Int, args::Tuple, choices::ChoiceMap, - state::SwitchGenerateState{T1, T2, Tr}) where {T1, T2, Tr} + state::SwitchGenerateState{T}) where {N, K, T} - # create flip distribution - flip_d = bernoulli(branch_p) - - # check for constraints at :cond - constrained = has_value(choices, :cond) - !constrained && check_no_submap(choices, :cond) - - # get/constrain flip value - constrained ? (flip = get_value(choices, :cond); state.weight += logpdf(Bernoulli(), flip, branch_p)) : flip = rand(flip_d) - state.cond = flip - - # generate subtrace - constraints = get_submap(choices, :branch) - (subtrace, weight) = generate(flip ? gen_fn.a : gen_fn.b, args, constraints) + (subtrace, weight) = generate(getindex(gen_fn.mix, index), args, choices) state.subtrace = subtrace state.weight += weight - - # return from branch state.retval = get_retval(subtrace) end -function generate(gen_fn::Switch{T1, T2, Tr}, +function generate(gen_fn::Switch{N, K, T}, args::Tuple, - choices::ChoiceMap) where {T1, T2, Tr} + choices::ChoiceMap) where {N, K, T} - branch_p = args[1] - state = SwitchGenerateState{T1, T2, Tr}(0.0, 0.0, 0.0) - process!(gen_fn, branch_p, args[2 : end], choices, state) - trace = SwitchTrace{T1, T2, Tr}(gen_fn, branch_p, state.cond, state.subtrace, state.retval, args[2 : end], state.score, state.noise) + index = args[1] + state = SwitchGenerateState{T}(0.0, 0.0, 0.0) + process!(gen_fn, index, args[2 : end], choices, state) + trace = SwitchTrace{T}(gen_fn, index, state.subtrace, state.retval, args[2 : end], state.score, state.noise) (trace, state.weight) end diff --git a/src/modeling_library/switch/propose.jl b/src/modeling_library/switch/propose.jl index ccfa1640..a1253470 100644 --- a/src/modeling_library/switch/propose.jl +++ b/src/modeling_library/switch/propose.jl @@ -5,26 +5,23 @@ mutable struct SwitchProposeState{T} SwitchProposeState{T}(choices, weight) where T = new{T}(choices, weight) end -function process_new!(gen_fn::Switch{T1, T2, Tr}, - branch_p::Float64, +function process_new!(gen_fn::Switch{N, K, T}, + index::Int, args::Tuple, - state::SwitchProposeState{Union{T1, T2}}) where {T1, T2, Tr} + state::SwitchProposeState{T}) where {N, K, T} - flip = bernoulli(branch_p) - (submap, weight, retval) = propose(flip ? gen_fn.a : gen_fn.b, args) - set_value!(state.choices, :cond, flip) - state.weight += logpdf(Bernoulli(), flip, branch_p) - set_submap!(state.choices, :branch, submap) + (submap, weight, retval) = propose(getindex(gen_fn.mix, index), args) + state.choices = submap state.weight += weight state.retval = retval end -function propose(gen_fn::Switch{T1, T2, Tr}, - args::Tuple) where {T1, T2, Tr} +function propose(gen_fn::Switch{N, K, T}, + args::Tuple) where {N, K, T} - branch_p = args[1] + index = args[1] choices = choicemap() - state = SwitchProposeState{Union{T1, T2}}(choices, 0.0) - process_new!(gen_fn, branch_p, args[2:end], state) + state = SwitchProposeState{T}(choices, 0.0) + process_new!(gen_fn, index, args[2:end], state) (state.choices, state.weight, state.retval) end diff --git a/src/modeling_library/switch/simulate.jl b/src/modeling_library/switch/simulate.jl index bbaccc4c..fc203dff 100644 --- a/src/modeling_library/switch/simulate.jl +++ b/src/modeling_library/switch/simulate.jl @@ -1,34 +1,30 @@ -mutable struct SwitchSimulateState{T1, T2, Tr} +mutable struct SwitchSimulateState{T} score::Float64 noise::Float64 - cond::Bool - subtrace::Tr - retval::Union{T1, T2} - SwitchSimulateState{T1, T2, Tr}(score::Float64, noise::Float64) where {T1, T2, Tr} = new{T1, T2, Tr}(score, noise) + index::Int + subtrace::Trace + retval::T + SwitchSimulateState{T}(score::Float64, noise::Float64) where T = new{T}(score, noise) end -function process!(gen_fn::Switch{T1, T2, Tr}, - branch_p::Float64, +function process!(gen_fn::Switch{N, K, T}, + index::Int, args::Tuple, - state::SwitchSimulateState{T1, T2, Tr}) where {T1, T2, Tr} - local subtrace::Tr - local retval::Union{T1, T2} - flip = bernoulli(branch_p) - state.score += logpdf(Bernoulli(), flip, branch_p) - state.cond = flip - subtrace = simulate(flip ? gen_fn.a : gen_fn.b, args) + state::SwitchSimulateState{T}) where {N, K, T} + local retval::T + subtrace = simulate(getindex(gen_fn.mix, index), args) state.noise += project(subtrace, EmptySelection()) state.subtrace = subtrace state.score += get_score(subtrace) state.retval = get_retval(subtrace) end -function simulate(gen_fn::Switch{T1, T2, Tr}, - args::Tuple) where {T1, T2, Tr} +function simulate(gen_fn::Switch{N, K, T}, + args::Tuple) where {N, K, T} - branch_p = args[1] - state = SwitchSimulateState{T1, T2, Tr}(0.0, 0.0) - process!(gen_fn, branch_p, args[2 : end], state) - trace = SwitchTrace{T1, T2, Tr}(gen_fn, branch_p, state.cond, state.subtrace, state.retval, args[2 : end], state.score, state.noise) + index = args[1] + state = SwitchSimulateState{T}(0.0, 0.0) + process!(gen_fn, index, args[2 : end], state) + trace = SwitchTrace{T}(gen_fn, index, state.subtrace, state.retval, args[2 : end], state.score, state.noise) trace end diff --git a/src/modeling_library/switch/switch.jl b/src/modeling_library/switch/switch.jl index 1170d3a2..0e368d17 100644 --- a/src/modeling_library/switch/switch.jl +++ b/src/modeling_library/switch/switch.jl @@ -1,15 +1,19 @@ -struct Switch{T1, T2, Tr} <: GenerativeFunction{Union{T1, T2}, Tr} - a::GenerativeFunction{T1, Tr} - b::GenerativeFunction{T2, Tr} +struct Switch{N, K, T} <: GenerativeFunction{T, Trace} + mix::NTuple{N, GenerativeFunction{T}} + function Switch(gen_fns::GenerativeFunction...) + @assert !isempty(gen_fns) + rettype = get_return_type(getindex(gen_fns, 1)) + new{length(gen_fns), typeof(gen_fns), rettype}(gen_fns) + end end export Switch -has_argument_grads(switch_fn::Switch) = has_argument_grads(switch_fn.a) && has_argument_grads(switch_fn.b) -accepts_output_grad(switch_fn::Switch) = accepts_output_grad(switch_fn.a) && accepts_output_grad(switch_fn.b) +has_argument_grads(switch_fn::Switch) = all(has_argument_grads, switch.mix) +accepts_output_grad(switch_fn::Switch) = all(accepts_output_grad, switch.mix) -function (gen_fn::Switch)(flip_p::Float64, args...) - (_, _, retval) = propose(gen_fn, (flip_p, args...)) +function (gen_fn::Switch)(index::Int, args...) + (_, _, retval) = propose(gen_fn, (index, args...)) retval end diff --git a/src/modeling_library/switch/trace.jl b/src/modeling_library/switch/trace.jl deleted file mode 100644 index c9f2ca5c..00000000 --- a/src/modeling_library/switch/trace.jl +++ /dev/null @@ -1,38 +0,0 @@ -struct SwitchTrace{T1, T2, Tr} <: Trace - kernel::GenerativeFunction{Union{T1, T2}, Tr} - p::Float64 - cond::Bool - branch::Tr - retval::Union{T1, T2} - args::Tuple - score::Float64 - noise::Float64 -end - -@inline function get_choices(tr::SwitchTrace) - choices = choicemap() - set_submap!(choices, :branch, get_choices(tr.branch)) - set_value!(choices, :cond, tr.cond) - choices -end -@inline get_retval(tr::SwitchTrace) = tr.retval -@inline get_args(tr::SwitchTrace) = tr.args -@inline get_score(tr::SwitchTrace) = tr.score -@inline get_gen_fn(tr::SwitchTrace) = tr.kernel - -@inline function Base.getindex(tr::SwitchTrace, addr::Pair) - (first, rest) = addr - subtr = getfield(trace, first) - subtrace[rest] -end -@inline Base.getindex(tr::SwitchTrace, addr::Symbol) = getfield(trace, addr) - -function project(tr::SwitchTrace, selection::Selection) - weight = 0. - for k in [:cond, :branch] - subselection = selection[k] - weight += project(getindex(tr, k), subselection) - end - weight -end -project(tr::SwitchTrace, ::EmptySelection) = tr.noise diff --git a/src/modeling_library/switch/update.jl b/src/modeling_library/switch/update.jl new file mode 100644 index 00000000..f7b92ea5 --- /dev/null +++ b/src/modeling_library/switch/update.jl @@ -0,0 +1,56 @@ +mutable struct SwitchUpdateState{T,U} + weight::Float64 + score::Float64 + noise::Float64 + cond::Bool + subtrace::U + retval::T + discard::DynamicChoiceMap + updated_retdiff::Diff +end + +function process!(gen_fn::Switch{T1, T2, Tr}, + branch_p::Float64, + args::Tuple, + choices::ChoiceMap, + kernel_argdiffs::Tuple, + state::SwitchUpdateState{Union{T1, T2}, Tr}) where {T1, T2, Tr} + local subtrace::Tr + local prev_subtrace::Tr + local retval::T + + # get new subtrace with recursive call to update() + submap = get_submap(choices, :branch) + prev_subtrace = state.subtrace + (subtrace, weight, retdiff, discard) = update(prev_subtrace, kernel_args, kernel_argdiffs, submap) + + # retrieve retdiff + if retdiff != NoChange() + state.updated_retdiff = retdiff + end + + # update state + state.weight += weight + set_submap!(state.discard, key, discard) + state.score += (get_score(subtrace) - get_score(prev_subtrace)) + state.noise += (project(subtrace, EmptySelection()) - project(prev_subtrace, EmptySelection())) + state.subtraces = assoc(state.subtraces, key, subtrace) + retval = get_retval(subtrace) + state.retval = assoc(state.retval, key, retval) + subtrace_empty = isempty(get_choices(subtrace)) + prev_subtrace_empty = isempty(get_choices(prev_subtrace)) + if !subtrace_empty && prev_subtrace_empty + state.num_nonempty += 1 + elseif subtrace_empty && !prev_subtrace_empty + state.num_nonempty -= 1 + end +end + +function update(trace::Switch{T1, T2, Tr}, + args::Tuple, + argdiffs::Tuple, + choices::ChoiceMap) where {T1, T2, Tr} + gen_fn = trace.gen_fn + branch_p = args[1] + return (new_trace, state.weight, retdiff, discard) +end diff --git a/src/modeling_library/with_prob/assess.jl b/src/modeling_library/with_prob/assess.jl new file mode 100644 index 00000000..da9a77a7 --- /dev/null +++ b/src/modeling_library/with_prob/assess.jl @@ -0,0 +1,26 @@ +mutable struct WithProbabilityAssessState{T} + weight::Float64 + retval::T +end + +function process_new!(gen_fn::WithProbability{T}, + branch_p::Float64, + args::Tuple, + choices::ChoiceMap, + state::WithProbabilityAssessState{T}) where T + flip = get_value(choices, :cond) + state.weight += logpdf(Bernoulli(), flip, branch_p) + submap = get_submap(choices, :branch) + (weight, retval) = assess(gen_fn.kernel, kernel_args, submap) + state.weight += weight + state.retval = retval +end + +function assess(gen_fn::WithProbability{T}, + args::Tuple, + choices::ChoiceMap) where T + branch_p = args[1] + state = WithProbabilityAssessState{T}(0.0) + process_new!(gen_fn, branch_p, args[2 : end], choices, state) + (state.weight, state.retval) +end diff --git a/src/modeling_library/with_prob/generate.jl b/src/modeling_library/with_prob/generate.jl new file mode 100644 index 00000000..45d5540a --- /dev/null +++ b/src/modeling_library/with_prob/generate.jl @@ -0,0 +1,47 @@ +mutable struct WithProbabilityGenerateState{T} + score::Float64 + noise::Float64 + weight::Float64 + cond::Bool + subtrace::Trace + retval::T + WithProbabilityGenerateState{T}(score::Float64, noise::Float64, weight::Float64) where T = new{T}(score, noise, weight) +end + +function process!(gen_fn::WithProbability{T}, + branch_p::Float64, + args::Tuple, + choices::ChoiceMap, + state::WithProbabilityGenerateState{T}) where T + + # sample from Bernoulli with probability branch_p + flip_d = bernoulli(branch_p) + + # check for constraints at :cond + constrained = has_value(choices, :cond) + !constrained && check_no_submap(choices, :cond) + + # get/constrain flip value + constrained ? (flip = get_value(choices, :cond); state.weight += logpdf(Bernoulli(), flip, branch_p)) : flip = rand(flip_d) + state.cond = flip + + # generate subtrace + constraints = get_submap(choices, :branch) + (subtrace, weight) = generate(flip ? gen_fn.a : gen_fn.b, args, constraints) + state.subtrace = subtrace + state.weight += weight + + # return from branch + state.retval = get_retval(subtrace) +end + +function generate(gen_fn::WithProbability{T}, + args::Tuple, + choices::ChoiceMap) where T + + branch_p = args[1] + state = WithProbabilityGenerateState{T}(0.0, 0.0, 0.0) + process!(gen_fn, branch_p, args[2 : end], choices, state) + trace = WithProbabilityTrace{T}(gen_fn, branch_p, state.cond, state.subtrace, state.retval, args[2 : end], state.score, state.noise) + (trace, state.weight) +end diff --git a/src/modeling_library/with_prob/propose.jl b/src/modeling_library/with_prob/propose.jl new file mode 100644 index 00000000..fd94b7db --- /dev/null +++ b/src/modeling_library/with_prob/propose.jl @@ -0,0 +1,30 @@ +mutable struct WithProbabilityProposeState{T} + choices::DynamicChoiceMap + weight::Float64 + retval::T + WithProbabilityProposeState{T}(choices, weight) where T = new{T}(choices, weight) +end + +function process_new!(gen_fn::WithProbability{T}, + branch_p::Float64, + args::Tuple, + state::WithProbabilityProposeState{T}) where T + + flip = bernoulli(branch_p) + (submap, weight, retval) = propose(flip ? gen_fn.a : gen_fn.b, args) + set_value!(state.choices, :cond, flip) + state.weight += logpdf(Bernoulli(), flip, branch_p) + set_submap!(state.choices, :branch, submap) + state.weight += weight + state.retval = retval +end + +function propose(gen_fn::WithProbability{T}, + args::Tuple) where T + + branch_p = args[1] + choices = choicemap() + state = WithProbabilityProposeState{T}(choices, 0.0) + process_new!(gen_fn, branch_p, args[2:end], state) + (state.choices, state.weight, state.retval) +end diff --git a/src/modeling_library/with_prob/simulate.jl b/src/modeling_library/with_prob/simulate.jl new file mode 100644 index 00000000..2c8d40fa --- /dev/null +++ b/src/modeling_library/with_prob/simulate.jl @@ -0,0 +1,33 @@ +mutable struct WithProbabilitySimulateState{T} + score::Float64 + noise::Float64 + cond::Bool + subtrace::Trace + retval::T + WithProbabilitySimulateState{T}(score::Float64, noise::Float64) where T = new{T}(score, noise) +end + +function process!(gen_fn::WithProbability{T}, + branch_p::Float64, + args::Tuple, + state::WithProbabilitySimulateState{T}) where T + local retval::T + flip = bernoulli(branch_p) + state.score += logpdf(Bernoulli(), flip, branch_p) + state.cond = flip + subtrace = simulate(flip ? gen_fn.a : gen_fn.b, args) + state.noise += project(subtrace, EmptySelection()) + state.subtrace = subtrace + state.score += get_score(subtrace) + state.retval = get_retval(subtrace) +end + +function simulate(gen_fn::WithProbability{T}, + args::Tuple) where T + + branch_p = args[1] + state = WithProbabilitySimulateState{T}(0.0, 0.0) + process!(gen_fn, branch_p, args[2 : end], state) + trace = WithProbabilityTrace{T}(gen_fn, branch_p, state.cond, state.subtrace, state.retval, args[2 : end], state.score, state.noise) + trace +end diff --git a/src/modeling_library/with_prob/update.jl b/src/modeling_library/with_prob/update.jl new file mode 100644 index 00000000..6278f527 --- /dev/null +++ b/src/modeling_library/with_prob/update.jl @@ -0,0 +1,54 @@ +mutable struct WithProbabilityUpdateState{T} + weight::Float64 + score::Float64 + noise::Float64 + cond::Bool + subtrace::Trace + retval::T + discard::DynamicChoiceMap + updated_retdiff::Diff +end + +function process!(gen_fn::WithProbability{T} + branch_p::Float64, + args::Tuple, + choices::ChoiceMap, + kernel_argdiffs::Tuple, + state::WithProbabilityUpdateState{T}) where T + local retval::T + + # get new subtrace with recursive call to update() + submap = get_submap(choices, :branch) + prev_subtrace = state.subtrace + (subtrace, weight, retdiff, discard) = update(prev_subtrace, kernel_args, kernel_argdiffs, submap) + + # retrieve retdiff + if retdiff != NoChange() + state.updated_retdiff = retdiff + end + + # update state + state.weight += weight + set_submap!(state.discard, key, discard) + state.score += (get_score(subtrace) - get_score(prev_subtrace)) + state.noise += (project(subtrace, EmptySelection()) - project(prev_subtrace, EmptySelection())) + state.subtraces = assoc(state.subtraces, key, subtrace) + retval = get_retval(subtrace) + state.retval = assoc(state.retval, key, retval) + subtrace_empty = isempty(get_choices(subtrace)) + prev_subtrace_empty = isempty(get_choices(prev_subtrace)) + if !subtrace_empty && prev_subtrace_empty + state.num_nonempty += 1 + elseif subtrace_empty && !prev_subtrace_empty + state.num_nonempty -= 1 + end +end + +function update(trace::WithProbability{T}, + args::Tuple, + argdiffs::Tuple, + choices::ChoiceMap) where T + gen_fn = trace.gen_fn + branch_p = args[1] + return (new_trace, state.weight, retdiff, discard) +end diff --git a/src/modeling_library/with_prob/with_prob.jl b/src/modeling_library/with_prob/with_prob.jl new file mode 100644 index 00000000..396a9b8d --- /dev/null +++ b/src/modeling_library/with_prob/with_prob.jl @@ -0,0 +1,19 @@ +struct WithProbability{T} <: GenerativeFunction{T, Trace} + a::GenerativeFunction{T} + b::GenerativeFunction{T} +end + +export WithProbability + +has_argument_grads(switch_fn::WithProbability) = has_argument_grads(switch_fn.a) && has_argument_grads(switch_fn.b) +accepts_output_grad(switch_fn::WithProbability) = accepts_output_grad(switch_fn.a) && accepts_output_grad(switch_fn.b) + +function (gen_fn::WithProbability)(flip_p::Float64, args...) + (_, _, retval) = propose(gen_fn, (flip_p, args...)) + retval +end + +include("assess.jl") +include("propose.jl") +include("simulate.jl") +include("generate.jl") From 29b7797134df094cd237a30b2f6661c810a0fff5 Mon Sep 17 00:00:00 2001 From: femtomc Date: Tue, 17 Nov 2020 22:03:41 -0500 Subject: [PATCH 07/30] Working on Switch update and regenerate. --- scratch/switch_comb.jl | 30 ------------ src/modeling_library/switch/assess.jl | 8 ++-- src/modeling_library/switch/generate.jl | 8 ++-- src/modeling_library/switch/propose.jl | 8 ++-- src/modeling_library/switch/simulate.jl | 25 ++++++++-- src/modeling_library/switch/switch.jl | 15 +++++- src/modeling_library/switch/update.jl | 4 +- test/modeling_library/switch.jl | 64 ++++++++++++------------- 8 files changed, 78 insertions(+), 84 deletions(-) delete mode 100644 scratch/switch_comb.jl diff --git a/scratch/switch_comb.jl b/scratch/switch_comb.jl deleted file mode 100644 index 952a8f82..00000000 --- a/scratch/switch_comb.jl +++ /dev/null @@ -1,30 +0,0 @@ -module SwitchComb - -include("../src/Gen.jl") -using .Gen - -@gen (grad) function foo((grad)(x::Float64), (grad)(y::Float64)) - std::Float64 = 3.0 - z = @trace(normal(x + y, std), :z) - return z -end - -@gen (grad) function baz((grad)(x::Float64), (grad)(y::Float64)) - std::Float64 = 3.0 - z = @trace(normal(x + 2 * y, std), :z) - return z -end - -sc = Switch(foo, baz) -chm, _, _ = propose(sc, (2, 5.0, 3.0)) -display(chm) - -tr = simulate(sc, (2, 5.0, 3.0)) -display(get_choices(tr)) - -chm = choicemap() -chm[:z] = 5.0 -tr, _ = generate(sc, (2, 5.0, 3.0), chm) -display(get_choices(tr)) - -end # module diff --git a/src/modeling_library/switch/assess.jl b/src/modeling_library/switch/assess.jl index 642bde05..808e9e8a 100644 --- a/src/modeling_library/switch/assess.jl +++ b/src/modeling_library/switch/assess.jl @@ -3,19 +3,19 @@ mutable struct SwitchAssessState{T} retval::T end -function process_new!(gen_fn::Switch{N, K, T}, +function process_new!(gen_fn::Switch{C, N, K, T}, branch_p::Float64, args::Tuple, choices::ChoiceMap, - state::SwitchAssessState{T}) where {N, K, T} + state::SwitchAssessState{T}) where {C, N, K, T} (weight, retval) = assess(getindex(gen_fn.mix, index), kernel_args, choices) state.weight += weight state.retval = retval end -function assess(gen_fn::Switch{N, K, T}, +function assess(gen_fn::Switch{C, N, K, T}, args::Tuple, - choices::ChoiceMap) where {N, K, T} + choices::ChoiceMap) where {C, N, K, T} index = args[1] state = SwitchAssessState{T}(0.0) process_new!(gen_fn, index, args[2 : end], choices, state) diff --git a/src/modeling_library/switch/generate.jl b/src/modeling_library/switch/generate.jl index fb77320e..0d12e951 100644 --- a/src/modeling_library/switch/generate.jl +++ b/src/modeling_library/switch/generate.jl @@ -8,11 +8,11 @@ mutable struct SwitchGenerateState{T} SwitchGenerateState{T}(score::Float64, noise::Float64, weight::Float64) where T = new{T}(score, noise, weight) end -function process!(gen_fn::Switch{N, K, T}, +function process!(gen_fn::Switch{C, N, K, T}, index::Int, args::Tuple, choices::ChoiceMap, - state::SwitchGenerateState{T}) where {N, K, T} + state::SwitchGenerateState{T}) where {C, N, K, T} (subtrace, weight) = generate(getindex(gen_fn.mix, index), args, choices) state.subtrace = subtrace @@ -20,9 +20,9 @@ function process!(gen_fn::Switch{N, K, T}, state.retval = get_retval(subtrace) end -function generate(gen_fn::Switch{N, K, T}, +function generate(gen_fn::Switch{C, N, K, T}, args::Tuple, - choices::ChoiceMap) where {N, K, T} + choices::ChoiceMap) where {C, N, K, T} index = args[1] state = SwitchGenerateState{T}(0.0, 0.0, 0.0) diff --git a/src/modeling_library/switch/propose.jl b/src/modeling_library/switch/propose.jl index a1253470..6adf0654 100644 --- a/src/modeling_library/switch/propose.jl +++ b/src/modeling_library/switch/propose.jl @@ -5,10 +5,10 @@ mutable struct SwitchProposeState{T} SwitchProposeState{T}(choices, weight) where T = new{T}(choices, weight) end -function process_new!(gen_fn::Switch{N, K, T}, +function process_new!(gen_fn::Switch{C, N, K, T}, index::Int, args::Tuple, - state::SwitchProposeState{T}) where {N, K, T} + state::SwitchProposeState{T}) where {C, N, K, T} (submap, weight, retval) = propose(getindex(gen_fn.mix, index), args) state.choices = submap @@ -16,8 +16,8 @@ function process_new!(gen_fn::Switch{N, K, T}, state.retval = retval end -function propose(gen_fn::Switch{N, K, T}, - args::Tuple) where {N, K, T} +function propose(gen_fn::Switch{C, N, K, T}, + args::Tuple) where {C, N, K, T} index = args[1] choices = choicemap() diff --git a/src/modeling_library/switch/simulate.jl b/src/modeling_library/switch/simulate.jl index fc203dff..dad59290 100644 --- a/src/modeling_library/switch/simulate.jl +++ b/src/modeling_library/switch/simulate.jl @@ -7,24 +7,39 @@ mutable struct SwitchSimulateState{T} SwitchSimulateState{T}(score::Float64, noise::Float64) where T = new{T}(score, noise) end -function process!(gen_fn::Switch{N, K, T}, +function process!(gen_fn::Switch{C, N, K, T}, index::Int, args::Tuple, - state::SwitchSimulateState{T}) where {N, K, T} + state::SwitchSimulateState{T}) where {C, N, K, T} local retval::T subtrace = simulate(getindex(gen_fn.mix, index), args) + state.index = index state.noise += project(subtrace, EmptySelection()) state.subtrace = subtrace state.score += get_score(subtrace) state.retval = get_retval(subtrace) end -function simulate(gen_fn::Switch{N, K, T}, - args::Tuple) where {N, K, T} +function process!(gen_fn::Switch{C, N, K, T}, + index::C, + args::Tuple, + state::SwitchSimulateState{T}) where {C, N, K, T} + local retval::T + index = getindex(gen_fn.cases, index) + state.index = index + subtrace = simulate(getindex(gen_fn.mix, index), args) + state.noise += project(subtrace, EmptySelection()) + state.subtrace = subtrace + state.score += get_score(subtrace) + state.retval = get_retval(subtrace) +end + +function simulate(gen_fn::Switch{C, N, K, T}, + args::Tuple) where {C, N, K, T} index = args[1] state = SwitchSimulateState{T}(0.0, 0.0) process!(gen_fn, index, args[2 : end], state) - trace = SwitchTrace{T}(gen_fn, index, state.subtrace, state.retval, args[2 : end], state.score, state.noise) + trace = SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise) trace end diff --git a/src/modeling_library/switch/switch.jl b/src/modeling_library/switch/switch.jl index 0e368d17..48bfec9f 100644 --- a/src/modeling_library/switch/switch.jl +++ b/src/modeling_library/switch/switch.jl @@ -1,9 +1,15 @@ -struct Switch{N, K, T} <: GenerativeFunction{T, Trace} +struct Switch{C, N, K, T} <: GenerativeFunction{T, Trace} mix::NTuple{N, GenerativeFunction{T}} + cases::Dict{C, Int} function Switch(gen_fns::GenerativeFunction...) @assert !isempty(gen_fns) rettype = get_return_type(getindex(gen_fns, 1)) - new{length(gen_fns), typeof(gen_fns), rettype}(gen_fns) + new{Int, length(gen_fns), typeof(gen_fns), rettype}(gen_fns, Dict{Int, Int}()) + end + function Switch(d::Dict{C, Int}, gen_fns::GenerativeFunction...) where C + @assert !isempty(gen_fns) + rettype = get_return_type(getindex(gen_fns, 1)) + new{C, length(gen_fns), typeof(gen_fns), rettype}(gen_fns, d) end end @@ -17,6 +23,11 @@ function (gen_fn::Switch)(index::Int, args...) retval end +function (gen_fn::Switch{C})(index::C, args...) where C + (_, _, retval) = propose(gen_fn, (gen_fn.cases[index], args...)) + retval +end + include("assess.jl") include("propose.jl") include("simulate.jl") diff --git a/src/modeling_library/switch/update.jl b/src/modeling_library/switch/update.jl index f7b92ea5..493d6a7d 100644 --- a/src/modeling_library/switch/update.jl +++ b/src/modeling_library/switch/update.jl @@ -9,12 +9,12 @@ mutable struct SwitchUpdateState{T,U} updated_retdiff::Diff end -function process!(gen_fn::Switch{T1, T2, Tr}, +function process!(gen_fn::Switch{C, N, T, K}, branch_p::Float64, args::Tuple, choices::ChoiceMap, kernel_argdiffs::Tuple, - state::SwitchUpdateState{Union{T1, T2}, Tr}) where {T1, T2, Tr} + state::SwitchUpdateState{T}) where {C, N, T, K} local subtrace::Tr local prev_subtrace::Tr local retval::T diff --git a/test/modeling_library/switch.jl b/test/modeling_library/switch.jl index 6343deae..3464eeae 100644 --- a/test/modeling_library/switch.jl +++ b/test/modeling_library/switch.jl @@ -1,44 +1,42 @@ -@testset "switch combinator" begin +module SwitchComb - @gen (grad) function foo((grad)(x::Float64), (grad)(y::Float64)) - @param std::Float64 - z = @trace(normal(x + y, std), :z) - return z - end +include("../src/Gen.jl") +using .Gen - @gen (grad) function baz((grad)(x::Float64), (grad)(y::Float64)) - @param std::Float64 - z = @trace(normal(x + 2 * y, std), :z) - return z - end +# ------------ Toplevel caller ------------ # - set_param!(foo, :std, 1.) - set_param!(baz, :std, 1.) - - bar = Switch(foo, baz) - args = (1.0, 3.0) +@gen (grad) function foo((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + y, std), :z) + return z +end - @testset "simulate" begin - end +@gen (grad) function baz((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + 2 * y, std), :z) + return z +end - @testset "generate" begin - end +sc = Switch(Dict(:x => 1, :y => 2), foo, baz) +chm, _, _ = propose(sc, (2, 5.0, 3.0)) +display(chm) - @testset "propose" begin - end +tr = simulate(sc, (2, 5.0, 3.0)) +display(get_choices(tr)) - @testset "assess" begin - end +chm = choicemap() +chm[:z] = 5.0 +tr, _ = generate(sc, (2, 5.0, 3.0), chm) +display(get_choices(tr)) - @testset "update" begin - end +# ------------ Static DSL ------------ # - @testset "regenerate" begin - end +@gen (static) function bam(s::Symbol) + x ~ sc(s, 5.0, 3.0) +end +Gen.@load_generated_functions() - @testset "choice_gradients" begin - end +tr = simulate(bam, (:x, )) +display(get_choices(tr)) - @testset "accumulate_param_gradients!" begin - end -end +end # module From 3e6e3071311814f219413f2ebfe88f666f33318b Mon Sep 17 00:00:00 2001 From: femtomc Date: Tue, 17 Nov 2020 23:49:14 -0500 Subject: [PATCH 08/30] Added Switch update and regenerate. --- src/modeling_library/cond.jl | 18 ++---- src/modeling_library/switch/assess.jl | 14 +++-- src/modeling_library/switch/generate.jl | 4 +- src/modeling_library/switch/propose.jl | 12 ++-- src/modeling_library/switch/regenerate.jl | 44 ++++++++++++++ src/modeling_library/switch/simulate.jl | 14 +---- src/modeling_library/switch/switch.jl | 2 + src/modeling_library/switch/update.jl | 71 ++++++++++------------- src/modeling_library/with_prob/update.jl | 35 ++--------- test/modeling_library/switch.jl | 2 +- 10 files changed, 106 insertions(+), 110 deletions(-) create mode 100644 src/modeling_library/switch/regenerate.jl diff --git a/src/modeling_library/cond.jl b/src/modeling_library/cond.jl index c77846d4..ca164a54 100644 --- a/src/modeling_library/cond.jl +++ b/src/modeling_library/cond.jl @@ -30,12 +30,10 @@ end @inline Base.getindex(tr::WithProbabilityTrace, addr::Symbol) = getfield(trace, addr) function project(tr::WithProbabilityTrace, selection::Selection) - weight = 0. - for k in [:cond, :branch] - subselection = selection[k] - weight += project(getindex(tr, k), subselection) - end - weight + sum(map([:cond, :branch]) do k + subselection = selection[k] + project(getindex(tr, k), subselection) + end) end project(tr::WithProbabilityTrace, ::EmptySelection) = tr.noise @@ -64,9 +62,5 @@ end end @inline Base.getindex(tr::SwitchTrace, addr::Symbol) = getfield(trace, addr) -function project(tr::SwitchTrace, selection::Selection) - weight = 0. - weight += project(tr.branch, selection) - weight -end -project(tr::SwitchTrace, ::EmptySelection) = tr.noise +@inline project(tr::SwitchTrace, selection::Selection) = project(tr.branch, selection) +@inline project(tr::SwitchTrace, ::EmptySelection) = tr.noise diff --git a/src/modeling_library/switch/assess.jl b/src/modeling_library/switch/assess.jl index 808e9e8a..fcf3d2b9 100644 --- a/src/modeling_library/switch/assess.jl +++ b/src/modeling_library/switch/assess.jl @@ -3,21 +3,23 @@ mutable struct SwitchAssessState{T} retval::T end -function process_new!(gen_fn::Switch{C, N, K, T}, - branch_p::Float64, - args::Tuple, - choices::ChoiceMap, - state::SwitchAssessState{T}) where {C, N, K, T} +function process!(gen_fn::Switch{C, N, K, T}, + index::Int, + args::Tuple, + choices::ChoiceMap, + state::SwitchAssessState{T}) where {C, N, K, T} (weight, retval) = assess(getindex(gen_fn.mix, index), kernel_args, choices) state.weight += weight state.retval = retval end +@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, choices::ChoiceMap, state::SwitchAssessState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, choices, state) + function assess(gen_fn::Switch{C, N, K, T}, args::Tuple, choices::ChoiceMap) where {C, N, K, T} index = args[1] state = SwitchAssessState{T}(0.0) - process_new!(gen_fn, index, args[2 : end], choices, state) + process!(gen_fn, index, args[2 : end], choices, state) (state.weight, state.retval) end diff --git a/src/modeling_library/switch/generate.jl b/src/modeling_library/switch/generate.jl index 0d12e951..df20b3ff 100644 --- a/src/modeling_library/switch/generate.jl +++ b/src/modeling_library/switch/generate.jl @@ -13,13 +13,15 @@ function process!(gen_fn::Switch{C, N, K, T}, args::Tuple, choices::ChoiceMap, state::SwitchGenerateState{T}) where {C, N, K, T} - + (subtrace, weight) = generate(getindex(gen_fn.mix, index), args, choices) state.subtrace = subtrace state.weight += weight state.retval = get_retval(subtrace) end +@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, choices::ChoiceMap, state::SwitchGenerateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, choices, state) + function generate(gen_fn::Switch{C, N, K, T}, args::Tuple, choices::ChoiceMap) where {C, N, K, T} diff --git a/src/modeling_library/switch/propose.jl b/src/modeling_library/switch/propose.jl index 6adf0654..c149c112 100644 --- a/src/modeling_library/switch/propose.jl +++ b/src/modeling_library/switch/propose.jl @@ -5,10 +5,10 @@ mutable struct SwitchProposeState{T} SwitchProposeState{T}(choices, weight) where T = new{T}(choices, weight) end -function process_new!(gen_fn::Switch{C, N, K, T}, - index::Int, - args::Tuple, - state::SwitchProposeState{T}) where {C, N, K, T} +function process!(gen_fn::Switch{C, N, K, T}, + index::Int, + args::Tuple, + state::SwitchProposeState{T}) where {C, N, K, T} (submap, weight, retval) = propose(getindex(gen_fn.mix, index), args) state.choices = submap @@ -16,12 +16,14 @@ function process_new!(gen_fn::Switch{C, N, K, T}, state.retval = retval end +@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, state::SwitchProposeState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, state) + function propose(gen_fn::Switch{C, N, K, T}, args::Tuple) where {C, N, K, T} index = args[1] choices = choicemap() state = SwitchProposeState{T}(choices, 0.0) - process_new!(gen_fn, index, args[2:end], state) + process!(gen_fn, index, args[2:end], state) (state.choices, state.weight, state.retval) end diff --git a/src/modeling_library/switch/regenerate.jl b/src/modeling_library/switch/regenerate.jl new file mode 100644 index 00000000..e5bcc27d --- /dev/null +++ b/src/modeling_library/switch/regenerate.jl @@ -0,0 +1,44 @@ +mutable struct SwitchRegenerateState{T} + weight::Float64 + score::Float64 + noise::Float64 + prev_trace::Trace + trace::Trace + index::Int + discard::DynamicChoiceMap + updated_retdiff::Diff + SwitchRegenerateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, subtrace) +end + +function process!(gen_fn::Switch{C, N, T, K}, + index::Int, + index_argdiff::Diff, + args::Tuple, + kernel_argdiffs::Tuple, + selection::Selection, + state::SwitchRegenerateState{T}) where {C, N, T, K} + if index != getfield(state.prev_trace, :index) + decrement = get_score(state.prev_trace) + new_trace, weight, retdiff, discard = regenerate(getfield(state.prev_trace, :branch), args, kernel_argdiffs, selection) + state.weight = weight - decrement + else + new_trace, weight, retdiff, discard = regenerate(getfield(state.prev_trace, :branch), args, kernel_argdiffs, selection) + state.weight = weight + end + state.trace = new_trace + state.updated_retdiff = retdiff + state.discard = discard +end + +@inline process!(gen_fn::Switch{C, N, T, K}, index::C, index_argdiff::Diff, args::Tuple, selection::Selection, kernel_argdiffs::Tuple, state::SwitchRegenerateState{T}) where {C, N, T, K} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, selection, kernel_argdiffs, state) + +function regenerate(trace::SwitchTrace{T}, + args::Tuple, + argdiffs::Tuple, + selection::Selection) where T + gen_fn = trace.gen_fn + index, index_argdiff = args[1], argdiffs[1] + state = SwitchRegenerateState{T}(0.0, 0.0, 0.0, trace) + process!(gen_fn, index, index_argdiff, args, kernel_argdiffs, selection, argdiffs) + return (state.trace, state.weight, state.updated_retdiff, state.discard) +end diff --git a/src/modeling_library/switch/simulate.jl b/src/modeling_library/switch/simulate.jl index dad59290..1460a7fb 100644 --- a/src/modeling_library/switch/simulate.jl +++ b/src/modeling_library/switch/simulate.jl @@ -20,19 +20,7 @@ function process!(gen_fn::Switch{C, N, K, T}, state.retval = get_retval(subtrace) end -function process!(gen_fn::Switch{C, N, K, T}, - index::C, - args::Tuple, - state::SwitchSimulateState{T}) where {C, N, K, T} - local retval::T - index = getindex(gen_fn.cases, index) - state.index = index - subtrace = simulate(getindex(gen_fn.mix, index), args) - state.noise += project(subtrace, EmptySelection()) - state.subtrace = subtrace - state.score += get_score(subtrace) - state.retval = get_retval(subtrace) -end +@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, state::SwitchSimulateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, state) function simulate(gen_fn::Switch{C, N, K, T}, args::Tuple) where {C, N, K, T} diff --git a/src/modeling_library/switch/switch.jl b/src/modeling_library/switch/switch.jl index 48bfec9f..a3e18bcf 100644 --- a/src/modeling_library/switch/switch.jl +++ b/src/modeling_library/switch/switch.jl @@ -32,3 +32,5 @@ include("assess.jl") include("propose.jl") include("simulate.jl") include("generate.jl") +include("update.jl") +include("regenerate.jl") diff --git a/src/modeling_library/switch/update.jl b/src/modeling_library/switch/update.jl index 493d6a7d..86d2bb5e 100644 --- a/src/modeling_library/switch/update.jl +++ b/src/modeling_library/switch/update.jl @@ -1,56 +1,45 @@ -mutable struct SwitchUpdateState{T,U} +mutable struct SwitchUpdateState{T} weight::Float64 score::Float64 noise::Float64 - cond::Bool - subtrace::U - retval::T + prev_trace::Trace + trace::Trace + index::Int discard::DynamicChoiceMap updated_retdiff::Diff + SwitchUpdateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, subtrace) end function process!(gen_fn::Switch{C, N, T, K}, - branch_p::Float64, - args::Tuple, - choices::ChoiceMap, - kernel_argdiffs::Tuple, - state::SwitchUpdateState{T}) where {C, N, T, K} - local subtrace::Tr - local prev_subtrace::Tr - local retval::T - - # get new subtrace with recursive call to update() - submap = get_submap(choices, :branch) - prev_subtrace = state.subtrace - (subtrace, weight, retdiff, discard) = update(prev_subtrace, kernel_args, kernel_argdiffs, submap) - - # retrieve retdiff - if retdiff != NoChange() - state.updated_retdiff = retdiff - end - - # update state - state.weight += weight - set_submap!(state.discard, key, discard) - state.score += (get_score(subtrace) - get_score(prev_subtrace)) - state.noise += (project(subtrace, EmptySelection()) - project(prev_subtrace, EmptySelection())) - state.subtraces = assoc(state.subtraces, key, subtrace) - retval = get_retval(subtrace) - state.retval = assoc(state.retval, key, retval) - subtrace_empty = isempty(get_choices(subtrace)) - prev_subtrace_empty = isempty(get_choices(prev_subtrace)) - if !subtrace_empty && prev_subtrace_empty - state.num_nonempty += 1 - elseif subtrace_empty && !prev_subtrace_empty - state.num_nonempty -= 1 + index::Int, + index_argdiff::Diff, + args::Tuple, + kernel_argdiffs::Tuple, + choices::ChoiceMap, + state::SwitchUpdateState{T}) where {C, N, T, K} + if index != getfield(state.prev_trace, :index) + decrement = get_score(state.prev_trace) + merged = merge!(get_choices(state.prev_trace), choices) + new_trace, weight, retdiff, discard = update(getfield(state.prev_trace, :branch), args, kernel_argdiffs, merged) + state.weight = weight - decrement + else + new_trace, weight, retdiff, discard = update(getfield(state.prev_trace, :branch), args, kernel_argdiffs, choices) + state.weight = weight end + state.trace = new_trace + state.updated_retdiff = retdiff + state.discard = discard end -function update(trace::Switch{T1, T2, Tr}, +@inline process!(gen_fn::Switch{C, N, T, K}, index::C, index_argdiff::Diff, args::Tuple, choices::ChoiceMap, kernel_argdiffs::Tuple, state::SwitchUpdateState{T}) where {C, N, T, K} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, choices, kernel_argdiffs, state) + +function update(trace::SwitchTrace{T}, args::Tuple, argdiffs::Tuple, - choices::ChoiceMap) where {T1, T2, Tr} + choices::ChoiceMap) where T gen_fn = trace.gen_fn - branch_p = args[1] - return (new_trace, state.weight, retdiff, discard) + index, index_argdiff = args[1], argdiffs[1] + state = SwitchUpdateState{T}(0.0, 0.0, 0.0, trace) + process!(gen_fn, index, index_argdiff, args, kernel_argdiffs, choices, argdiffs) + return (state.trace, state.weight, state.updated_retdiff, state.discard) end diff --git a/src/modeling_library/with_prob/update.jl b/src/modeling_library/with_prob/update.jl index 6278f527..5e482c85 100644 --- a/src/modeling_library/with_prob/update.jl +++ b/src/modeling_library/with_prob/update.jl @@ -3,8 +3,8 @@ mutable struct WithProbabilityUpdateState{T} score::Float64 noise::Float64 cond::Bool - subtrace::Trace - retval::T + prev_trace::Trace + trace::Trace discard::DynamicChoiceMap updated_retdiff::Diff end @@ -12,36 +12,9 @@ end function process!(gen_fn::WithProbability{T} branch_p::Float64, args::Tuple, - choices::ChoiceMap, kernel_argdiffs::Tuple, + choices::ChoiceMap, state::WithProbabilityUpdateState{T}) where T - local retval::T - - # get new subtrace with recursive call to update() - submap = get_submap(choices, :branch) - prev_subtrace = state.subtrace - (subtrace, weight, retdiff, discard) = update(prev_subtrace, kernel_args, kernel_argdiffs, submap) - - # retrieve retdiff - if retdiff != NoChange() - state.updated_retdiff = retdiff - end - - # update state - state.weight += weight - set_submap!(state.discard, key, discard) - state.score += (get_score(subtrace) - get_score(prev_subtrace)) - state.noise += (project(subtrace, EmptySelection()) - project(prev_subtrace, EmptySelection())) - state.subtraces = assoc(state.subtraces, key, subtrace) - retval = get_retval(subtrace) - state.retval = assoc(state.retval, key, retval) - subtrace_empty = isempty(get_choices(subtrace)) - prev_subtrace_empty = isempty(get_choices(prev_subtrace)) - if !subtrace_empty && prev_subtrace_empty - state.num_nonempty += 1 - elseif subtrace_empty && !prev_subtrace_empty - state.num_nonempty -= 1 - end end function update(trace::WithProbability{T}, @@ -49,6 +22,6 @@ function update(trace::WithProbability{T}, argdiffs::Tuple, choices::ChoiceMap) where T gen_fn = trace.gen_fn - branch_p = args[1] + branch_p, branch_p_diff = args[1], argdiffs[1] return (new_trace, state.weight, retdiff, discard) end diff --git a/test/modeling_library/switch.jl b/test/modeling_library/switch.jl index 3464eeae..48bb2e18 100644 --- a/test/modeling_library/switch.jl +++ b/test/modeling_library/switch.jl @@ -1,6 +1,6 @@ module SwitchComb -include("../src/Gen.jl") +include("../../src/Gen.jl") using .Gen # ------------ Toplevel caller ------------ # From 7929b8630ca1a9a15cbb8b8d5381cfae23db430c Mon Sep 17 00:00:00 2001 From: femtomc Date: Wed, 18 Nov 2020 00:27:23 -0500 Subject: [PATCH 09/30] Added Switch update and regenerate - working out kinks in update. --- src/modeling_library/cond.jl | 8 ++++---- src/modeling_library/switch/assess.jl | 2 +- src/modeling_library/switch/generate.jl | 3 +-- src/modeling_library/switch/propose.jl | 2 +- src/modeling_library/switch/regenerate.jl | 11 ++++++----- src/modeling_library/switch/simulate.jl | 3 +-- src/modeling_library/switch/update.jl | 23 +++++++++++++---------- test/modeling_library/switch.jl | 5 +++++ 8 files changed, 32 insertions(+), 25 deletions(-) diff --git a/src/modeling_library/cond.jl b/src/modeling_library/cond.jl index ca164a54..270689a9 100644 --- a/src/modeling_library/cond.jl +++ b/src/modeling_library/cond.jl @@ -1,7 +1,7 @@ # ------------ WithProbability trace ------------ # struct WithProbabilityTrace{T1, T2, Tr} <: Trace - kernel::GenerativeFunction{Union{T1, T2}, Tr} + gen_fn::GenerativeFunction{Union{T1, T2}, Tr} p::Float64 cond::Bool branch::Tr @@ -20,7 +20,7 @@ end @inline get_retval(tr::WithProbabilityTrace) = tr.retval @inline get_args(tr::WithProbabilityTrace) = tr.args @inline get_score(tr::WithProbabilityTrace) = tr.score -@inline get_gen_fn(tr::WithProbabilityTrace) = tr.kernel +@inline get_gen_fn(tr::WithProbabilityTrace) = tr.gen_fn @inline function Base.getindex(tr::WithProbabilityTrace, addr::Pair) (first, rest) = addr @@ -40,7 +40,7 @@ project(tr::WithProbabilityTrace, ::EmptySelection) = tr.noise # ------------ Switch trace ------------ # struct SwitchTrace{T} <: Trace - kernel::GenerativeFunction{T} + gen_fn::GenerativeFunction{T} index::Int branch::Trace retval::T @@ -53,7 +53,7 @@ end @inline get_retval(tr::SwitchTrace) = tr.retval @inline get_args(tr::SwitchTrace) = tr.args @inline get_score(tr::SwitchTrace) = tr.score -@inline get_gen_fn(tr::SwitchTrace) = tr.kernel +@inline get_gen_fn(tr::SwitchTrace) = tr.gen_fn @inline function Base.getindex(tr::SwitchTrace, addr::Pair) (first, rest) = addr diff --git a/src/modeling_library/switch/assess.jl b/src/modeling_library/switch/assess.jl index fcf3d2b9..ed3656e5 100644 --- a/src/modeling_library/switch/assess.jl +++ b/src/modeling_library/switch/assess.jl @@ -21,5 +21,5 @@ function assess(gen_fn::Switch{C, N, K, T}, index = args[1] state = SwitchAssessState{T}(0.0) process!(gen_fn, index, args[2 : end], choices, state) - (state.weight, state.retval) + return state.weight, state.retval end diff --git a/src/modeling_library/switch/generate.jl b/src/modeling_library/switch/generate.jl index df20b3ff..456e5086 100644 --- a/src/modeling_library/switch/generate.jl +++ b/src/modeling_library/switch/generate.jl @@ -29,6 +29,5 @@ function generate(gen_fn::Switch{C, N, K, T}, index = args[1] state = SwitchGenerateState{T}(0.0, 0.0, 0.0) process!(gen_fn, index, args[2 : end], choices, state) - trace = SwitchTrace{T}(gen_fn, index, state.subtrace, state.retval, args[2 : end], state.score, state.noise) - (trace, state.weight) + return SwitchTrace{T}(gen_fn, index, state.subtrace, state.retval, args[2 : end], state.score, state.noise), state.weight end diff --git a/src/modeling_library/switch/propose.jl b/src/modeling_library/switch/propose.jl index c149c112..492f7ca5 100644 --- a/src/modeling_library/switch/propose.jl +++ b/src/modeling_library/switch/propose.jl @@ -25,5 +25,5 @@ function propose(gen_fn::Switch{C, N, K, T}, choices = choicemap() state = SwitchProposeState{T}(choices, 0.0) process!(gen_fn, index, args[2:end], state) - (state.choices, state.weight, state.retval) + return state.choices, state.weight, state.retval end diff --git a/src/modeling_library/switch/regenerate.jl b/src/modeling_library/switch/regenerate.jl index e5bcc27d..54d5f61b 100644 --- a/src/modeling_library/switch/regenerate.jl +++ b/src/modeling_library/switch/regenerate.jl @@ -7,18 +7,19 @@ mutable struct SwitchRegenerateState{T} index::Int discard::DynamicChoiceMap updated_retdiff::Diff - SwitchRegenerateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, subtrace) + SwitchRegenerateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace) end -function process!(gen_fn::Switch{C, N, T, K}, +function process!(gen_fn::Switch{C, N, K, T}, index::Int, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, selection::Selection, - state::SwitchRegenerateState{T}) where {C, N, T, K} + state::SwitchRegenerateState{T}) where {C, N, K, T} if index != getfield(state.prev_trace, :index) decrement = get_score(state.prev_trace) + kernel_argdiffs = map(_ -> UnknownChange(), kernel_argdiffs) new_trace, weight, retdiff, discard = regenerate(getfield(state.prev_trace, :branch), args, kernel_argdiffs, selection) state.weight = weight - decrement else @@ -30,7 +31,7 @@ function process!(gen_fn::Switch{C, N, T, K}, state.discard = discard end -@inline process!(gen_fn::Switch{C, N, T, K}, index::C, index_argdiff::Diff, args::Tuple, selection::Selection, kernel_argdiffs::Tuple, state::SwitchRegenerateState{T}) where {C, N, T, K} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, selection, kernel_argdiffs, state) +@inline process!(gen_fn::Switch{C, N, T, K}, index::C, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, selection::Selection, state::SwitchRegenerateState{T}) where {C, N, T, K} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, kernel_argdiffs, selection, state) function regenerate(trace::SwitchTrace{T}, args::Tuple, @@ -39,6 +40,6 @@ function regenerate(trace::SwitchTrace{T}, gen_fn = trace.gen_fn index, index_argdiff = args[1], argdiffs[1] state = SwitchRegenerateState{T}(0.0, 0.0, 0.0, trace) - process!(gen_fn, index, index_argdiff, args, kernel_argdiffs, selection, argdiffs) + process!(gen_fn, index, index_argdiff, args[2 : end], kernel_argdiffs[2 : end], selection, argdiffs) return (state.trace, state.weight, state.updated_retdiff, state.discard) end diff --git a/src/modeling_library/switch/simulate.jl b/src/modeling_library/switch/simulate.jl index 1460a7fb..13528f4f 100644 --- a/src/modeling_library/switch/simulate.jl +++ b/src/modeling_library/switch/simulate.jl @@ -28,6 +28,5 @@ function simulate(gen_fn::Switch{C, N, K, T}, index = args[1] state = SwitchSimulateState{T}(0.0, 0.0) process!(gen_fn, index, args[2 : end], state) - trace = SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise) - trace + SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise) end diff --git a/src/modeling_library/switch/update.jl b/src/modeling_library/switch/update.jl index 86d2bb5e..86bf9fd2 100644 --- a/src/modeling_library/switch/update.jl +++ b/src/modeling_library/switch/update.jl @@ -7,31 +7,34 @@ mutable struct SwitchUpdateState{T} index::Int discard::DynamicChoiceMap updated_retdiff::Diff - SwitchUpdateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, subtrace) + SwitchUpdateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace) end -function process!(gen_fn::Switch{C, N, T, K}, +function process!(gen_fn::Switch{C, N, K, T}, index::Int, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, choices::ChoiceMap, - state::SwitchUpdateState{T}) where {C, N, T, K} + state::SwitchUpdateState{T}) where {C, N, K, T} if index != getfield(state.prev_trace, :index) - decrement = get_score(state.prev_trace) - merged = merge!(get_choices(state.prev_trace), choices) + merged = merge(get_choices(state.prev_trace), choices) + display(merged) + kernel_argdiffs = map(_ -> UnknownChange(), kernel_argdiffs) new_trace, weight, retdiff, discard = update(getfield(state.prev_trace, :branch), args, kernel_argdiffs, merged) - state.weight = weight - decrement else new_trace, weight, retdiff, discard = update(getfield(state.prev_trace, :branch), args, kernel_argdiffs, choices) - state.weight = weight end + state.index = index + state.weight = weight + state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) + state.score = get_score(new_trace) state.trace = new_trace state.updated_retdiff = retdiff state.discard = discard end -@inline process!(gen_fn::Switch{C, N, T, K}, index::C, index_argdiff::Diff, args::Tuple, choices::ChoiceMap, kernel_argdiffs::Tuple, state::SwitchUpdateState{T}) where {C, N, T, K} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, choices, kernel_argdiffs, state) +@inline process!(gen_fn::Switch{C, N, K, T}, index::C, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, choices::ChoiceMap, state::SwitchUpdateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, kernel_argdiffs, choices, state) function update(trace::SwitchTrace{T}, args::Tuple, @@ -40,6 +43,6 @@ function update(trace::SwitchTrace{T}, gen_fn = trace.gen_fn index, index_argdiff = args[1], argdiffs[1] state = SwitchUpdateState{T}(0.0, 0.0, 0.0, trace) - process!(gen_fn, index, index_argdiff, args, kernel_argdiffs, choices, argdiffs) - return (state.trace, state.weight, state.updated_retdiff, state.discard) + process!(gen_fn, index, index_argdiff, args[2 : end], argdiffs[2 : end], choices, state) + return SwitchTrace(gen_fn, state.index, state.trace, get_retval(state.trace), args, state.score, state.noise), state.weight, state.updated_retdiff, state.discard end diff --git a/test/modeling_library/switch.jl b/test/modeling_library/switch.jl index 48bb2e18..a5dac55e 100644 --- a/test/modeling_library/switch.jl +++ b/test/modeling_library/switch.jl @@ -38,5 +38,10 @@ Gen.@load_generated_functions() tr = simulate(bam, (:x, )) display(get_choices(tr)) +display(get_score(tr)) + +new_tr, w = update(tr, (:y, ), (UnknownChange(), ), choicemap()) +display(get_choices(new_tr)) +display(get_score(new_tr) - w) end # module From 73618a142b239aebd4e11f5a3c61487afd4d4c91 Mon Sep 17 00:00:00 2001 From: femtomc Date: Wed, 18 Nov 2020 08:36:11 -0500 Subject: [PATCH 10/30] update and regenerate appear to be computing the correct ratios. To confirm with Marco/Alex. --- src/modeling_library/switch/regenerate.jl | 28 ++++++++++++----------- src/modeling_library/switch/update.jl | 9 ++++---- test/modeling_library/switch.jl | 5 ++-- 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/src/modeling_library/switch/regenerate.jl b/src/modeling_library/switch/regenerate.jl index 54d5f61b..a3d4a35f 100644 --- a/src/modeling_library/switch/regenerate.jl +++ b/src/modeling_library/switch/regenerate.jl @@ -5,8 +5,7 @@ mutable struct SwitchRegenerateState{T} prev_trace::Trace trace::Trace index::Int - discard::DynamicChoiceMap - updated_retdiff::Diff + retdiff::Diff SwitchRegenerateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace) end @@ -18,20 +17,23 @@ function process!(gen_fn::Switch{C, N, K, T}, selection::Selection, state::SwitchRegenerateState{T}) where {C, N, K, T} if index != getfield(state.prev_trace, :index) - decrement = get_score(state.prev_trace) - kernel_argdiffs = map(_ -> UnknownChange(), kernel_argdiffs) - new_trace, weight, retdiff, discard = regenerate(getfield(state.prev_trace, :branch), args, kernel_argdiffs, selection) - state.weight = weight - decrement + merged = get_choices(state.prev_trace) + branch_fn = getfield(gen_fn.mix, index) + new_trace, weight = generate(branch_fn, args, merged) + retdiff = UnknownChange() + weight -= get_score(state.prev_trace) else - new_trace, weight, retdiff, discard = regenerate(getfield(state.prev_trace, :branch), args, kernel_argdiffs, selection) - state.weight = weight + new_trace, weight, retdiff = regenerate(getfield(state.prev_trace, :branch), args, kernel_argdiffs, selection) end + state.index = index + state.weight = weight + state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) + state.score = get_score(new_trace) state.trace = new_trace - state.updated_retdiff = retdiff - state.discard = discard + state.retdiff = retdiff end -@inline process!(gen_fn::Switch{C, N, T, K}, index::C, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, selection::Selection, state::SwitchRegenerateState{T}) where {C, N, T, K} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, kernel_argdiffs, selection, state) +@inline process!(gen_fn::Switch{C, N, K, T}, index::C, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, selection::Selection, state::SwitchRegenerateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, kernel_argdiffs, selection, state) function regenerate(trace::SwitchTrace{T}, args::Tuple, @@ -40,6 +42,6 @@ function regenerate(trace::SwitchTrace{T}, gen_fn = trace.gen_fn index, index_argdiff = args[1], argdiffs[1] state = SwitchRegenerateState{T}(0.0, 0.0, 0.0, trace) - process!(gen_fn, index, index_argdiff, args[2 : end], kernel_argdiffs[2 : end], selection, argdiffs) - return (state.trace, state.weight, state.updated_retdiff, state.discard) + process!(gen_fn, index, index_argdiff, args[2 : end], argdiffs[2 : end], selection, state) + return SwitchTrace(gen_fn, state.index, state.trace, get_retval(state.trace), args, state.score, state.noise), state.weight, state.retdiff end diff --git a/src/modeling_library/switch/update.jl b/src/modeling_library/switch/update.jl index 86bf9fd2..83a9a1b5 100644 --- a/src/modeling_library/switch/update.jl +++ b/src/modeling_library/switch/update.jl @@ -5,7 +5,7 @@ mutable struct SwitchUpdateState{T} prev_trace::Trace trace::Trace index::Int - discard::DynamicChoiceMap + discard::ChoiceMap updated_retdiff::Diff SwitchUpdateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace) end @@ -19,9 +19,10 @@ function process!(gen_fn::Switch{C, N, K, T}, state::SwitchUpdateState{T}) where {C, N, K, T} if index != getfield(state.prev_trace, :index) merged = merge(get_choices(state.prev_trace), choices) - display(merged) - kernel_argdiffs = map(_ -> UnknownChange(), kernel_argdiffs) - new_trace, weight, retdiff, discard = update(getfield(state.prev_trace, :branch), args, kernel_argdiffs, merged) + branch_fn = getfield(gen_fn.mix, index) + new_trace, weight = generate(branch_fn, args, merged) + retdiff, discard = UnknownChange(), get_choices(getfield(state.prev_trace, :branch)) + weight -= get_score(state.prev_trace) else new_trace, weight, retdiff, discard = update(getfield(state.prev_trace, :branch), args, kernel_argdiffs, choices) end diff --git a/test/modeling_library/switch.jl b/test/modeling_library/switch.jl index a5dac55e..82a70295 100644 --- a/test/modeling_library/switch.jl +++ b/test/modeling_library/switch.jl @@ -37,11 +37,12 @@ end Gen.@load_generated_functions() tr = simulate(bam, (:x, )) -display(get_choices(tr)) display(get_score(tr)) new_tr, w = update(tr, (:y, ), (UnknownChange(), ), choicemap()) -display(get_choices(new_tr)) +display(get_score(new_tr) - w) + +new_tr, w = regenerate(tr, (:x, ), (UnknownChange(), ), select()) display(get_score(new_tr) - w) end # module From 252413f7f93a2fc7b35a07d0c0e12fa6ae099eee Mon Sep 17 00:00:00 2001 From: femtomc Date: Wed, 18 Nov 2020 08:38:51 -0500 Subject: [PATCH 11/30] Fixed generate index type bug. --- src/modeling_library/switch/generate.jl | 3 ++- test/modeling_library/switch.jl | 16 +++++++++++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/modeling_library/switch/generate.jl b/src/modeling_library/switch/generate.jl index 456e5086..1c135c67 100644 --- a/src/modeling_library/switch/generate.jl +++ b/src/modeling_library/switch/generate.jl @@ -15,6 +15,7 @@ function process!(gen_fn::Switch{C, N, K, T}, state::SwitchGenerateState{T}) where {C, N, K, T} (subtrace, weight) = generate(getindex(gen_fn.mix, index), args, choices) + state.index = index state.subtrace = subtrace state.weight += weight state.retval = get_retval(subtrace) @@ -29,5 +30,5 @@ function generate(gen_fn::Switch{C, N, K, T}, index = args[1] state = SwitchGenerateState{T}(0.0, 0.0, 0.0) process!(gen_fn, index, args[2 : end], choices, state) - return SwitchTrace{T}(gen_fn, index, state.subtrace, state.retval, args[2 : end], state.score, state.noise), state.weight + return SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise), state.weight end diff --git a/test/modeling_library/switch.jl b/test/modeling_library/switch.jl index 82a70295..dacaa573 100644 --- a/test/modeling_library/switch.jl +++ b/test/modeling_library/switch.jl @@ -17,7 +17,8 @@ end return z end -sc = Switch(Dict(:x => 1, :y => 2), foo, baz) +# Standard. +sc = Switch(foo, baz) chm, _, _ = propose(sc, (2, 5.0, 3.0)) display(chm) @@ -29,6 +30,19 @@ chm[:z] = 5.0 tr, _ = generate(sc, (2, 5.0, 3.0), chm) display(get_choices(tr)) +# Cases. +sc = Switch(Dict(:x => 1, :y => 2), foo, baz) +chm, _, _ = propose(sc, (:x, 5.0, 3.0)) +display(chm) + +tr = simulate(sc, (:x, 5.0, 3.0)) +display(get_choices(tr)) + +chm = choicemap() +chm[:z] = 5.0 +tr, _ = generate(sc, (:x, 5.0, 3.0), chm) +display(get_choices(tr)) + # ------------ Static DSL ------------ # @gen (static) function bam(s::Symbol) From ac3528e60ec3469165902fe9acdabbf39cad1ebc Mon Sep 17 00:00:00 2001 From: femtomc Date: Wed, 18 Nov 2020 08:53:06 -0500 Subject: [PATCH 12/30] Branch dispatch done using diff types. --- src/modeling_library/switch/update.jl | 33 +++++++++++++++++++-------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/src/modeling_library/switch/update.jl b/src/modeling_library/switch/update.jl index 83a9a1b5..492dc11a 100644 --- a/src/modeling_library/switch/update.jl +++ b/src/modeling_library/switch/update.jl @@ -12,20 +12,33 @@ end function process!(gen_fn::Switch{C, N, K, T}, index::Int, - index_argdiff::Diff, + index_argdiff::UnknownChange, # TODO: Diffed wrapper? + args::Tuple, + kernel_argdiffs::Tuple, + choices::ChoiceMap, + state::SwitchUpdateState{T}) where {C, N, K, T, DV} + merged = merge(get_choices(state.prev_trace), choices) + branch_fn = getfield(gen_fn.mix, index) + new_trace, weight = generate(branch_fn, args, merged) + retdiff, discard = UnknownChange(), get_choices(getfield(state.prev_trace, :branch)) + weight -= get_score(state.prev_trace) + state.index = index + state.weight = weight + state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) + state.score = get_score(new_trace) + state.trace = new_trace + state.updated_retdiff = retdiff + state.discard = discard +end + +function process!(gen_fn::Switch{C, N, K, T}, + index::Int, + index_argdiff::NoChange, # TODO: Diffed wrapper? args::Tuple, kernel_argdiffs::Tuple, choices::ChoiceMap, state::SwitchUpdateState{T}) where {C, N, K, T} - if index != getfield(state.prev_trace, :index) - merged = merge(get_choices(state.prev_trace), choices) - branch_fn = getfield(gen_fn.mix, index) - new_trace, weight = generate(branch_fn, args, merged) - retdiff, discard = UnknownChange(), get_choices(getfield(state.prev_trace, :branch)) - weight -= get_score(state.prev_trace) - else - new_trace, weight, retdiff, discard = update(getfield(state.prev_trace, :branch), args, kernel_argdiffs, choices) - end + new_trace, weight, retdiff, discard = update(getfield(state.prev_trace, :branch), args, kernel_argdiffs, choices) state.index = index state.weight = weight state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) From eaf3327d10763e7ac090258816c20ccf7d42d0dd Mon Sep 17 00:00:00 2001 From: femtomc Date: Wed, 18 Nov 2020 08:54:20 -0500 Subject: [PATCH 13/30] Branch dispatch done using diff types. --- src/modeling_library/switch/regenerate.jl | 32 ++++++++++++++++------- 1 file changed, 22 insertions(+), 10 deletions(-) diff --git a/src/modeling_library/switch/regenerate.jl b/src/modeling_library/switch/regenerate.jl index a3d4a35f..f3499603 100644 --- a/src/modeling_library/switch/regenerate.jl +++ b/src/modeling_library/switch/regenerate.jl @@ -11,20 +11,32 @@ end function process!(gen_fn::Switch{C, N, K, T}, index::Int, - index_argdiff::Diff, + index_argdiff::UnknownChange, args::Tuple, kernel_argdiffs::Tuple, selection::Selection, state::SwitchRegenerateState{T}) where {C, N, K, T} - if index != getfield(state.prev_trace, :index) - merged = get_choices(state.prev_trace) - branch_fn = getfield(gen_fn.mix, index) - new_trace, weight = generate(branch_fn, args, merged) - retdiff = UnknownChange() - weight -= get_score(state.prev_trace) - else - new_trace, weight, retdiff = regenerate(getfield(state.prev_trace, :branch), args, kernel_argdiffs, selection) - end + merged = get_choices(state.prev_trace) + branch_fn = getfield(gen_fn.mix, index) + new_trace, weight = generate(branch_fn, args, merged) + retdiff = UnknownChange() + weight -= get_score(state.prev_trace) + state.index = index + state.weight = weight + state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) + state.score = get_score(new_trace) + state.trace = new_trace + state.retdiff = retdiff +end + +function process!(gen_fn::Switch{C, N, K, T}, + index::Int, + index_argdiff::NoChange, + args::Tuple, + kernel_argdiffs::Tuple, + selection::Selection, + state::SwitchRegenerateState{T}) where {C, N, K, T} + new_trace, weight, retdiff = regenerate(getfield(state.prev_trace, :branch), args, kernel_argdiffs, selection) state.index = index state.weight = weight state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) From 6d58aac726f589a70f383663e7940d582047e4ac Mon Sep 17 00:00:00 2001 From: femtomc Date: Wed, 18 Nov 2020 08:55:47 -0500 Subject: [PATCH 14/30] Branch dispatch done using diff types. --- src/modeling_library/switch/update.jl | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/modeling_library/switch/update.jl b/src/modeling_library/switch/update.jl index 492dc11a..b343504b 100644 --- a/src/modeling_library/switch/update.jl +++ b/src/modeling_library/switch/update.jl @@ -17,11 +17,15 @@ function process!(gen_fn::Switch{C, N, K, T}, kernel_argdiffs::Tuple, choices::ChoiceMap, state::SwitchUpdateState{T}) where {C, N, K, T, DV} + + # Generate new trace. merged = merge(get_choices(state.prev_trace), choices) branch_fn = getfield(gen_fn.mix, index) new_trace, weight = generate(branch_fn, args, merged) retdiff, discard = UnknownChange(), get_choices(getfield(state.prev_trace, :branch)) weight -= get_score(state.prev_trace) + + # Set state. state.index = index state.weight = weight state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) @@ -38,7 +42,11 @@ function process!(gen_fn::Switch{C, N, K, T}, kernel_argdiffs::Tuple, choices::ChoiceMap, state::SwitchUpdateState{T}) where {C, N, K, T} + + # Update trace. new_trace, weight, retdiff, discard = update(getfield(state.prev_trace, :branch), args, kernel_argdiffs, choices) + + # Set state. state.index = index state.weight = weight state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) From e413e9cbd3ecf1e7521a0d0d6c413b477c3670c3 Mon Sep 17 00:00:00 2001 From: femtomc Date: Wed, 18 Nov 2020 15:09:08 -0500 Subject: [PATCH 15/30] Added custom methods in update for Switch which allow the merging of overlapping choice maps and the computation of the discard addresses in update. --- src/modeling_library/switch/update.jl | 51 ++++++++++++++++++++++++--- test/modeling_library/switch.jl | 34 +++++++++++------- 2 files changed, 68 insertions(+), 17 deletions(-) diff --git a/src/modeling_library/switch/update.jl b/src/modeling_library/switch/update.jl index b343504b..63204065 100644 --- a/src/modeling_library/switch/update.jl +++ b/src/modeling_library/switch/update.jl @@ -10,6 +10,50 @@ mutable struct SwitchUpdateState{T} SwitchUpdateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace) end +function update_recurse_merge(prev_choices::ChoiceMap, choices::ChoiceMap) + prev_choice_submap_iterator = get_submaps_shallow(prev_choices) + prev_choice_value_iterator = get_values_shallow(prev_choices) + choice_submap_iterator = get_submaps_shallow(choices) + choice_value_iterator = get_values_shallow(choices) + choices = DynamicChoiceMap() + for (key, value) in prev_choice_value_iterator + key in keys(choice_value_iterator) && continue + set_value!(choices, key, value) + end + for (key, node1) in prev_choice_submap_iterator + if key in keys(choice_submap_iterator) + node2 = get_submap(choices, key) + node = update_recurse_merge(node1, node2) + set_submap!(choices, key, node) + else + set_submap!(choices, key, node1) + end + end + for (key, value) in choice_value_iterator + set_value!(choices, key, value) + end + for (key, node) in filter((k, _) -> !(k in keys(prev_choice_submap_iterator)), choice_submap_iterator) + set_submap!(choices, key, node) + end + return choices +end + +function update_discard(prev_trace::Trace, choices::ChoiceMap, new_trace::Trace) + discard = choicemap() + prev_choices = get_choices(prev_trace) + for (k, v) in get_submaps_shallow(prev_choices) + get_submap(get_choices(new_trace), k) isa EmptyChoiceMap && continue + get_submap(choices, k) isa EmptyChoiceMap && continue + set_submap!(discard, k, v) + end + for (k, v) in get_values_shallow(prev_choices) + has_value(get_choices(new_trace), k) || continue + has_value(choices, k) || continue + set_value!(discard, k, v) + end + discard +end + function process!(gen_fn::Switch{C, N, K, T}, index::Int, index_argdiff::UnknownChange, # TODO: Diffed wrapper? @@ -19,11 +63,11 @@ function process!(gen_fn::Switch{C, N, K, T}, state::SwitchUpdateState{T}) where {C, N, K, T, DV} # Generate new trace. - merged = merge(get_choices(state.prev_trace), choices) + merged = update_recurse_merge(get_choices(state.prev_trace), choices) branch_fn = getfield(gen_fn.mix, index) new_trace, weight = generate(branch_fn, args, merged) - retdiff, discard = UnknownChange(), get_choices(getfield(state.prev_trace, :branch)) weight -= get_score(state.prev_trace) + state.discard = update_discard(state.prev_trace, choices, new_trace) # Set state. state.index = index @@ -31,8 +75,7 @@ function process!(gen_fn::Switch{C, N, K, T}, state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) state.score = get_score(new_trace) state.trace = new_trace - state.updated_retdiff = retdiff - state.discard = discard + state.updated_retdiff = UnknownChange() end function process!(gen_fn::Switch{C, N, K, T}, diff --git a/test/modeling_library/switch.jl b/test/modeling_library/switch.jl index dacaa573..19cea930 100644 --- a/test/modeling_library/switch.jl +++ b/test/modeling_library/switch.jl @@ -20,43 +20,51 @@ end # Standard. sc = Switch(foo, baz) chm, _, _ = propose(sc, (2, 5.0, 3.0)) -display(chm) tr = simulate(sc, (2, 5.0, 3.0)) -display(get_choices(tr)) chm = choicemap() chm[:z] = 5.0 tr, _ = generate(sc, (2, 5.0, 3.0), chm) -display(get_choices(tr)) # Cases. sc = Switch(Dict(:x => 1, :y => 2), foo, baz) chm, _, _ = propose(sc, (:x, 5.0, 3.0)) -display(chm) tr = simulate(sc, (:x, 5.0, 3.0)) -display(get_choices(tr)) chm = choicemap() chm[:z] = 5.0 tr, _ = generate(sc, (:x, 5.0, 3.0), chm) -display(get_choices(tr)) # ------------ Static DSL ------------ # -@gen (static) function bam(s::Symbol) +@gen (static) function bang((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + y, std), :z) + return z +end + +@gen (static) function fuzz((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + 2 * y, std), :z) + return z +end + +sc = Switch(bang, fuzz) + +@gen (static) function bam(s::Int) x ~ sc(s, 5.0, 3.0) end Gen.@load_generated_functions() -tr = simulate(bam, (:x, )) -display(get_score(tr)) +tr = simulate(bam, (1, )) -new_tr, w = update(tr, (:y, ), (UnknownChange(), ), choicemap()) -display(get_score(new_tr) - w) +chm = choicemap((:x => :z, 5.0)) +new_tr, w, rd, discard = update(tr, (2, ), (UnknownChange(), ), chm) +display(get_choices(new_tr)) +display(discard) -new_tr, w = regenerate(tr, (:x, ), (UnknownChange(), ), select()) -display(get_score(new_tr) - w) +new_tr, w = regenerate(tr, (1, ), (UnknownChange(), ), select()) end # module From 435493f0cba8bd5ba8cdf954644a1aa2f2e7b71a Mon Sep 17 00:00:00 2001 From: femtomc Date: Wed, 18 Nov 2020 15:19:04 -0500 Subject: [PATCH 16/30] Added custom methods in update for Switch which allow the merging of overlapping choice maps and the computation of the discard addresses in update. --- src/modeling_library/switch/update.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/modeling_library/switch/update.jl b/src/modeling_library/switch/update.jl index 63204065..773bed66 100644 --- a/src/modeling_library/switch/update.jl +++ b/src/modeling_library/switch/update.jl @@ -32,7 +32,9 @@ function update_recurse_merge(prev_choices::ChoiceMap, choices::ChoiceMap) for (key, value) in choice_value_iterator set_value!(choices, key, value) end - for (key, node) in filter((k, _) -> !(k in keys(prev_choice_submap_iterator)), choice_submap_iterator) + for (key, node) in filter(choice_submap_iterator) do (k, _) + !(k in keys(prev_choice_submap_iterator)) + end set_submap!(choices, key, node) end return choices From 32fec4f96f981baab3a1799be711714d4217db8e Mon Sep 17 00:00:00 2001 From: femtomc Date: Wed, 18 Nov 2020 17:34:01 -0500 Subject: [PATCH 17/30] Idiomatic check for EmptyChoiceMap. --- src/modeling_library/switch/update.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/modeling_library/switch/update.jl b/src/modeling_library/switch/update.jl index 773bed66..e60c53ee 100644 --- a/src/modeling_library/switch/update.jl +++ b/src/modeling_library/switch/update.jl @@ -44,8 +44,8 @@ function update_discard(prev_trace::Trace, choices::ChoiceMap, new_trace::Trace) discard = choicemap() prev_choices = get_choices(prev_trace) for (k, v) in get_submaps_shallow(prev_choices) - get_submap(get_choices(new_trace), k) isa EmptyChoiceMap && continue - get_submap(choices, k) isa EmptyChoiceMap && continue + isempty(get_submap(get_choices(new_trace), k)) && continue + isempty(get_submap(choices, k)) && continue set_submap!(discard, k, v) end for (k, v) in get_values_shallow(prev_choices) From bb767e71266a87265b8c1cf87cf1a56f5641c384 Mon Sep 17 00:00:00 2001 From: femtomc Date: Wed, 18 Nov 2020 17:52:49 -0500 Subject: [PATCH 18/30] Working on backprop - seems simple? Could it really be? --- src/modeling_library/switch/backprop.jl | 2 ++ src/modeling_library/switch/regenerate.jl | 3 +-- src/modeling_library/switch/switch.jl | 1 + 3 files changed, 4 insertions(+), 2 deletions(-) create mode 100644 src/modeling_library/switch/backprop.jl diff --git a/src/modeling_library/switch/backprop.jl b/src/modeling_library/switch/backprop.jl new file mode 100644 index 00000000..4f3fb96f --- /dev/null +++ b/src/modeling_library/switch/backprop.jl @@ -0,0 +1,2 @@ +@inline choice_gradients(trace::SwitchTrace{T}, selection::Selection, retval_grad) where T = (nothing, choice_gradients(getfield(trace, :branch), selection, retval_grad)...) +@inline accumulate_param_gradients(trace::SwitchTrace{T}, retval_grad) where {T} = (nothing, accumulate_param_gradients(getfield(trace, :branch), retval_grad)...) diff --git a/src/modeling_library/switch/regenerate.jl b/src/modeling_library/switch/regenerate.jl index f3499603..d5c9d121 100644 --- a/src/modeling_library/switch/regenerate.jl +++ b/src/modeling_library/switch/regenerate.jl @@ -16,9 +16,8 @@ function process!(gen_fn::Switch{C, N, K, T}, kernel_argdiffs::Tuple, selection::Selection, state::SwitchRegenerateState{T}) where {C, N, K, T} - merged = get_choices(state.prev_trace) branch_fn = getfield(gen_fn.mix, index) - new_trace, weight = generate(branch_fn, args, merged) + new_trace, weight = generate(branch_fn, args, get_choices(state.prev_trace)) retdiff = UnknownChange() weight -= get_score(state.prev_trace) state.index = index diff --git a/src/modeling_library/switch/switch.jl b/src/modeling_library/switch/switch.jl index a3e18bcf..e8ee756e 100644 --- a/src/modeling_library/switch/switch.jl +++ b/src/modeling_library/switch/switch.jl @@ -34,3 +34,4 @@ include("simulate.jl") include("generate.jl") include("update.jl") include("regenerate.jl") +include("backprop.jl") From a35e2e70382c0af81261078712c4d7c9b193e777 Mon Sep 17 00:00:00 2001 From: femtomc Date: Wed, 18 Nov 2020 17:55:20 -0500 Subject: [PATCH 19/30] Extracting WithProb combinator into another PR. --- src/modeling_library/cond.jl | 39 ----------------- src/modeling_library/modeling_library.jl | 1 - src/modeling_library/with_prob/assess.jl | 26 ------------ src/modeling_library/with_prob/generate.jl | 47 --------------------- src/modeling_library/with_prob/propose.jl | 30 ------------- src/modeling_library/with_prob/simulate.jl | 33 --------------- src/modeling_library/with_prob/update.jl | 27 ------------ src/modeling_library/with_prob/with_prob.jl | 19 --------- 8 files changed, 222 deletions(-) delete mode 100644 src/modeling_library/with_prob/assess.jl delete mode 100644 src/modeling_library/with_prob/generate.jl delete mode 100644 src/modeling_library/with_prob/propose.jl delete mode 100644 src/modeling_library/with_prob/simulate.jl delete mode 100644 src/modeling_library/with_prob/update.jl delete mode 100644 src/modeling_library/with_prob/with_prob.jl diff --git a/src/modeling_library/cond.jl b/src/modeling_library/cond.jl index 270689a9..bc10b2c7 100644 --- a/src/modeling_library/cond.jl +++ b/src/modeling_library/cond.jl @@ -1,42 +1,3 @@ -# ------------ WithProbability trace ------------ # - -struct WithProbabilityTrace{T1, T2, Tr} <: Trace - gen_fn::GenerativeFunction{Union{T1, T2}, Tr} - p::Float64 - cond::Bool - branch::Tr - retval::Union{T1, T2} - args::Tuple - score::Float64 - noise::Float64 -end - -@inline function get_choices(tr::WithProbabilityTrace) - choices = choicemap() - set_submap!(choices, :branch, get_choices(tr.branch)) - set_value!(choices, :cond, tr.cond) - choices -end -@inline get_retval(tr::WithProbabilityTrace) = tr.retval -@inline get_args(tr::WithProbabilityTrace) = tr.args -@inline get_score(tr::WithProbabilityTrace) = tr.score -@inline get_gen_fn(tr::WithProbabilityTrace) = tr.gen_fn - -@inline function Base.getindex(tr::WithProbabilityTrace, addr::Pair) - (first, rest) = addr - subtr = getfield(trace, first) - subtrace[rest] -end -@inline Base.getindex(tr::WithProbabilityTrace, addr::Symbol) = getfield(trace, addr) - -function project(tr::WithProbabilityTrace, selection::Selection) - sum(map([:cond, :branch]) do k - subselection = selection[k] - project(getindex(tr, k), subselection) - end) -end -project(tr::WithProbabilityTrace, ::EmptySelection) = tr.noise - # ------------ Switch trace ------------ # struct SwitchTrace{T} <: Trace diff --git a/src/modeling_library/modeling_library.jl b/src/modeling_library/modeling_library.jl index a6004510..2572b0fd 100644 --- a/src/modeling_library/modeling_library.jl +++ b/src/modeling_library/modeling_library.jl @@ -76,7 +76,6 @@ include("map/map.jl") include("unfold/unfold.jl") include("recurse/recurse.jl") include("switch/switch.jl") -include("with_prob/with_prob.jl") ############################################################# # abstractions for constructing custom generative functions # diff --git a/src/modeling_library/with_prob/assess.jl b/src/modeling_library/with_prob/assess.jl deleted file mode 100644 index da9a77a7..00000000 --- a/src/modeling_library/with_prob/assess.jl +++ /dev/null @@ -1,26 +0,0 @@ -mutable struct WithProbabilityAssessState{T} - weight::Float64 - retval::T -end - -function process_new!(gen_fn::WithProbability{T}, - branch_p::Float64, - args::Tuple, - choices::ChoiceMap, - state::WithProbabilityAssessState{T}) where T - flip = get_value(choices, :cond) - state.weight += logpdf(Bernoulli(), flip, branch_p) - submap = get_submap(choices, :branch) - (weight, retval) = assess(gen_fn.kernel, kernel_args, submap) - state.weight += weight - state.retval = retval -end - -function assess(gen_fn::WithProbability{T}, - args::Tuple, - choices::ChoiceMap) where T - branch_p = args[1] - state = WithProbabilityAssessState{T}(0.0) - process_new!(gen_fn, branch_p, args[2 : end], choices, state) - (state.weight, state.retval) -end diff --git a/src/modeling_library/with_prob/generate.jl b/src/modeling_library/with_prob/generate.jl deleted file mode 100644 index 45d5540a..00000000 --- a/src/modeling_library/with_prob/generate.jl +++ /dev/null @@ -1,47 +0,0 @@ -mutable struct WithProbabilityGenerateState{T} - score::Float64 - noise::Float64 - weight::Float64 - cond::Bool - subtrace::Trace - retval::T - WithProbabilityGenerateState{T}(score::Float64, noise::Float64, weight::Float64) where T = new{T}(score, noise, weight) -end - -function process!(gen_fn::WithProbability{T}, - branch_p::Float64, - args::Tuple, - choices::ChoiceMap, - state::WithProbabilityGenerateState{T}) where T - - # sample from Bernoulli with probability branch_p - flip_d = bernoulli(branch_p) - - # check for constraints at :cond - constrained = has_value(choices, :cond) - !constrained && check_no_submap(choices, :cond) - - # get/constrain flip value - constrained ? (flip = get_value(choices, :cond); state.weight += logpdf(Bernoulli(), flip, branch_p)) : flip = rand(flip_d) - state.cond = flip - - # generate subtrace - constraints = get_submap(choices, :branch) - (subtrace, weight) = generate(flip ? gen_fn.a : gen_fn.b, args, constraints) - state.subtrace = subtrace - state.weight += weight - - # return from branch - state.retval = get_retval(subtrace) -end - -function generate(gen_fn::WithProbability{T}, - args::Tuple, - choices::ChoiceMap) where T - - branch_p = args[1] - state = WithProbabilityGenerateState{T}(0.0, 0.0, 0.0) - process!(gen_fn, branch_p, args[2 : end], choices, state) - trace = WithProbabilityTrace{T}(gen_fn, branch_p, state.cond, state.subtrace, state.retval, args[2 : end], state.score, state.noise) - (trace, state.weight) -end diff --git a/src/modeling_library/with_prob/propose.jl b/src/modeling_library/with_prob/propose.jl deleted file mode 100644 index fd94b7db..00000000 --- a/src/modeling_library/with_prob/propose.jl +++ /dev/null @@ -1,30 +0,0 @@ -mutable struct WithProbabilityProposeState{T} - choices::DynamicChoiceMap - weight::Float64 - retval::T - WithProbabilityProposeState{T}(choices, weight) where T = new{T}(choices, weight) -end - -function process_new!(gen_fn::WithProbability{T}, - branch_p::Float64, - args::Tuple, - state::WithProbabilityProposeState{T}) where T - - flip = bernoulli(branch_p) - (submap, weight, retval) = propose(flip ? gen_fn.a : gen_fn.b, args) - set_value!(state.choices, :cond, flip) - state.weight += logpdf(Bernoulli(), flip, branch_p) - set_submap!(state.choices, :branch, submap) - state.weight += weight - state.retval = retval -end - -function propose(gen_fn::WithProbability{T}, - args::Tuple) where T - - branch_p = args[1] - choices = choicemap() - state = WithProbabilityProposeState{T}(choices, 0.0) - process_new!(gen_fn, branch_p, args[2:end], state) - (state.choices, state.weight, state.retval) -end diff --git a/src/modeling_library/with_prob/simulate.jl b/src/modeling_library/with_prob/simulate.jl deleted file mode 100644 index 2c8d40fa..00000000 --- a/src/modeling_library/with_prob/simulate.jl +++ /dev/null @@ -1,33 +0,0 @@ -mutable struct WithProbabilitySimulateState{T} - score::Float64 - noise::Float64 - cond::Bool - subtrace::Trace - retval::T - WithProbabilitySimulateState{T}(score::Float64, noise::Float64) where T = new{T}(score, noise) -end - -function process!(gen_fn::WithProbability{T}, - branch_p::Float64, - args::Tuple, - state::WithProbabilitySimulateState{T}) where T - local retval::T - flip = bernoulli(branch_p) - state.score += logpdf(Bernoulli(), flip, branch_p) - state.cond = flip - subtrace = simulate(flip ? gen_fn.a : gen_fn.b, args) - state.noise += project(subtrace, EmptySelection()) - state.subtrace = subtrace - state.score += get_score(subtrace) - state.retval = get_retval(subtrace) -end - -function simulate(gen_fn::WithProbability{T}, - args::Tuple) where T - - branch_p = args[1] - state = WithProbabilitySimulateState{T}(0.0, 0.0) - process!(gen_fn, branch_p, args[2 : end], state) - trace = WithProbabilityTrace{T}(gen_fn, branch_p, state.cond, state.subtrace, state.retval, args[2 : end], state.score, state.noise) - trace -end diff --git a/src/modeling_library/with_prob/update.jl b/src/modeling_library/with_prob/update.jl deleted file mode 100644 index 5e482c85..00000000 --- a/src/modeling_library/with_prob/update.jl +++ /dev/null @@ -1,27 +0,0 @@ -mutable struct WithProbabilityUpdateState{T} - weight::Float64 - score::Float64 - noise::Float64 - cond::Bool - prev_trace::Trace - trace::Trace - discard::DynamicChoiceMap - updated_retdiff::Diff -end - -function process!(gen_fn::WithProbability{T} - branch_p::Float64, - args::Tuple, - kernel_argdiffs::Tuple, - choices::ChoiceMap, - state::WithProbabilityUpdateState{T}) where T -end - -function update(trace::WithProbability{T}, - args::Tuple, - argdiffs::Tuple, - choices::ChoiceMap) where T - gen_fn = trace.gen_fn - branch_p, branch_p_diff = args[1], argdiffs[1] - return (new_trace, state.weight, retdiff, discard) -end diff --git a/src/modeling_library/with_prob/with_prob.jl b/src/modeling_library/with_prob/with_prob.jl deleted file mode 100644 index 396a9b8d..00000000 --- a/src/modeling_library/with_prob/with_prob.jl +++ /dev/null @@ -1,19 +0,0 @@ -struct WithProbability{T} <: GenerativeFunction{T, Trace} - a::GenerativeFunction{T} - b::GenerativeFunction{T} -end - -export WithProbability - -has_argument_grads(switch_fn::WithProbability) = has_argument_grads(switch_fn.a) && has_argument_grads(switch_fn.b) -accepts_output_grad(switch_fn::WithProbability) = accepts_output_grad(switch_fn.a) && accepts_output_grad(switch_fn.b) - -function (gen_fn::WithProbability)(flip_p::Float64, args...) - (_, _, retval) = propose(gen_fn, (flip_p, args...)) - retval -end - -include("assess.jl") -include("propose.jl") -include("simulate.jl") -include("generate.jl") From 562667e9aa70762d28a91a0795783174e825e613 Mon Sep 17 00:00:00 2001 From: femtomc Date: Wed, 18 Nov 2020 19:33:40 -0500 Subject: [PATCH 20/30] Testing backprop. --- test/modeling_library/switch.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/test/modeling_library/switch.jl b/test/modeling_library/switch.jl index 19cea930..1674d40f 100644 --- a/test/modeling_library/switch.jl +++ b/test/modeling_library/switch.jl @@ -55,6 +55,7 @@ sc = Switch(bang, fuzz) @gen (static) function bam(s::Int) x ~ sc(s, 5.0, 3.0) + return x end Gen.@load_generated_functions() @@ -62,9 +63,15 @@ tr = simulate(bam, (1, )) chm = choicemap((:x => :z, 5.0)) new_tr, w, rd, discard = update(tr, (2, ), (UnknownChange(), ), chm) -display(get_choices(new_tr)) display(discard) +display(get_choices(new_tr)) new_tr, w = regenerate(tr, (1, ), (UnknownChange(), ), select()) +sel = AllSelection() +arg_grads, cvs, cgs = choice_gradients(tr, sel, 1.0) +display(arg_grads) +display(cvs) +display(cgs) + end # module From b74a071d4b02c4f3a94cdd07324ca6c9ffccede6 Mon Sep 17 00:00:00 2001 From: femtomc Date: Wed, 18 Nov 2020 19:48:19 -0500 Subject: [PATCH 21/30] Fixed backprop - was thinking in Zygote lang. Gradients appear to be propagating correctly now. --- src/modeling_library/switch/backprop.jl | 4 ++-- src/static_ir/backprop.jl | 2 +- test/modeling_library/switch.jl | 3 +-- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/src/modeling_library/switch/backprop.jl b/src/modeling_library/switch/backprop.jl index 4f3fb96f..5a2fc04a 100644 --- a/src/modeling_library/switch/backprop.jl +++ b/src/modeling_library/switch/backprop.jl @@ -1,2 +1,2 @@ -@inline choice_gradients(trace::SwitchTrace{T}, selection::Selection, retval_grad) where T = (nothing, choice_gradients(getfield(trace, :branch), selection, retval_grad)...) -@inline accumulate_param_gradients(trace::SwitchTrace{T}, retval_grad) where {T} = (nothing, accumulate_param_gradients(getfield(trace, :branch), retval_grad)...) +@inline choice_gradients(trace::SwitchTrace{T}, selection::Selection, retval_grad) where T = choice_gradients(getfield(trace, :branch), selection, retval_grad) +@inline accumulate_param_gradients(trace::SwitchTrace{T}, retval_grad) where {T} = accumulate_param_gradients(getfield(trace, :branch), retval_grad) diff --git a/src/static_ir/backprop.jl b/src/static_ir/backprop.jl index 99594cb7..96f63da0 100644 --- a/src/static_ir/backprop.jl +++ b/src/static_ir/backprop.jl @@ -353,7 +353,7 @@ function get_selected_choices(::EmptyAddressSchema, ::StaticIR) end function get_selected_choices(::AllAddressSchema, ir::StaticIR) - Set{RandomChoiceNodes}(ir.choice_nodes) + Set{RandomChoiceNode}(ir.choice_nodes) end function get_selected_choices(schema::StaticAddressSchema, ir::StaticIR) diff --git a/test/modeling_library/switch.jl b/test/modeling_library/switch.jl index 1674d40f..6cce0f86 100644 --- a/test/modeling_library/switch.jl +++ b/test/modeling_library/switch.jl @@ -68,10 +68,9 @@ display(get_choices(new_tr)) new_tr, w = regenerate(tr, (1, ), (UnknownChange(), ), select()) -sel = AllSelection() +sel = select(:x => :z) arg_grads, cvs, cgs = choice_gradients(tr, sel, 1.0) display(arg_grads) -display(cvs) display(cgs) end # module From 849d61e0dde8eabbcc3cdbe068ad4338a323f251 Mon Sep 17 00:00:00 2001 From: femtomc Date: Wed, 18 Nov 2020 21:38:19 -0500 Subject: [PATCH 22/30] Added docstring and docs example. --- docs/src/ref/combinators.md | 37 +++++++++++++++++++++++++++ src/modeling_library/switch/switch.jl | 18 +++++++++++++ test/modeling_library/switch.jl | 7 +++-- 3 files changed, 60 insertions(+), 2 deletions(-) diff --git a/docs/src/ref/combinators.md b/docs/src/ref/combinators.md index 8c18258b..71f51cc8 100644 --- a/docs/src/ref/combinators.md +++ b/docs/src/ref/combinators.md @@ -119,4 +119,41 @@ TODO: document me schematic of recurse combinatokr ``` +## Switch combinator +```@docs +Switch +``` + +In the schematic below, the kernel is denoted `S` and accepts an integer index `k`. + +Consider the following constructions: + +```julia +@gen function bang((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + y, std), :z) + return z +end + +@gen function fuzz((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + 2 * y, std), :z) + return z +end + +sc = Switch(bang, fuzz) +``` + +This creates a new generative function `sc`. We can then obtain the trace of `sc`: + +```julia +(trace, _) = simulate(sc, (2, 5.0, 3.0)) +``` + +The resulting trace contains the subtrace from the branch with index `2` - in this case, a call to `fuzz`: + +``` +│ +└── :z : 13.552870875213735 +``` diff --git a/src/modeling_library/switch/switch.jl b/src/modeling_library/switch/switch.jl index e8ee756e..2e6b17e6 100644 --- a/src/modeling_library/switch/switch.jl +++ b/src/modeling_library/switch/switch.jl @@ -35,3 +35,21 @@ include("generate.jl") include("update.jl") include("regenerate.jl") include("backprop.jl") + +@doc( +""" + gen_fn = Switch(gen_fns::GenerativeFunction...) + +Returns a new generative function that accepts an argument tuple of type `Tuple{Int, ...}` where the first index indicates which branch to call. + + gen_fn = Switch(d::Dict{T, Int}, gen_fns::GenerativeFunction...) where T + +Returns a new generative function that accepts an argument tuple of type `Tuple{Int, ...}` or an argument tuple of type `Tuple{T, ...}` where the first index either indicates which branch to call, or indicates an index into `d` which maps to the selected branch. This form is meant for convenience - it allows the programmer to use `d` like if-else or case statements. + +`Switch` is designed to allow for the expression of patterns of if-else control flow. `gen_fns` must satisfy a few requirements: + +1. Each `gen_fn` in `gen_fns` must accept the same argument types. +2. Each `gen_fn` in `gen_fns` must return the same return type. + +Otherwise, each `gen_fn` can come from different modeling languages, possess different traces, etc. +""", Switch) diff --git a/test/modeling_library/switch.jl b/test/modeling_library/switch.jl index 6cce0f86..fac5fe0b 100644 --- a/test/modeling_library/switch.jl +++ b/test/modeling_library/switch.jl @@ -51,14 +51,17 @@ end return z end -sc = Switch(bang, fuzz) - @gen (static) function bam(s::Int) x ~ sc(s, 5.0, 3.0) return x end Gen.@load_generated_functions() +sc = Switch(bang, fuzz) +tr = simulate(sc, (2, 5.0, 3.0)) +display(get_choices(tr)) +display(tr.index) + tr = simulate(bam, (1, )) chm = choicemap((:x => :z, 5.0)) From adf73a5a21bf1c2d851b53ec38502aaa3985bb7c Mon Sep 17 00:00:00 2001 From: femtomc Date: Thu, 19 Nov 2020 18:30:11 -0500 Subject: [PATCH 23/30] Fixed numerous bugs uncovered while constructing test suite. One serious bug in semantics for regenerate - when switching branches, should generate with choice map constraints except those addresses which are in selection. --- src/modeling_library/cond.jl | 2 +- src/modeling_library/switch/assess.jl | 3 +- src/modeling_library/switch/regenerate.jl | 23 +- src/modeling_library/switch/simulate.jl | 2 +- src/modeling_library/switch/switch.jl | 7 +- src/modeling_library/switch/update.jl | 20 +- test/inference/particle_filter.jl | 2 +- test/modeling_library/modeling_library.jl | 1 + test/modeling_library/switch.jl | 244 +++++++++++++++------- 9 files changed, 208 insertions(+), 96 deletions(-) diff --git a/src/modeling_library/cond.jl b/src/modeling_library/cond.jl index bc10b2c7..5d398524 100644 --- a/src/modeling_library/cond.jl +++ b/src/modeling_library/cond.jl @@ -21,7 +21,7 @@ end subtr = getfield(trace, first) subtrace[rest] end -@inline Base.getindex(tr::SwitchTrace, addr::Symbol) = getfield(trace, addr) +@inline Base.getindex(tr::SwitchTrace, addr::Symbol) = getindex(tr.branch, addr) @inline project(tr::SwitchTrace, selection::Selection) = project(tr.branch, selection) @inline project(tr::SwitchTrace, ::EmptySelection) = tr.noise diff --git a/src/modeling_library/switch/assess.jl b/src/modeling_library/switch/assess.jl index ed3656e5..eb932855 100644 --- a/src/modeling_library/switch/assess.jl +++ b/src/modeling_library/switch/assess.jl @@ -1,6 +1,7 @@ mutable struct SwitchAssessState{T} weight::Float64 retval::T + SwitchAssessState{T}(weight::Float64) where T = new{T}(weight) end function process!(gen_fn::Switch{C, N, K, T}, @@ -8,7 +9,7 @@ function process!(gen_fn::Switch{C, N, K, T}, args::Tuple, choices::ChoiceMap, state::SwitchAssessState{T}) where {C, N, K, T} - (weight, retval) = assess(getindex(gen_fn.mix, index), kernel_args, choices) + (weight, retval) = assess(getindex(gen_fn.mix, index), args, choices) state.weight += weight state.retval = retval end diff --git a/src/modeling_library/switch/regenerate.jl b/src/modeling_library/switch/regenerate.jl index d5c9d121..9e582c7c 100644 --- a/src/modeling_library/switch/regenerate.jl +++ b/src/modeling_library/switch/regenerate.jl @@ -9,6 +9,26 @@ mutable struct SwitchRegenerateState{T} SwitchRegenerateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace) end +function regenerate_recurse_merge(prev_choices::ChoiceMap, selection::Selection) + prev_choice_submap_iterator = get_submaps_shallow(prev_choices) + prev_choice_value_iterator = get_values_shallow(prev_choices) + new_choices = DynamicChoiceMap() + for (key, value) in prev_choice_value_iterator + key in selection && continue + set_value!(new_choices, key, value) + end + for (key, node1) in prev_choice_submap_iterator + if key in selection + subsel = get_subselection(selection, key) + node = regenerate_recurse_merge(node1, subsel) + set_submap!(new_choices, key, node) + else + set_submap!(new_choices, key, node1) + end + end + return new_choices +end + function process!(gen_fn::Switch{C, N, K, T}, index::Int, index_argdiff::UnknownChange, @@ -17,7 +37,8 @@ function process!(gen_fn::Switch{C, N, K, T}, selection::Selection, state::SwitchRegenerateState{T}) where {C, N, K, T} branch_fn = getfield(gen_fn.mix, index) - new_trace, weight = generate(branch_fn, args, get_choices(state.prev_trace)) + merged = regenerate_recurse_merge(get_choices(state.prev_trace), selection) + new_trace, weight = generate(branch_fn, args, merged) retdiff = UnknownChange() weight -= get_score(state.prev_trace) state.index = index diff --git a/src/modeling_library/switch/simulate.jl b/src/modeling_library/switch/simulate.jl index 13528f4f..ef75bb48 100644 --- a/src/modeling_library/switch/simulate.jl +++ b/src/modeling_library/switch/simulate.jl @@ -28,5 +28,5 @@ function simulate(gen_fn::Switch{C, N, K, T}, index = args[1] state = SwitchSimulateState{T}(0.0, 0.0) process!(gen_fn, index, args[2 : end], state) - SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise) + return SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise) end diff --git a/src/modeling_library/switch/switch.jl b/src/modeling_library/switch/switch.jl index 2e6b17e6..ca90556b 100644 --- a/src/modeling_library/switch/switch.jl +++ b/src/modeling_library/switch/switch.jl @@ -12,11 +12,12 @@ struct Switch{C, N, K, T} <: GenerativeFunction{T, Trace} new{C, length(gen_fns), typeof(gen_fns), rettype}(gen_fns, d) end end - export Switch -has_argument_grads(switch_fn::Switch) = all(has_argument_grads, switch.mix) -accepts_output_grad(switch_fn::Switch) = all(accepts_output_grad, switch.mix) +has_argument_grads(switch_fn::Switch) = map(zip(map(has_argument_grads, switch_fn.mix)...)) do as + all(as) +end +accepts_output_grad(switch_fn::Switch) = all(accepts_output_grad, switch_fn.mix) function (gen_fn::Switch)(index::Int, args...) (_, _, retval) = propose(gen_fn, (index, args...)) diff --git a/src/modeling_library/switch/update.jl b/src/modeling_library/switch/update.jl index e60c53ee..dbedce94 100644 --- a/src/modeling_library/switch/update.jl +++ b/src/modeling_library/switch/update.jl @@ -15,29 +15,29 @@ function update_recurse_merge(prev_choices::ChoiceMap, choices::ChoiceMap) prev_choice_value_iterator = get_values_shallow(prev_choices) choice_submap_iterator = get_submaps_shallow(choices) choice_value_iterator = get_values_shallow(choices) - choices = DynamicChoiceMap() + new_choices = DynamicChoiceMap() for (key, value) in prev_choice_value_iterator key in keys(choice_value_iterator) && continue - set_value!(choices, key, value) + set_value!(new_choices, key, value) end for (key, node1) in prev_choice_submap_iterator if key in keys(choice_submap_iterator) node2 = get_submap(choices, key) node = update_recurse_merge(node1, node2) - set_submap!(choices, key, node) + set_submap!(new_choices, key, node) else - set_submap!(choices, key, node1) + set_submap!(new_choices, key, node1) end end for (key, value) in choice_value_iterator - set_value!(choices, key, value) + set_value!(new_choices, key, value) end - for (key, node) in filter(choice_submap_iterator) do (k, _) - !(k in keys(prev_choice_submap_iterator)) - end - set_submap!(choices, key, node) + sel, _ = zip(prev_choice_submap_iterator...) + comp = complement(select(sel...)) + for (key, node) in get_submaps_shallow(get_selected(choices, comp)) + set_submap!(new_choices, key, node) end - return choices + return new_choices end function update_discard(prev_trace::Trace, choices::ChoiceMap, new_trace::Trace) diff --git a/test/inference/particle_filter.jl b/test/inference/particle_filter.jl index 982d0eec..2e32a60a 100644 --- a/test/inference/particle_filter.jl +++ b/test/inference/particle_filter.jl @@ -164,7 +164,7 @@ end # check log marginal likelihood estimate expected_log_ml = log(hmm_forward_alg(prior, emission_dists, transition_dists, obs_x)) actual_log_ml_est = log_ml_estimate(state) - @test isapprox(expected_log_ml, actual_log_ml_est, atol=0.01) + @test isapprox(expected_log_ml, actual_log_ml_est, atol=0.02) end end diff --git a/test/modeling_library/modeling_library.jl b/test/modeling_library/modeling_library.jl index 616110f8..2ebb8929 100644 --- a/test/modeling_library/modeling_library.jl +++ b/test/modeling_library/modeling_library.jl @@ -5,4 +5,5 @@ include("call_at.jl") include("map.jl") include("unfold.jl") include("recurse.jl") +include("switch.jl") include("dist_dsl.jl") diff --git a/test/modeling_library/switch.jl b/test/modeling_library/switch.jl index fac5fe0b..d781c949 100644 --- a/test/modeling_library/switch.jl +++ b/test/modeling_library/switch.jl @@ -1,79 +1,167 @@ -module SwitchComb - -include("../../src/Gen.jl") -using .Gen - -# ------------ Toplevel caller ------------ # - -@gen (grad) function foo((grad)(x::Float64), (grad)(y::Float64)) - std::Float64 = 3.0 - z = @trace(normal(x + y, std), :z) - return z -end - -@gen (grad) function baz((grad)(x::Float64), (grad)(y::Float64)) - std::Float64 = 3.0 - z = @trace(normal(x + 2 * y, std), :z) - return z -end - -# Standard. -sc = Switch(foo, baz) -chm, _, _ = propose(sc, (2, 5.0, 3.0)) - -tr = simulate(sc, (2, 5.0, 3.0)) - -chm = choicemap() -chm[:z] = 5.0 -tr, _ = generate(sc, (2, 5.0, 3.0), chm) - -# Cases. -sc = Switch(Dict(:x => 1, :y => 2), foo, baz) -chm, _, _ = propose(sc, (:x, 5.0, 3.0)) - -tr = simulate(sc, (:x, 5.0, 3.0)) - -chm = choicemap() -chm[:z] = 5.0 -tr, _ = generate(sc, (:x, 5.0, 3.0), chm) - -# ------------ Static DSL ------------ # - -@gen (static) function bang((grad)(x::Float64), (grad)(y::Float64)) - std::Float64 = 3.0 - z = @trace(normal(x + y, std), :z) - return z -end - -@gen (static) function fuzz((grad)(x::Float64), (grad)(y::Float64)) - std::Float64 = 3.0 - z = @trace(normal(x + 2 * y, std), :z) - return z -end - -@gen (static) function bam(s::Int) - x ~ sc(s, 5.0, 3.0) - return x +@testset "switch combinator" begin + + # ------------ Trace ------------ # + + @gen function swtrg() + z ~ normal(3.0, 5.0) + return z + end + + @testset "switch trace" begin + tr = simulate(swtrg, ()) + swtr = Gen.SwitchTrace(swtrg, 1, tr, get_retval(tr), (), get_score(tr), 0.0) + @test swtr[:z] == tr[:z] + @test project(swtr, AllSelection()) == project(swtr.branch, AllSelection()) + @test project(swtr, EmptySelection()) == swtr.noise + end + + # ------------ Bare combinator ------------ # + + # Model chunk. + @gen (grad) function foo((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + y, std), :z) + return z + end + + @gen (grad) function baz((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + 2 * y, std), :z) + return z + end + sc = Switch(foo, baz) + # ----. + + @testset "simulate" begin + tr = simulate(sc, (1, 5.0, 3.0)) + @test isapprox(get_score(tr), logpdf(normal, tr[:z], 5.0 + 3.0, 3.0)) + tr = simulate(sc, (2, 5.0, 3.0)) + @test isapprox(get_score(tr), logpdf(normal, tr[:z], 5.0 + 2 * 3.0, 3.0)) + end + + @testset "generate" begin + chm = choicemap() + chm[:z] = 5.0 + tr, w = generate(sc, (2, 5.0, 3.0), chm) + assignment = get_choices(tr) + @test assignment[:z] == 5.0 + @test isapprox(w, logpdf(normal, 5.0, 5.0 + 2 * 3.0, 3.0)) + end + + @testset "assess" begin + chm = choicemap() + chm[:z] = 5.0 + w, ret = assess(sc, (2, 5.0, 3.0), chm) + @test isapprox(w, logpdf(normal, 5.0, 5.0 + 2 * 3.0, 3.0)) + end + + @testset "propose" begin + chm, w = propose(sc, (2, 5.0, 3.0)) + @test isapprox(w, logpdf(normal, chm[:z], 5.0 + 2 * 3.0, 3.0)) + end + + @testset "update" begin + tr = simulate(sc, (1, 5.0, 3.0)) + old_sc = get_score(tr) + chm = choicemap((:x => :z, 5.0)) + new_tr, w, rd, discard = update(tr, (2, 5.0, 3.0), + (UnknownChange(), NoChange(), NoChange()), + chm) + @test old_sc == get_score(new_tr) - w + chm = choicemap((:x => :z, 10.0)) + new_tr, w, rd, discard = update(tr, (1, 5.0, 3.0), + (UnknownChange(), NoChange(), NoChange()), + chm) + @test old_sc == get_score(new_tr) - w + end + + @testset "regenerate" begin + tr = simulate(sc, (2, 5.0, 3.0)) + old_sc = get_score(tr) + sel = select(:z) + new_tr, w, rd = regenerate(tr, (2, 5.0, 3.0), + (UnknownChange(), NoChange(), NoChange()), + sel) + @test old_sc == get_score(new_tr) - w + new_tr, w, rd = regenerate(tr, (1, 5.0, 3.0), + (UnknownChange(), NoChange(), NoChange()), + sel) + @test old_sc == get_score(new_tr) - w + end + + @testset "choice gradients" begin + tr = simulate(sc, (2, 5.0, 3.0)) + sel = select(:z) + arg_grads, cvs, cgs = choice_gradients(tr, sel, 1.0) + end + + @testset "accumulate parameter gradients" begin + end + + # ------------ Hierarchy ------------ # + + # Model chunk. + @gen (grad) function bang((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + y, std), :z) + return z + end + @gen (grad) function fuzz((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + 2 * y, std), :z) + return z + end + sc = Switch(bang, fuzz) + @gen (grad) function bam(s::Int) + x ~ sc(s, 5.0, 3.0) + return x + end + # ----. + + @testset "simulate" begin + tr = simulate(bam, (2, )) + end + + @testset "generate" begin + chm = choicemap() + chm[:x => :z] = 5.0 + tr, _ = generate(sc, (2, 5.0, 3.0), chm) + end + + @testset "assess" begin + end + + @testset "propose" begin + end + + @testset "update" begin + tr = simulate(bam, (2, )) + old_sc = get_score(tr) + chm = choicemap((:x => :z, 5.0)) + new_tr, w, rd, discard = update(tr, (2, ), (UnknownChange(), ), chm) + @test old_sc == get_score(new_tr) - w + chm = choicemap((:x => :z, 10.0)) + new_tr, w, rd, discard = update(tr, (1, ), (UnknownChange(), ), chm) + @test old_sc == get_score(new_tr) - w + end + + @testset "regenerate" begin + tr = simulate(bam, (2, )) + old_sc = get_score(tr) + new_tr, w = regenerate(tr, (1, ), (UnknownChange(), ), select()) + @test old_sc == get_score(new_tr) - w + new_tr, w = regenerate(tr, (2, ), (UnknownChange(), ), select()) + @test old_sc == get_score(new_tr) - w + new_tr, w = regenerate(tr, (1, ), (UnknownChange(), ), select()) + @test old_sc == get_score(new_tr) - w + end + + @testset "choice gradients" begin + tr = simulate(bam, (2, )) + sel = select(:x => :z) + arg_grads, cvs, cgs = choice_gradients(tr, sel, 1.0) + end + + @testset "accumulate parameter gradients" begin + end end -Gen.@load_generated_functions() - -sc = Switch(bang, fuzz) -tr = simulate(sc, (2, 5.0, 3.0)) -display(get_choices(tr)) -display(tr.index) - -tr = simulate(bam, (1, )) - -chm = choicemap((:x => :z, 5.0)) -new_tr, w, rd, discard = update(tr, (2, ), (UnknownChange(), ), chm) -display(discard) -display(get_choices(new_tr)) - -new_tr, w = regenerate(tr, (1, ), (UnknownChange(), ), select()) - -sel = select(:x => :z) -arg_grads, cvs, cgs = choice_gradients(tr, sel, 1.0) -display(arg_grads) -display(cgs) - -end # module From dfe0125f653412fa4f68e4dc6e1b8b7c102f6fff Mon Sep 17 00:00:00 2001 From: femtomc Date: Thu, 19 Nov 2020 19:30:29 -0500 Subject: [PATCH 24/30] Fixed numerous bugs uncovered while constructing test suite. One serious bug in semantics for regenerate - when switching branches, should generate with choice map constraints except those addresses which are in selection. --- src/modeling_library/switch/regenerate.jl | 15 +++++++++------ test/modeling_library/switch.jl | 8 ++++---- 2 files changed, 13 insertions(+), 10 deletions(-) diff --git a/src/modeling_library/switch/regenerate.jl b/src/modeling_library/switch/regenerate.jl index 9e582c7c..caaa36b2 100644 --- a/src/modeling_library/switch/regenerate.jl +++ b/src/modeling_library/switch/regenerate.jl @@ -9,17 +9,19 @@ mutable struct SwitchRegenerateState{T} SwitchRegenerateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace) end +@inline regenerate_recurse_merge(prev_choices::ChoiceMap, selection::EmptySelection) = prev_choices +@inline regenerate_recurse_merge(prev_choices::ChoiceMap, selection::AllSelection) = choicemap() function regenerate_recurse_merge(prev_choices::ChoiceMap, selection::Selection) - prev_choice_submap_iterator = get_submaps_shallow(prev_choices) prev_choice_value_iterator = get_values_shallow(prev_choices) - new_choices = DynamicChoiceMap() + prev_choice_submap_iterator = get_submaps_shallow(prev_choices) + new_choices = choicemap() for (key, value) in prev_choice_value_iterator - key in selection && continue + in(key, selection) && continue set_value!(new_choices, key, value) end for (key, node1) in prev_choice_submap_iterator - if key in selection - subsel = get_subselection(selection, key) + if in(key, selection) + subsel = getindex(selection, key) node = regenerate_recurse_merge(node1, subsel) set_submap!(new_choices, key, node) else @@ -40,7 +42,8 @@ function process!(gen_fn::Switch{C, N, K, T}, merged = regenerate_recurse_merge(get_choices(state.prev_trace), selection) new_trace, weight = generate(branch_fn, args, merged) retdiff = UnknownChange() - weight -= get_score(state.prev_trace) + weight -= project(state.prev_trace, complement(selection)) + weight += (project(new_trace, selection) - project(state.prev_trace, selection)) state.index = index state.weight = weight state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection()) diff --git a/test/modeling_library/switch.jl b/test/modeling_library/switch.jl index d781c949..41e52510 100644 --- a/test/modeling_library/switch.jl +++ b/test/modeling_library/switch.jl @@ -80,12 +80,12 @@ old_sc = get_score(tr) sel = select(:z) new_tr, w, rd = regenerate(tr, (2, 5.0, 3.0), - (UnknownChange(), NoChange(), NoChange()), - sel) + (UnknownChange(), NoChange(), NoChange()), + sel) @test old_sc == get_score(new_tr) - w new_tr, w, rd = regenerate(tr, (1, 5.0, 3.0), - (UnknownChange(), NoChange(), NoChange()), - sel) + (UnknownChange(), NoChange(), NoChange()), + sel) @test old_sc == get_score(new_tr) - w end From 3717d6579e9cc10b0036cfb1609aaf765f317131 Mon Sep 17 00:00:00 2001 From: femtomc Date: Thu, 19 Nov 2020 20:41:37 -0500 Subject: [PATCH 25/30] Tests for everything but gradients - working on gradients now. --- src/modeling_library/cond.jl | 9 +- src/modeling_library/switch/assess.jl | 2 +- test/modeling_library/switch.jl | 115 ++++++++++++++++++++++++-- 3 files changed, 110 insertions(+), 16 deletions(-) diff --git a/src/modeling_library/cond.jl b/src/modeling_library/cond.jl index 5d398524..9c0ce4fd 100644 --- a/src/modeling_library/cond.jl +++ b/src/modeling_library/cond.jl @@ -15,13 +15,6 @@ end @inline get_args(tr::SwitchTrace) = tr.args @inline get_score(tr::SwitchTrace) = tr.score @inline get_gen_fn(tr::SwitchTrace) = tr.gen_fn - -@inline function Base.getindex(tr::SwitchTrace, addr::Pair) - (first, rest) = addr - subtr = getfield(trace, first) - subtrace[rest] -end -@inline Base.getindex(tr::SwitchTrace, addr::Symbol) = getindex(tr.branch, addr) - +@inline Base.getindex(tr::SwitchTrace, addr) = Base.getindex(tr.branch, addr) @inline project(tr::SwitchTrace, selection::Selection) = project(tr.branch, selection) @inline project(tr::SwitchTrace, ::EmptySelection) = tr.noise diff --git a/src/modeling_library/switch/assess.jl b/src/modeling_library/switch/assess.jl index eb932855..a843535e 100644 --- a/src/modeling_library/switch/assess.jl +++ b/src/modeling_library/switch/assess.jl @@ -10,7 +10,7 @@ function process!(gen_fn::Switch{C, N, K, T}, choices::ChoiceMap, state::SwitchAssessState{T}) where {C, N, K, T} (weight, retval) = assess(getindex(gen_fn.mix, index), args, choices) - state.weight += weight + state.weight = weight state.retval = retval end diff --git a/test/modeling_library/switch.jl b/test/modeling_library/switch.jl index 41e52510..15da4cba 100644 --- a/test/modeling_library/switch.jl +++ b/test/modeling_library/switch.jl @@ -18,18 +18,18 @@ # ------------ Bare combinator ------------ # # Model chunk. - @gen (grad) function foo((grad)(x::Float64), (grad)(y::Float64)) + @gen (grad) function bang0((grad)(x::Float64), (grad)(y::Float64)) std::Float64 = 3.0 z = @trace(normal(x + y, std), :z) return z end - @gen (grad) function baz((grad)(x::Float64), (grad)(y::Float64)) + @gen (grad) function fuzz0((grad)(x::Float64), (grad)(y::Float64)) std::Float64 = 3.0 z = @trace(normal(x + 2 * y, std), :z) return z end - sc = Switch(foo, baz) + sc = Switch(bang0, fuzz0) # ----. @testset "simulate" begin @@ -101,17 +101,17 @@ # ------------ Hierarchy ------------ # # Model chunk. - @gen (grad) function bang((grad)(x::Float64), (grad)(y::Float64)) + @gen (grad) function bang1((grad)(x::Float64), (grad)(y::Float64)) std::Float64 = 3.0 z = @trace(normal(x + y, std), :z) return z end - @gen (grad) function fuzz((grad)(x::Float64), (grad)(y::Float64)) + @gen (grad) function fuzz1((grad)(x::Float64), (grad)(y::Float64)) std::Float64 = 3.0 z = @trace(normal(x + 2 * y, std), :z) return z end - sc = Switch(bang, fuzz) + sc = Switch(bang1, fuzz1) @gen (grad) function bam(s::Int) x ~ sc(s, 5.0, 3.0) return x @@ -120,18 +120,28 @@ @testset "simulate" begin tr = simulate(bam, (2, )) + @test isapprox(get_score(tr), logpdf(normal, tr[:x => :z], 5.0 + 2 * 3.0, 3.0)) end @testset "generate" begin chm = choicemap() chm[:x => :z] = 5.0 - tr, _ = generate(sc, (2, 5.0, 3.0), chm) + tr, w = generate(bam, (2, ), chm) + assignment = get_choices(tr) + @test assignment[:x => :z] == 5.0 + @test isapprox(w, logpdf(normal, 5.0, 5.0 + 2 * 3.0, 3.0)) end @testset "assess" begin + chm = choicemap() + chm[:x => :z] = 5.0 + w, ret = assess(bam, (2, ), chm) + @test isapprox(w, logpdf(normal, 5.0, 5.0 + 2 * 3.0, 3.0)) end @testset "propose" begin + chm, w = propose(bam, (2, )) + @test isapprox(w, logpdf(normal, chm[:x => :z], 5.0 + 2 * 3.0, 3.0)) end @testset "update" begin @@ -164,4 +174,95 @@ @testset "accumulate parameter gradients" begin end + + # ------------ (More complex) hierarchy ------------ # + + # Model chunk. + @gen (grad) function bang2((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + y, std), :z) + return z + end + @gen (grad) function fuzz2((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + 2 * y, std), :z) + q = @trace(bang2(z, y), :q) + return z + end + sc2 = Switch(bang2, fuzz2) + @gen (grad) function bam2(s::Int) + x ~ sc2(s, 5.0, 3.0) + return x + end + # ----. + + @testset "simulate" begin + tr = simulate(bam2, (1, )) + @test isapprox(get_score(tr), logpdf(normal, tr[:x => :z], 5.0 + 3.0, 3.0)) + tr = simulate(bam2, (2, )) + @test isapprox(get_score(tr), logpdf(normal, tr[:x => :z], 5.0 + 2 * 3.0, 3.0) + logpdf(normal, tr[:x => :q => :z], tr[:x => :z] + 3.0, 3.0)) + end + + @testset "generate" begin + chm = choicemap() + chm[:x => :z] = 5.0 + tr, w = generate(bam2, (1, ), chm) + assignment = get_choices(tr) + @test assignment[:x => :z] == 5.0 + @test isapprox(w, logpdf(normal, 5.0, 5.0 + 3.0, 3.0)) + tr, w = generate(bam2, (2, ), chm) + assignment = get_choices(tr) + @test assignment[:x => :z] == 5.0 + @test isapprox(w, logpdf(normal, tr[:x => :z], 5.0 + 2 * 3.0, 3.0)) + end + + @testset "assess" begin + chm = choicemap() + chm[:x => :z] = 5.0 + w, ret = assess(bam2, (1, ), chm) + @test isapprox(w, logpdf(normal, 5.0, 5.0 + 3.0, 3.0)) + chm[:x => :q => :z] = 5.0 + w, ret = assess(bam2, (2, ), chm) + @test isapprox(w, logpdf(normal, 5.0, 5.0 + 2 * 3.0, 3.0) + logpdf(normal, 5.0, 5.0 + 3.0, 3.0)) + end + + @testset "propose" begin + chm, w = propose(bam2, (1, )) + @test isapprox(w, logpdf(normal, chm[:x => :z], 5.0 + 3.0, 3.0)) + chm, w = propose(bam2, (2, )) + @test isapprox(w, logpdf(normal, chm[:x => :z], 5.0 + 2 * 3.0, 3.0) + logpdf(normal, chm[:x => :q => :z], chm[:x => :z] + 3.0, 3.0)) + end + + @testset "update" begin + tr = simulate(bam2, (2, )) + old_sc = get_score(tr) + chm = choicemap((:x => :z, 5.0)) + new_tr, w, rd, discard = update(tr, (2, ), (UnknownChange(), ), chm) + @test old_sc == get_score(new_tr) - w + chm = choicemap((:x => :z, 10.0)) + new_tr, w, rd, discard = update(tr, (1, ), (UnknownChange(), ), chm) + @test old_sc == get_score(new_tr) - w + end + + @testset "regenerate" begin + tr = simulate(bam2, (2, )) + old_sc = get_score(tr) + new_tr, w = regenerate(tr, (1, ), (UnknownChange(), ), select()) + @test old_sc == get_score(new_tr) - w + new_tr, w = regenerate(tr, (2, ), (UnknownChange(), ), select()) + @test old_sc == get_score(new_tr) - w + new_tr, w = regenerate(tr, (1, ), (UnknownChange(), ), select()) + @test old_sc == get_score(new_tr) - w + new_tr, w = regenerate(tr, (1, ), (UnknownChange(), ), select(:x => :z)) + @test old_sc == get_score(new_tr) - w + end + + @testset "choice gradients" begin + tr = simulate(bam2, (2, )) + sel = select(:x => :z) + arg_grads, cvs, cgs = choice_gradients(tr, sel, 1.0) + end + + @testset "accumulate parameter gradients" begin + end end From cb62fb50d11527e8974a9bdbbda4edbe4d4ed930 Mon Sep 17 00:00:00 2001 From: femtomc Date: Fri, 20 Nov 2020 09:45:01 -0500 Subject: [PATCH 26/30] Last tests I need to write: accumulate_param_gradients! --- src/modeling_library/switch/backprop.jl | 2 +- test/modeling_library/switch.jl | 91 ++++++++++++++++--------- 2 files changed, 60 insertions(+), 33 deletions(-) diff --git a/src/modeling_library/switch/backprop.jl b/src/modeling_library/switch/backprop.jl index 5a2fc04a..28add242 100644 --- a/src/modeling_library/switch/backprop.jl +++ b/src/modeling_library/switch/backprop.jl @@ -1,2 +1,2 @@ @inline choice_gradients(trace::SwitchTrace{T}, selection::Selection, retval_grad) where T = choice_gradients(getfield(trace, :branch), selection, retval_grad) -@inline accumulate_param_gradients(trace::SwitchTrace{T}, retval_grad) where {T} = accumulate_param_gradients(getfield(trace, :branch), retval_grad) +@inline accumulate_param_gradients!(trace::SwitchTrace{T}, retval_grad, scale_factor = 1.) where {T} = accumulate_param_gradients!(getfield(trace, :branch), retval_grad, scale_factor) diff --git a/test/modeling_library/switch.jl b/test/modeling_library/switch.jl index 15da4cba..aa6f9e02 100644 --- a/test/modeling_library/switch.jl +++ b/test/modeling_library/switch.jl @@ -67,12 +67,12 @@ new_tr, w, rd, discard = update(tr, (2, 5.0, 3.0), (UnknownChange(), NoChange(), NoChange()), chm) - @test old_sc == get_score(new_tr) - w + @test isapprox(old_sc, get_score(new_tr) - w) chm = choicemap((:x => :z, 10.0)) new_tr, w, rd, discard = update(tr, (1, 5.0, 3.0), (UnknownChange(), NoChange(), NoChange()), chm) - @test old_sc == get_score(new_tr) - w + @test isapprox(old_sc, get_score(new_tr) - w) end @testset "regenerate" begin @@ -82,35 +82,43 @@ new_tr, w, rd = regenerate(tr, (2, 5.0, 3.0), (UnknownChange(), NoChange(), NoChange()), sel) - @test old_sc == get_score(new_tr) - w + @test isapprox(old_sc, get_score(new_tr) - w) new_tr, w, rd = regenerate(tr, (1, 5.0, 3.0), (UnknownChange(), NoChange(), NoChange()), sel) - @test old_sc == get_score(new_tr) - w + @test isapprox(old_sc, get_score(new_tr) - w) end @testset "choice gradients" begin - tr = simulate(sc, (2, 5.0, 3.0)) - sel = select(:z) - arg_grads, cvs, cgs = choice_gradients(tr, sel, 1.0) - end - - @testset "accumulate parameter gradients" begin + for z in [1.0, 3.0, 5.0, 10.0] + chm = choicemap((:z, z)) + tr, _ = generate(sc, (1, 5.0, 3.0), chm) + sel = select(:z) + input_grads, choices, gradients = choice_gradients(tr, sel) + expected_choice_grad = logpdf_grad(normal, z, 5.0 + 3.0, 3.0) + @test isapprox(gradients[:z], expected_choice_grad[1]) + tr, _ = generate(sc, (2, 5.0, 3.0), chm) + input_grads, choices, gradients = choice_gradients(tr, sel) + expected_choice_grad = logpdf_grad(normal, z, 5.0 + 2 * 3.0, 3.0) + @test isapprox(gradients[:z], expected_choice_grad[1]) + end end # ------------ Hierarchy ------------ # # Model chunk. @gen (grad) function bang1((grad)(x::Float64), (grad)(y::Float64)) - std::Float64 = 3.0 + @param(std::Float64) z = @trace(normal(x + y, std), :z) return z end + init_param!(bang1, :std, 3.0) @gen (grad) function fuzz1((grad)(x::Float64), (grad)(y::Float64)) - std::Float64 = 3.0 + @param(std::Float64) z = @trace(normal(x + 2 * y, std), :z) return z end + init_param!(fuzz1, :std, 3.0) sc = Switch(bang1, fuzz1) @gen (grad) function bam(s::Int) x ~ sc(s, 5.0, 3.0) @@ -149,30 +157,47 @@ old_sc = get_score(tr) chm = choicemap((:x => :z, 5.0)) new_tr, w, rd, discard = update(tr, (2, ), (UnknownChange(), ), chm) - @test old_sc == get_score(new_tr) - w + @test isapprox(old_sc, get_score(new_tr) - w) chm = choicemap((:x => :z, 10.0)) new_tr, w, rd, discard = update(tr, (1, ), (UnknownChange(), ), chm) - @test old_sc == get_score(new_tr) - w + @test isapprox(old_sc, get_score(new_tr) - w) end @testset "regenerate" begin tr = simulate(bam, (2, )) old_sc = get_score(tr) new_tr, w = regenerate(tr, (1, ), (UnknownChange(), ), select()) - @test old_sc == get_score(new_tr) - w + @test isapprox(old_sc, get_score(new_tr) - w) new_tr, w = regenerate(tr, (2, ), (UnknownChange(), ), select()) - @test old_sc == get_score(new_tr) - w + @test isapprox(old_sc, get_score(new_tr) - w) new_tr, w = regenerate(tr, (1, ), (UnknownChange(), ), select()) - @test old_sc == get_score(new_tr) - w + @test isapprox(old_sc, get_score(new_tr) - w) end @testset "choice gradients" begin - tr = simulate(bam, (2, )) - sel = select(:x => :z) - arg_grads, cvs, cgs = choice_gradients(tr, sel, 1.0) + for z in [1.0, 3.0, 5.0, 10.0] + chm = choicemap((:x => :z, z)) + tr, _ = generate(bam, (1, ), chm) + sel = select(:x => :z) + input_grads, choices, gradients = choice_gradients(tr, sel) + expected_choice_grad = logpdf_grad(normal, z, 5.0 + 3.0, 3.0) + @test isapprox(gradients[:x => :z], expected_choice_grad[1]) + chm = choicemap((:x => :z, z)) + tr, _ = generate(bam, (2, ), chm) + sel = select(:x => :z) + input_grads, choices, gradients = choice_gradients(tr, sel) + expected_choice_grad = logpdf_grad(normal, z, 5.0 + 2 * 3.0, 3.0) + @test isapprox(gradients[:x => :z], expected_choice_grad[1]) + end end @testset "accumulate parameter gradients" begin + tr = simulate(bam, (1, )) + zero_param_grad!(bang1, :std) + input_grads = accumulate_param_gradients!(tr, 1.0) + tr = simulate(bam, (2, )) + zero_param_grad!(fuzz1, :std) + input_grads = accumulate_param_gradients!(tr, 1.0) end # ------------ (More complex) hierarchy ------------ # @@ -238,31 +263,33 @@ old_sc = get_score(tr) chm = choicemap((:x => :z, 5.0)) new_tr, w, rd, discard = update(tr, (2, ), (UnknownChange(), ), chm) - @test old_sc == get_score(new_tr) - w + @test isapprox(old_sc, get_score(new_tr) - w) chm = choicemap((:x => :z, 10.0)) new_tr, w, rd, discard = update(tr, (1, ), (UnknownChange(), ), chm) - @test old_sc == get_score(new_tr) - w + @test isapprox(old_sc, get_score(new_tr) - w) end @testset "regenerate" begin tr = simulate(bam2, (2, )) old_sc = get_score(tr) new_tr, w = regenerate(tr, (1, ), (UnknownChange(), ), select()) - @test old_sc == get_score(new_tr) - w + @test isapprox(old_sc, get_score(new_tr) - w) new_tr, w = regenerate(tr, (2, ), (UnknownChange(), ), select()) - @test old_sc == get_score(new_tr) - w + @test isapprox(old_sc, get_score(new_tr) - w) new_tr, w = regenerate(tr, (1, ), (UnknownChange(), ), select()) - @test old_sc == get_score(new_tr) - w + @test isapprox(old_sc, get_score(new_tr) - w) new_tr, w = regenerate(tr, (1, ), (UnknownChange(), ), select(:x => :z)) - @test old_sc == get_score(new_tr) - w + @test isapprox(old_sc, get_score(new_tr) - w) end @testset "choice gradients" begin - tr = simulate(bam2, (2, )) - sel = select(:x => :z) - arg_grads, cvs, cgs = choice_gradients(tr, sel, 1.0) - end - - @testset "accumulate parameter gradients" begin + for z in [1.0, 3.0, 5.0, 10.0] + chm = choicemap((:x => :z, z)) + tr, _ = generate(bam2, (1, ), chm) + sel = select(:x => :z) + input_grads, choices, gradients = choice_gradients(tr, sel) + expected_choice_grad = logpdf_grad(normal, z, 5.0 + 3.0, 3.0) + @test isapprox(gradients[:x => :z], expected_choice_grad[1]) + end end end From 97473d06cd33e5036fc4c4c21c4246841f41fb50 Mon Sep 17 00:00:00 2001 From: femtomc Date: Fri, 20 Nov 2020 10:54:18 -0500 Subject: [PATCH 27/30] Added accumulate_param_gradients! tests. --- test/modeling_library/switch.jl | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/test/modeling_library/switch.jl b/test/modeling_library/switch.jl index aa6f9e02..7346a356 100644 --- a/test/modeling_library/switch.jl +++ b/test/modeling_library/switch.jl @@ -192,12 +192,19 @@ end @testset "accumulate parameter gradients" begin - tr = simulate(bam, (1, )) - zero_param_grad!(bang1, :std) - input_grads = accumulate_param_gradients!(tr, 1.0) - tr = simulate(bam, (2, )) - zero_param_grad!(fuzz1, :std) - input_grads = accumulate_param_gradients!(tr, 1.0) + for z in [1.0, 3.0, 5.0, 10.0] + chm = choicemap((:z, z)) + tr, _ = generate(bam, (1, ), chm) + zero_param_grad!(bang1, :std) + input_grads = accumulate_param_gradients!(tr, 1.0) + expected_std_grad = logpdf_grad(normal, tr[:x => :z], 5.0 + 3.0, 3.0)[3] + @test isapprox(get_param_grad(bang1, :std), expected_std_grad) + tr, _ = generate(bam, (2, ), chm) + zero_param_grad!(fuzz1, :std) + input_grads = accumulate_param_gradients!(tr, 1.0) + expected_std_grad = logpdf_grad(normal, tr[:x => :z], 5.0 + 2 * 3.0, 3.0)[3] + @test isapprox(get_param_grad(fuzz1, :std), expected_std_grad) + end end # ------------ (More complex) hierarchy ------------ # From 176b9e9b3a85a98c5b0843bd9ea634f8cbb97620 Mon Sep 17 00:00:00 2001 From: femtomc Date: Fri, 20 Nov 2020 11:02:41 -0500 Subject: [PATCH 28/30] Reverted particle filter fix - will be handled in another issue. --- test/inference/particle_filter.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/inference/particle_filter.jl b/test/inference/particle_filter.jl index 2e32a60a..982d0eec 100644 --- a/test/inference/particle_filter.jl +++ b/test/inference/particle_filter.jl @@ -164,7 +164,7 @@ end # check log marginal likelihood estimate expected_log_ml = log(hmm_forward_alg(prior, emission_dists, transition_dists, obs_x)) actual_log_ml_est = log_ml_estimate(state) - @test isapprox(expected_log_ml, actual_log_ml_est, atol=0.02) + @test isapprox(expected_log_ml, actual_log_ml_est, atol=0.01) end end From 0465965adc5a7f8d099fe909490008749383fe5b Mon Sep 17 00:00:00 2001 From: femtomc Date: Sun, 22 Nov 2020 12:39:44 -0500 Subject: [PATCH 29/30] Renamed mix field of Switch generative function to branches to more accurately reflect the pattern. --- src/modeling_library/switch/assess.jl | 2 +- src/modeling_library/switch/generate.jl | 2 +- src/modeling_library/switch/propose.jl | 2 +- src/modeling_library/switch/regenerate.jl | 2 +- src/modeling_library/switch/simulate.jl | 2 +- src/modeling_library/switch/switch.jl | 6 +++--- src/modeling_library/switch/update.jl | 2 +- 7 files changed, 9 insertions(+), 9 deletions(-) diff --git a/src/modeling_library/switch/assess.jl b/src/modeling_library/switch/assess.jl index a843535e..4371eb8a 100644 --- a/src/modeling_library/switch/assess.jl +++ b/src/modeling_library/switch/assess.jl @@ -9,7 +9,7 @@ function process!(gen_fn::Switch{C, N, K, T}, args::Tuple, choices::ChoiceMap, state::SwitchAssessState{T}) where {C, N, K, T} - (weight, retval) = assess(getindex(gen_fn.mix, index), args, choices) + (weight, retval) = assess(getindex(gen_fn.branches, index), args, choices) state.weight = weight state.retval = retval end diff --git a/src/modeling_library/switch/generate.jl b/src/modeling_library/switch/generate.jl index 1c135c67..bd03f632 100644 --- a/src/modeling_library/switch/generate.jl +++ b/src/modeling_library/switch/generate.jl @@ -14,7 +14,7 @@ function process!(gen_fn::Switch{C, N, K, T}, choices::ChoiceMap, state::SwitchGenerateState{T}) where {C, N, K, T} - (subtrace, weight) = generate(getindex(gen_fn.mix, index), args, choices) + (subtrace, weight) = generate(getindex(gen_fn.branches, index), args, choices) state.index = index state.subtrace = subtrace state.weight += weight diff --git a/src/modeling_library/switch/propose.jl b/src/modeling_library/switch/propose.jl index 492f7ca5..b4df1d97 100644 --- a/src/modeling_library/switch/propose.jl +++ b/src/modeling_library/switch/propose.jl @@ -10,7 +10,7 @@ function process!(gen_fn::Switch{C, N, K, T}, args::Tuple, state::SwitchProposeState{T}) where {C, N, K, T} - (submap, weight, retval) = propose(getindex(gen_fn.mix, index), args) + (submap, weight, retval) = propose(getindex(gen_fn.branches, index), args) state.choices = submap state.weight += weight state.retval = retval diff --git a/src/modeling_library/switch/regenerate.jl b/src/modeling_library/switch/regenerate.jl index caaa36b2..49bce056 100644 --- a/src/modeling_library/switch/regenerate.jl +++ b/src/modeling_library/switch/regenerate.jl @@ -38,7 +38,7 @@ function process!(gen_fn::Switch{C, N, K, T}, kernel_argdiffs::Tuple, selection::Selection, state::SwitchRegenerateState{T}) where {C, N, K, T} - branch_fn = getfield(gen_fn.mix, index) + branch_fn = getfield(gen_fn.branches, index) merged = regenerate_recurse_merge(get_choices(state.prev_trace), selection) new_trace, weight = generate(branch_fn, args, merged) retdiff = UnknownChange() diff --git a/src/modeling_library/switch/simulate.jl b/src/modeling_library/switch/simulate.jl index ef75bb48..fc4b3b02 100644 --- a/src/modeling_library/switch/simulate.jl +++ b/src/modeling_library/switch/simulate.jl @@ -12,7 +12,7 @@ function process!(gen_fn::Switch{C, N, K, T}, args::Tuple, state::SwitchSimulateState{T}) where {C, N, K, T} local retval::T - subtrace = simulate(getindex(gen_fn.mix, index), args) + subtrace = simulate(getindex(gen_fn.branches, index), args) state.index = index state.noise += project(subtrace, EmptySelection()) state.subtrace = subtrace diff --git a/src/modeling_library/switch/switch.jl b/src/modeling_library/switch/switch.jl index ca90556b..82114344 100644 --- a/src/modeling_library/switch/switch.jl +++ b/src/modeling_library/switch/switch.jl @@ -1,5 +1,5 @@ struct Switch{C, N, K, T} <: GenerativeFunction{T, Trace} - mix::NTuple{N, GenerativeFunction{T}} + branches::NTuple{N, GenerativeFunction{T}} cases::Dict{C, Int} function Switch(gen_fns::GenerativeFunction...) @assert !isempty(gen_fns) @@ -14,10 +14,10 @@ struct Switch{C, N, K, T} <: GenerativeFunction{T, Trace} end export Switch -has_argument_grads(switch_fn::Switch) = map(zip(map(has_argument_grads, switch_fn.mix)...)) do as +has_argument_grads(switch_fn::Switch) = map(zip(map(has_argument_grads, switch_fn.branches)...)) do as all(as) end -accepts_output_grad(switch_fn::Switch) = all(accepts_output_grad, switch_fn.mix) +accepts_output_grad(switch_fn::Switch) = all(accepts_output_grad, switch_fn.branches) function (gen_fn::Switch)(index::Int, args...) (_, _, retval) = propose(gen_fn, (index, args...)) diff --git a/src/modeling_library/switch/update.jl b/src/modeling_library/switch/update.jl index dbedce94..d7f69daa 100644 --- a/src/modeling_library/switch/update.jl +++ b/src/modeling_library/switch/update.jl @@ -66,7 +66,7 @@ function process!(gen_fn::Switch{C, N, K, T}, # Generate new trace. merged = update_recurse_merge(get_choices(state.prev_trace), choices) - branch_fn = getfield(gen_fn.mix, index) + branch_fn = getfield(gen_fn.branches, index) new_trace, weight = generate(branch_fn, args, merged) weight -= get_score(state.prev_trace) state.discard = update_discard(state.prev_trace, choices, new_trace) From 43c7274648a6d9ea21b42097e28e8acc0cdad846 Mon Sep 17 00:00:00 2001 From: femtomc Date: Sat, 5 Dec 2020 13:15:25 -0500 Subject: [PATCH 30/30] Addressed review comments. Added docstrings where necessary. Corrected update_discard. Added test to test the discard functionality in a hierarchical model example. --- src/modeling_library/switch/regenerate.jl | 26 +----- src/modeling_library/switch/update.jl | 96 ++++++++++++++--------- test/modeling_library/switch.jl | 38 +++++++++ 3 files changed, 101 insertions(+), 59 deletions(-) diff --git a/src/modeling_library/switch/regenerate.jl b/src/modeling_library/switch/regenerate.jl index 49bce056..cb1094af 100644 --- a/src/modeling_library/switch/regenerate.jl +++ b/src/modeling_library/switch/regenerate.jl @@ -9,37 +9,15 @@ mutable struct SwitchRegenerateState{T} SwitchRegenerateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace) end -@inline regenerate_recurse_merge(prev_choices::ChoiceMap, selection::EmptySelection) = prev_choices -@inline regenerate_recurse_merge(prev_choices::ChoiceMap, selection::AllSelection) = choicemap() -function regenerate_recurse_merge(prev_choices::ChoiceMap, selection::Selection) - prev_choice_value_iterator = get_values_shallow(prev_choices) - prev_choice_submap_iterator = get_submaps_shallow(prev_choices) - new_choices = choicemap() - for (key, value) in prev_choice_value_iterator - in(key, selection) && continue - set_value!(new_choices, key, value) - end - for (key, node1) in prev_choice_submap_iterator - if in(key, selection) - subsel = getindex(selection, key) - node = regenerate_recurse_merge(node1, subsel) - set_submap!(new_choices, key, node) - else - set_submap!(new_choices, key, node1) - end - end - return new_choices -end - function process!(gen_fn::Switch{C, N, K, T}, index::Int, - index_argdiff::UnknownChange, + index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, selection::Selection, state::SwitchRegenerateState{T}) where {C, N, K, T} branch_fn = getfield(gen_fn.branches, index) - merged = regenerate_recurse_merge(get_choices(state.prev_trace), selection) + merged = get_selected(get_choices(state.prev_trace), complement(selection)) new_trace, weight = generate(branch_fn, args, merged) retdiff = UnknownChange() weight -= project(state.prev_trace, complement(selection)) diff --git a/src/modeling_library/switch/update.jl b/src/modeling_library/switch/update.jl index d7f69daa..6aa672fd 100644 --- a/src/modeling_library/switch/update.jl +++ b/src/modeling_library/switch/update.jl @@ -16,53 +16,79 @@ function update_recurse_merge(prev_choices::ChoiceMap, choices::ChoiceMap) choice_submap_iterator = get_submaps_shallow(choices) choice_value_iterator = get_values_shallow(choices) new_choices = DynamicChoiceMap() - for (key, value) in prev_choice_value_iterator - key in keys(choice_value_iterator) && continue - set_value!(new_choices, key, value) + + # Add (address, value) to new_choices from prev_choices if address does not occur in choices. + for (address, value) in prev_choice_value_iterator + address in keys(choice_value_iterator) && continue + set_value!(new_choices, address, value) end - for (key, node1) in prev_choice_submap_iterator - if key in keys(choice_submap_iterator) - node2 = get_submap(choices, key) + + # Add (address, submap) to new_choices from prev_choices if address does not occur in choices. + # If it does, enter a recursive call to update_recurse_merge. + for (address, node1) in prev_choice_submap_iterator + if address in keys(choice_submap_iterator) + node2 = get_submap(choices, address) node = update_recurse_merge(node1, node2) - set_submap!(new_choices, key, node) + set_submap!(new_choices, address, node) else - set_submap!(new_choices, key, node1) + set_submap!(new_choices, address, node1) end end - for (key, value) in choice_value_iterator - set_value!(new_choices, key, value) + + # Add (address, value) from choices to new_choices. This is okay because we've excluded any conflicting addresses from the prev_choices above. + for (address, value) in choice_value_iterator + set_value!(new_choices, address, value) end + sel, _ = zip(prev_choice_submap_iterator...) comp = complement(select(sel...)) - for (key, node) in get_submaps_shallow(get_selected(choices, comp)) - set_submap!(new_choices, key, node) + for (address, node) in get_submaps_shallow(get_selected(choices, comp)) + set_submap!(new_choices, address, node) end return new_choices end -function update_discard(prev_trace::Trace, choices::ChoiceMap, new_trace::Trace) +@doc( +""" +update_recurse_merge(prev_choices::ChoiceMap, choices::ChoiceMap) + +Returns choices that are in constraints, merged with all choices in the previous trace that do not have the same address as some choice in the constraints." +""", update_recurse_merge) + +function update_discard(prev_choices::ChoiceMap, choices::ChoiceMap, new_choices::ChoiceMap) discard = choicemap() - prev_choices = get_choices(prev_trace) for (k, v) in get_submaps_shallow(prev_choices) - isempty(get_submap(get_choices(new_trace), k)) && continue - isempty(get_submap(choices, k)) && continue - set_submap!(discard, k, v) + new_submap = get_submap(new_choices, k) + choices_submap = get_submap(choices, k) + sub_discard = update_discard(v, choices_submap, new_submap) + set_submap!(discard, k, sub_discard) end for (k, v) in get_values_shallow(prev_choices) - has_value(get_choices(new_trace), k) || continue - has_value(choices, k) || continue - set_value!(discard, k, v) + if (!has_value(new_choices, k) || has_value(choices, k)) + set_value!(discard, k, v) + end end discard end +@doc( +""" +update_discard(prev_choices::ChoiceMap, choices::ChoiceMap, new_choices::ChoiceMap) + +Returns choices from previous trace that: + 1. have an address which does not appear in the new trace. + 2. have an address which does appear in the constraints. +""", update_discard) + +@inline update_discard(prev_trace::Trace, choices::ChoiceMap, new_trace::Trace) = update_discard(get_choices(prev_trace), choices, get_choices(new_trace)) + function process!(gen_fn::Switch{C, N, K, T}, - index::Int, - index_argdiff::UnknownChange, # TODO: Diffed wrapper? - args::Tuple, - kernel_argdiffs::Tuple, - choices::ChoiceMap, - state::SwitchUpdateState{T}) where {C, N, K, T, DV} + index::Int, + index_argdiff::UnknownChange, + args::Tuple, + kernel_argdiffs::Tuple, + choices::ChoiceMap, + state::SwitchUpdateState{T}) where {C, N, K, T, DV} # Generate new trace. merged = update_recurse_merge(get_choices(state.prev_trace), choices) @@ -81,12 +107,12 @@ function process!(gen_fn::Switch{C, N, K, T}, end function process!(gen_fn::Switch{C, N, K, T}, - index::Int, - index_argdiff::NoChange, # TODO: Diffed wrapper? - args::Tuple, - kernel_argdiffs::Tuple, - choices::ChoiceMap, - state::SwitchUpdateState{T}) where {C, N, K, T} + index::Int, + index_argdiff::NoChange, # TODO: Diffed wrapper? + args::Tuple, + kernel_argdiffs::Tuple, + choices::ChoiceMap, + state::SwitchUpdateState{T}) where {C, N, K, T} # Update trace. new_trace, weight, retdiff, discard = update(getfield(state.prev_trace, :branch), args, kernel_argdiffs, choices) @@ -104,9 +130,9 @@ end @inline process!(gen_fn::Switch{C, N, K, T}, index::C, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, choices::ChoiceMap, state::SwitchUpdateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, kernel_argdiffs, choices, state) function update(trace::SwitchTrace{T}, - args::Tuple, - argdiffs::Tuple, - choices::ChoiceMap) where T + args::Tuple, + argdiffs::Tuple, + choices::ChoiceMap) where T gen_fn = trace.gen_fn index, index_argdiff = args[1], argdiffs[1] state = SwitchUpdateState{T}(0.0, 0.0, 0.0, trace) diff --git a/test/modeling_library/switch.jl b/test/modeling_library/switch.jl index 7346a356..8c183aa1 100644 --- a/test/modeling_library/switch.jl +++ b/test/modeling_library/switch.jl @@ -299,4 +299,42 @@ @test isapprox(gradients[:x => :z], expected_choice_grad[1]) end end + + # ------------ (More complex) hierarchy to test discard ------------ # + + # Model chunk. + @gen (grad) function bang3((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + y, std), :z) + q = @trace(bang2(z, y), :q) + return z + end + @gen (grad) function fuzz3((grad)(x::Float64), (grad)(y::Float64)) + std::Float64 = 3.0 + z = @trace(normal(x + 2 * y, std), :z) + m = @trace(normal(x + 3 * y, std), :m) + q = @trace(bang3(z, y), :q) + return z + end + sc3 = Switch(bang3, fuzz3) + @gen (grad) function bam3(s::Int) + x ~ sc3(s, 5.0, 3.0) + return x + end + # ----. + + @testset "update" begin + tr = simulate(bam3, (2, )) + old_sc = get_score(tr) + chm = choicemap((:x => :z, 5.0)) + future_discarded = tr[:x => :z] + new_tr, w, rd, discard = update(tr, (2, ), (UnknownChange(), ), chm) + @test discard[:x => :z] == future_discarded + @test isapprox(old_sc, get_score(new_tr) - w) + chm = choicemap((:x => :z, 10.0)) + future_discarded = tr[:x => :q => :q => :z] + new_tr, w, rd, discard = update(tr, (1, ), (UnknownChange(), ), chm) + @test discard[:x => :q => :q => :z] == future_discarded + @test isapprox(old_sc, get_score(new_tr) - w) + end end