Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

(Ready for review): Switch combinator #334

Merged
merged 31 commits into from
Dec 8, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
3e4f695
Initial work on a Switch combinator.
femtomc Nov 17, 2020
bd4f830
Initial implementation of propose and generate.
femtomc Nov 17, 2020
374a7b0
Added implementaton of simulate.
femtomc Nov 17, 2020
5872593
Corrected some bugs with Bernoulli vs bernoulli.
femtomc Nov 17, 2020
9c0a9f2
Added assess implementation.
femtomc Nov 17, 2020
95baf07
Split into two combinators: Switch and WithProbability implementations.
femtomc Nov 18, 2020
29b7797
Working on Switch update and regenerate.
femtomc Nov 18, 2020
3e6e307
Added Switch update and regenerate.
femtomc Nov 18, 2020
7929b86
Added Switch update and regenerate - working out kinks in update.
femtomc Nov 18, 2020
73618a1
update and regenerate appear to be computing the correct ratios. To c…
femtomc Nov 18, 2020
252413f
Fixed generate index type bug.
femtomc Nov 18, 2020
ac3528e
Branch dispatch done using diff types.
femtomc Nov 18, 2020
eaf3327
Branch dispatch done using diff types.
femtomc Nov 18, 2020
6d58aac
Branch dispatch done using diff types.
femtomc Nov 18, 2020
e413e9c
Added custom methods in update for Switch which allow the merging of …
femtomc Nov 18, 2020
435493f
Added custom methods in update for Switch which allow the merging of …
femtomc Nov 18, 2020
32fec4f
Idiomatic check for EmptyChoiceMap.
femtomc Nov 18, 2020
bb767e7
Working on backprop - seems simple? Could it really be?
femtomc Nov 18, 2020
a35e2e7
Extracting WithProb combinator into another PR.
femtomc Nov 18, 2020
562667e
Testing backprop.
femtomc Nov 19, 2020
b74a071
Fixed backprop - was thinking in Zygote lang. Gradients appear to be …
femtomc Nov 19, 2020
915811d
Merge branch 'master' of https://github.com/probcomp/Gen.jl into 2020…
femtomc Nov 19, 2020
849d61e
Added docstring and docs example.
femtomc Nov 19, 2020
adf73a5
Fixed numerous bugs uncovered while constructing test suite. One seri…
femtomc Nov 19, 2020
dfe0125
Fixed numerous bugs uncovered while constructing test suite. One seri…
femtomc Nov 20, 2020
3717d65
Tests for everything but gradients - working on gradients now.
femtomc Nov 20, 2020
cb62fb5
Last tests I need to write: accumulate_param_gradients!
femtomc Nov 20, 2020
97473d0
Added accumulate_param_gradients! tests.
femtomc Nov 20, 2020
176b9e9
Reverted particle filter fix - will be handled in another issue.
femtomc Nov 20, 2020
0465965
Renamed mix field of Switch generative function to branches to more a…
femtomc Nov 22, 2020
43c7274
Addressed review comments. Added docstrings where necessary. Correcte…
femtomc Dec 5, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions docs/src/ref/combinators.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,4 +119,41 @@ TODO: document me
<img src="../../images/recurse_combinator.png" alt="schematic of recurse combinatokr" width="70%"/>
</div>
```
## Switch combinator

```@docs
Switch
```

In the schematic below, the kernel is denoted `S` and accepts an integer index `k`.

Consider the following constructions:

```julia
@gen function bang((grad)(x::Float64), (grad)(y::Float64))
std::Float64 = 3.0
z = @trace(normal(x + y, std), :z)
return z
end

@gen function fuzz((grad)(x::Float64), (grad)(y::Float64))
std::Float64 = 3.0
z = @trace(normal(x + 2 * y, std), :z)
return z
end

sc = Switch(bang, fuzz)
```

This creates a new generative function `sc`. We can then obtain the trace of `sc`:

```julia
(trace, _) = simulate(sc, (2, 5.0, 3.0))
```

The resulting trace contains the subtrace from the branch with index `2` - in this case, a call to `fuzz`:

```
└── :z : 13.552870875213735
```
20 changes: 20 additions & 0 deletions src/modeling_library/cond.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# ------------ Switch trace ------------ #

struct SwitchTrace{T} <: Trace
gen_fn::GenerativeFunction{T}
index::Int
branch::Trace
retval::T
args::Tuple
score::Float64
noise::Float64
end

@inline get_choices(tr::SwitchTrace) = get_choices(tr.branch)
@inline get_retval(tr::SwitchTrace) = tr.retval
@inline get_args(tr::SwitchTrace) = tr.args
@inline get_score(tr::SwitchTrace) = tr.score
@inline get_gen_fn(tr::SwitchTrace) = tr.gen_fn
@inline Base.getindex(tr::SwitchTrace, addr) = Base.getindex(tr.branch, addr)
@inline project(tr::SwitchTrace, selection::Selection) = project(tr.branch, selection)
@inline project(tr::SwitchTrace, ::EmptySelection) = tr.noise
4 changes: 4 additions & 0 deletions src/modeling_library/modeling_library.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,12 +66,16 @@ include("dist_dsl/dist_dsl.jl")
# code shared by vector-shaped combinators
include("vector.jl")

# traces for with prob/switch combinator
include("cond.jl")

# built-in generative function combinators
include("choice_at/choice_at.jl")
include("call_at/call_at.jl")
include("map/map.jl")
include("unfold/unfold.jl")
include("recurse/recurse.jl")
include("switch/switch.jl")

#############################################################
# abstractions for constructing custom generative functions #
Expand Down
26 changes: 26 additions & 0 deletions src/modeling_library/switch/assess.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
mutable struct SwitchAssessState{T}
weight::Float64
retval::T
SwitchAssessState{T}(weight::Float64) where T = new{T}(weight)
end

function process!(gen_fn::Switch{C, N, K, T},
index::Int,
args::Tuple,
choices::ChoiceMap,
state::SwitchAssessState{T}) where {C, N, K, T}
(weight, retval) = assess(getindex(gen_fn.branches, index), args, choices)
state.weight = weight
state.retval = retval
end

@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, choices::ChoiceMap, state::SwitchAssessState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, choices, state)

function assess(gen_fn::Switch{C, N, K, T},
args::Tuple,
choices::ChoiceMap) where {C, N, K, T}
index = args[1]
state = SwitchAssessState{T}(0.0)
process!(gen_fn, index, args[2 : end], choices, state)
return state.weight, state.retval
end
2 changes: 2 additions & 0 deletions src/modeling_library/switch/backprop.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
@inline choice_gradients(trace::SwitchTrace{T}, selection::Selection, retval_grad) where T = choice_gradients(getfield(trace, :branch), selection, retval_grad)
@inline accumulate_param_gradients!(trace::SwitchTrace{T}, retval_grad, scale_factor = 1.) where {T} = accumulate_param_gradients!(getfield(trace, :branch), retval_grad, scale_factor)
34 changes: 34 additions & 0 deletions src/modeling_library/switch/generate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
mutable struct SwitchGenerateState{T}
score::Float64
noise::Float64
weight::Float64
index::Int
subtrace::Trace
retval::T
SwitchGenerateState{T}(score::Float64, noise::Float64, weight::Float64) where T = new{T}(score, noise, weight)
end

function process!(gen_fn::Switch{C, N, K, T},
index::Int,
args::Tuple,
choices::ChoiceMap,
state::SwitchGenerateState{T}) where {C, N, K, T}

(subtrace, weight) = generate(getindex(gen_fn.branches, index), args, choices)
state.index = index
state.subtrace = subtrace
state.weight += weight
state.retval = get_retval(subtrace)
end

@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, choices::ChoiceMap, state::SwitchGenerateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, choices, state)

function generate(gen_fn::Switch{C, N, K, T},
args::Tuple,
choices::ChoiceMap) where {C, N, K, T}

index = args[1]
state = SwitchGenerateState{T}(0.0, 0.0, 0.0)
process!(gen_fn, index, args[2 : end], choices, state)
return SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise), state.weight
end
29 changes: 29 additions & 0 deletions src/modeling_library/switch/propose.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
mutable struct SwitchProposeState{T}
choices::DynamicChoiceMap
weight::Float64
retval::T
SwitchProposeState{T}(choices, weight) where T = new{T}(choices, weight)
end

function process!(gen_fn::Switch{C, N, K, T},
index::Int,
args::Tuple,
state::SwitchProposeState{T}) where {C, N, K, T}

(submap, weight, retval) = propose(getindex(gen_fn.branches, index), args)
state.choices = submap
state.weight += weight
state.retval = retval
end

@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, state::SwitchProposeState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, state)

function propose(gen_fn::Switch{C, N, K, T},
args::Tuple) where {C, N, K, T}

index = args[1]
choices = choicemap()
state = SwitchProposeState{T}(choices, 0.0)
process!(gen_fn, index, args[2:end], state)
return state.choices, state.weight, state.retval
end
60 changes: 60 additions & 0 deletions src/modeling_library/switch/regenerate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
mutable struct SwitchRegenerateState{T}
weight::Float64
score::Float64
noise::Float64
prev_trace::Trace
trace::Trace
index::Int
retdiff::Diff
SwitchRegenerateState{T}(weight::Float64, score::Float64, noise::Float64, prev_trace::Trace) where T = new{T}(weight, score, noise, prev_trace)
end

function process!(gen_fn::Switch{C, N, K, T},
index::Int,
index_argdiff::Diff,
args::Tuple,
kernel_argdiffs::Tuple,
selection::Selection,
state::SwitchRegenerateState{T}) where {C, N, K, T}
branch_fn = getfield(gen_fn.branches, index)
merged = get_selected(get_choices(state.prev_trace), complement(selection))
new_trace, weight = generate(branch_fn, args, merged)
retdiff = UnknownChange()
weight -= project(state.prev_trace, complement(selection))
weight += (project(new_trace, selection) - project(state.prev_trace, selection))
state.index = index
state.weight = weight
state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection())
state.score = get_score(new_trace)
state.trace = new_trace
state.retdiff = retdiff
end

function process!(gen_fn::Switch{C, N, K, T},
index::Int,
index_argdiff::NoChange,
args::Tuple,
kernel_argdiffs::Tuple,
selection::Selection,
state::SwitchRegenerateState{T}) where {C, N, K, T}
new_trace, weight, retdiff = regenerate(getfield(state.prev_trace, :branch), args, kernel_argdiffs, selection)
state.index = index
state.weight = weight
state.noise = project(new_trace, EmptySelection()) - project(state.prev_trace, EmptySelection())
state.score = get_score(new_trace)
state.trace = new_trace
state.retdiff = retdiff
end

@inline process!(gen_fn::Switch{C, N, K, T}, index::C, index_argdiff::Diff, args::Tuple, kernel_argdiffs::Tuple, selection::Selection, state::SwitchRegenerateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), index_argdiff, args, kernel_argdiffs, selection, state)

function regenerate(trace::SwitchTrace{T},
args::Tuple,
argdiffs::Tuple,
selection::Selection) where T
gen_fn = trace.gen_fn
index, index_argdiff = args[1], argdiffs[1]
state = SwitchRegenerateState{T}(0.0, 0.0, 0.0, trace)
process!(gen_fn, index, index_argdiff, args[2 : end], argdiffs[2 : end], selection, state)
return SwitchTrace(gen_fn, state.index, state.trace, get_retval(state.trace), args, state.score, state.noise), state.weight, state.retdiff
end
32 changes: 32 additions & 0 deletions src/modeling_library/switch/simulate.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
mutable struct SwitchSimulateState{T}
score::Float64
noise::Float64
index::Int
subtrace::Trace
retval::T
SwitchSimulateState{T}(score::Float64, noise::Float64) where T = new{T}(score, noise)
end

function process!(gen_fn::Switch{C, N, K, T},
index::Int,
args::Tuple,
state::SwitchSimulateState{T}) where {C, N, K, T}
local retval::T
subtrace = simulate(getindex(gen_fn.branches, index), args)
state.index = index
state.noise += project(subtrace, EmptySelection())
state.subtrace = subtrace
state.score += get_score(subtrace)
state.retval = get_retval(subtrace)
end

@inline process!(gen_fn::Switch{C, N, K, T}, index::C, args::Tuple, state::SwitchSimulateState{T}) where {C, N, K, T} = process!(gen_fn, getindex(gen_fn.cases, index), args, state)

function simulate(gen_fn::Switch{C, N, K, T},
args::Tuple) where {C, N, K, T}

index = args[1]
state = SwitchSimulateState{T}(0.0, 0.0)
process!(gen_fn, index, args[2 : end], state)
return SwitchTrace{T}(gen_fn, state.index, state.subtrace, state.retval, args[2 : end], state.score, state.noise)
end
56 changes: 56 additions & 0 deletions src/modeling_library/switch/switch.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
struct Switch{C, N, K, T} <: GenerativeFunction{T, Trace}
branches::NTuple{N, GenerativeFunction{T}}
cases::Dict{C, Int}
function Switch(gen_fns::GenerativeFunction...)
@assert !isempty(gen_fns)
rettype = get_return_type(getindex(gen_fns, 1))
new{Int, length(gen_fns), typeof(gen_fns), rettype}(gen_fns, Dict{Int, Int}())
end
function Switch(d::Dict{C, Int}, gen_fns::GenerativeFunction...) where C
@assert !isempty(gen_fns)
rettype = get_return_type(getindex(gen_fns, 1))
new{C, length(gen_fns), typeof(gen_fns), rettype}(gen_fns, d)
end
end
export Switch

has_argument_grads(switch_fn::Switch) = map(zip(map(has_argument_grads, switch_fn.branches)...)) do as
all(as)
end
accepts_output_grad(switch_fn::Switch) = all(accepts_output_grad, switch_fn.branches)

function (gen_fn::Switch)(index::Int, args...)
(_, _, retval) = propose(gen_fn, (index, args...))
retval
end

function (gen_fn::Switch{C})(index::C, args...) where C
(_, _, retval) = propose(gen_fn, (gen_fn.cases[index], args...))
retval
end

include("assess.jl")
include("propose.jl")
include("simulate.jl")
include("generate.jl")
include("update.jl")
include("regenerate.jl")
include("backprop.jl")

@doc(
"""
gen_fn = Switch(gen_fns::GenerativeFunction...)

Returns a new generative function that accepts an argument tuple of type `Tuple{Int, ...}` where the first index indicates which branch to call.

gen_fn = Switch(d::Dict{T, Int}, gen_fns::GenerativeFunction...) where T

Returns a new generative function that accepts an argument tuple of type `Tuple{Int, ...}` or an argument tuple of type `Tuple{T, ...}` where the first index either indicates which branch to call, or indicates an index into `d` which maps to the selected branch. This form is meant for convenience - it allows the programmer to use `d` like if-else or case statements.

`Switch` is designed to allow for the expression of patterns of if-else control flow. `gen_fns` must satisfy a few requirements:

1. Each `gen_fn` in `gen_fns` must accept the same argument types.
2. Each `gen_fn` in `gen_fns` must return the same return type.

Otherwise, each `gen_fn` can come from different modeling languages, possess different traces, etc.
""", Switch)
Loading