Skip to content

Commit

Permalink
Merge pull request #397 from probcomp/20210327-ztangent-translatorfixes
Browse files Browse the repository at this point in the history
Trace translator fixes, test cases, and minor enhancements.
  • Loading branch information
ztangent authored Mar 30, 2021
2 parents ca012bc + 14b2403 commit 68dd3a7
Show file tree
Hide file tree
Showing 4 changed files with 395 additions and 124 deletions.
58 changes: 49 additions & 9 deletions docs/src/ref/trace_translators.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ Note that the transform DSL code does not specify what the two generative functi
This information will be required for computing probabilities and probability densities of traces.
We provide this information by constructing a **Trace Translator** that wraps the transform along with this transformation:
```julia
translator = DeterministicTraceTranslator(p2, (), f)
translator = DeterministicTraceTranslator(p2, (), choicemap(), f)
```
We then can then apply the translator to a trace of `p1` using function call syntax.
The translator returns a trace of `p2` and a log-weight that we can use to compute the probability (density) of the resulting trace:
Expand Down Expand Up @@ -228,12 +228,10 @@ We construct `q1` and `q2` so that the two spaces have the same size, and a one-
For our example above, we construct `q2` to sample the coordinate (``[0, 0.1]^2``) relative to the cell.
We construct `q1` to be empty--there is already a mapping from each trace of `p1` to each trace of `p2` that simply identifies what cell ``(i, j)`` a given point in ``[0, 1]^2`` is in, so no extra random choices are needed.
```julia
@gen function q1()
@gen function q1(p1_trace)
end

@gen function q2(p2_trace)
i = p2_trace[:i]
j = p2_trace[:j]
dx ~ uniform(0.0, 0.1)
dy ~ uniform(0.0, 0.1)
end
Expand All @@ -251,8 +249,8 @@ For example, the following defines a trace transform that maps from pairs of tra
j = ceil(y * 10)
@write(p2_trace[:i], i, :discrete)
@write(p2_trace[:j], j, :discrete)
@write(q2_trace[:dx], x / 10, :continuous)
@write(q2_trace[:dy], y / 10, :continuous)
@write(q2_trace[:dx], x - (i-1)/10, :continuous)
@write(q2_trace[:dy], y - (j-1)/10, :continuous)
end
```
and the inverse transform:
Expand All @@ -265,7 +263,7 @@ and the inverse transform:
x = (i-1)/10 + dx
y = (j-1)/10 + dy
@write(p1_trace[:x], x, :continuous)
@write(p1_trace[:y], x, :continuous)
@write(p1_trace[:y], y, :continuous)
end
```
which we associate as inverses:
Expand All @@ -289,7 +287,7 @@ translator = GeneralTraceTranslator(
```
Then, we can apply the trace translator to a trace (`t1`) of `p1` and get a trace (`t2`) of `p2` and a log-weight:
```julia
(t2, log_weight) = translator(t1)
t2, log_weight = translator(t1)
```


Expand All @@ -309,7 +307,49 @@ This has two benefits when the previous and new traces have random choices that

## Simple Extending Trace Translators

TODO Document
Simple extending trace translators extend an existing trace with new random
choices sampled from a proposal distribution, as well as any new observations.
The arguments of the trace may also be updated. This type of trace translation
is the basic operation used in [Particle Filtering](@ref). For example,
we might have a model that sequentially samples new latent variables `(:z, t)`
and observations `(:x, t)` for each timestep `t`:

```julia
@gen function model(T::Int)
for t in 1:T
z = {(:z, t)} ~ normal(0, 1)
x = {(:x, t)} ~ normal(z, 1)
end
end
```

Each time we observe a new `(:x ,t)`, we might want to propose `(:z, t)` so that
it is close in value:

```julia
@gen function proposal(trace::Trace, x::Real)
t = get_args(trace)[1] + 1
{(:z, t)} ~ normal(x, 1)
end
```

Suppose we initially generated a trace up to timestep `t=1`, e.g. by calling
`t1 = simulate(model, (1,))`. Now we observe `(:x, 2)` to be `5.0`. By
constructing a simple extending trace translator, we can simultaneously
update the trace `t1` with new arguments, introduce the new observation
at `(:x, 2)`, and propose a likely value for `(:z, 2)`:

```julia
translator = SimpleExtendingTraceTranslator(
p_new_args=(2,), p_argdiffs=(UnknownChange(),),
new_observations=choicemap((:x, 2) => 5.0),
q_forward=proposal, q_forward_args=(5.0,))
t2, log_weight = translator(t1)
```

Similar functionality can be achieved through a combination of [`propose`](@ref)
on the proposal and [`update`](@ref) on the original trace, but using a trace
translator provides a nice layer of abstraction.

## Trace Transform DSL

Expand Down
Loading

0 comments on commit 68dd3a7

Please sign in to comment.