From 17f680590899ff785eb6f72d2d71f9df945dcbb6 Mon Sep 17 00:00:00 2001 From: Xuan Date: Sat, 27 Mar 2021 01:23:53 -0400 Subject: [PATCH 01/12] Fix typos in GeneralTraceTranslator and DeterministicTraceTranslator. --- src/inference/trace_translators.jl | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/inference/trace_translators.jl b/src/inference/trace_translators.jl index 93becd05b..55381083c 100644 --- a/src/inference/trace_translators.jl +++ b/src/inference/trace_translators.jl @@ -69,7 +69,7 @@ end struct AuxInputTraceRetValToken end -struct ModelOutputTraceToken +struct ModelOutputTraceToken end struct AuxOutputTraceToken @@ -227,7 +227,7 @@ struct FirstPassResults "output proposal choice map ``u'``" u_back::ChoiceMap - + t_cont_reads::Dict u_cont_reads::Dict t_cont_writes::Dict @@ -326,7 +326,7 @@ function write(state::FirstPassState, dest::AuxOutputAddress, value, ::Continuou return value end -function copy(state::FirstPassState, src::ModelInputAddress, dest::ModelOutputAddress) +function copy(state::FirstPassState, src::ModelInputAddress, dest::ModelOutputAddress) from_addr, to_addr = src.addr, dest.addr model_choices = get_choices(state.model_trace) push!(state.results.t_copy_reads, from_addr) @@ -474,7 +474,7 @@ discard_skip_read_addr(addr, discard::ChoiceMap) = !has_value(discard, addr) discard_skip_read_addr(addr, discard::Nothing) = false function store_addr_info!(dict::Dict, addr, value::Real, next_index::Int) - dict[addr] = next_index + dict[addr] = next_index return 1 # number of elements of array end @@ -540,7 +540,7 @@ function jacobian_correction(transform::TraceTransformDSLProgram, prev_model_tra first_pass_results.t_copy_reads, first_pass_results.u_cont_reads, first_pass_results.u_copy_reads, discard) - + # create mappings for output addresses that are needed for Jacobian (cont_constraints_key_to_index, cont_u_back_key_to_index, n_output) = assemble_output_maps( first_pass_results.t_cont_writes, @@ -559,7 +559,7 @@ function jacobian_correction(transform::TraceTransformDSLProgram, prev_model_tra output_arr = Vector{T}(undef, n_output) jacobian_pass_state = JacobianPassState( - prev_model_trace, proposal_trace, input_arr, output_arr, + prev_model_trace, proposal_trace, input_arr, output_arr, t_key_to_index, u_key_to_index, cont_constraints_key_to_index, cont_u_back_key_to_index) @@ -584,7 +584,7 @@ function jacobian_correction(transform::TraceTransformDSLProgram, prev_model_tra if isinf(correction) @error "Weight correction is infinite; the function may not be an bijection" end - + return correction end @@ -659,7 +659,7 @@ function (translator::DeterministicTraceTranslator)( log_weight = new_model_score - prev_model_score + log_abs_determinant if check - check_observations(get_choices(new_model_trace), observations) + check_observations(get_choices(new_model_trace), translator.new_observations) (prev_model_trace_rt, _) = deterministic_trace_translator_run_transform( inverse(translator.f), prev_observations, new_model_trace, get_gen_fn(prev_model_trace), get_args(prev_model_trace)) @@ -709,7 +709,7 @@ function general_trace_translator_run_transform( f::TraceTransformDSLProgram, new_observations::ChoiceMap, prev_model_trace::Trace, forward_proposal_trace::Trace, p_new::GenerativeFunction, p_new_args::Tuple, - q_backard::GenerativeFunction, q_backward_args::Tuple) + q_backward::GenerativeFunction, q_backward_args::Tuple) first_pass_results = run_first_pass(f, prev_model_trace, forward_proposal_trace) log_abs_determinant = jacobian_correction( f, prev_model_trace, forward_proposal_trace, first_pass_results, nothing) @@ -724,7 +724,7 @@ function (translator::GeneralTraceTranslator)( prev_model_trace::Trace; check=false, prev_observations=EmptyChoiceMap()) # sample auxiliary trace - forward_proposal_trace = simulate(proposal, (prev_model_trace, translator.q_forward_args...,)) + forward_proposal_trace = simulate(translator.q_forward, (prev_model_trace, translator.q_forward_args...,)) # apply trace transform (new_model_trace, backward_proposal_trace, log_abs_determinant) = general_trace_translator_run_transform( @@ -772,7 +772,7 @@ Run the translator with: (output_trace, log_weight) = translator(input_trace) """ -@with_kw struct SimpleExtendingTraceTranslator +@with_kw struct SimpleExtendingTraceTranslator p_new_args::Tuple = () argdiffs::Tuple = () new_obs::ChoiceMap = EmptyChoiceMap() @@ -884,7 +884,7 @@ function (translator::SymmetricTraceTranslator{<:Function})( forward_retval = get_retval(forward_trace) (new_model_trace, backward_choices, log_weight) = translator.involution( prev_model_trace, forward_choices, forward_retval, translator.q_args) - (backward_score, backward_retval) = assess(translator.q, (new_model_trace, translator.q_args...), backward_choices) + (backward_score, backward_retval) = assess(translator.q, (new_model_trace, translator.q_args...), backward_choices) log_weight += (backward_score - forward_score) From e03e81f496c86618d9a0bab0b27e0cea63f09a72 Mon Sep 17 00:00:00 2001 From: Xuan Date: Sat, 27 Mar 2021 01:27:11 -0400 Subject: [PATCH 02/12] Update field names in SimpleExtendingTraceTranslator for consistency. --- src/inference/trace_translators.jl | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/src/inference/trace_translators.jl b/src/inference/trace_translators.jl index 55381083c..ae261e997 100644 --- a/src/inference/trace_translators.jl +++ b/src/inference/trace_translators.jl @@ -761,10 +761,10 @@ end """ translator = SimpleExtendingTraceTranslator(; p_new_args::Tuple = (), - argdiffs::Tuple = (), - new_obs::ChoiceMap = EmptyChoiceMap(), - q_fwd::GenerativeFunction, - q_fwd_args::Tuple = ()) + p_argdiffs::Tuple = (), + new_observations::ChoiceMap = EmptyChoiceMap(), + q_forward::GenerativeFunction, + q_forward_args::Tuple = ()) Constructor for a simple extending trace translator. @@ -774,23 +774,23 @@ Run the translator with: """ @with_kw struct SimpleExtendingTraceTranslator p_new_args::Tuple = () - argdiffs::Tuple = () - new_obs::ChoiceMap = EmptyChoiceMap() - q_fwd::GenerativeFunction - q_fwd_args::Tuple = () + p_argdiffs::Tuple = () + new_observations::ChoiceMap = EmptyChoiceMap() + q_forward::GenerativeFunction + q_forward_args::Tuple = () end function (translator::SimpleExtendingTraceTranslator)(prev_model_trace::Trace) # simulate from auxiliary program - forward_proposal_trace = simulate(translator.q_fwd, (prev_model_trace, translator.q_fwd_args...,)) + forward_proposal_trace = simulate(translator.q_forward, (prev_model_trace, translator.q_forward_args...,)) forward_proposal_score = get_score(forward_proposal_trace) # computing the new trace via update - constraints = merge(get_choices(forward_proposal_trace), translator.new_obs) + constraints = merge(get_choices(forward_proposal_trace), translator.new_observations) (new_model_trace, log_model_weight, _, discard) = update( prev_model_trace, translator.p_new_args, - translator.argdiffs, constraints) + translator.p_argdiffs, constraints) if !isempty(discard) @error("can only extend the trace with random choices, cannot remove random choices") From cc3bccb4631a01f344d5c9487dcc0697be0ddb6c Mon Sep 17 00:00:00 2001 From: Xuan Date: Sat, 27 Mar 2021 01:43:23 -0400 Subject: [PATCH 03/12] Fix typo in @transform macro, throw error on syntax mismatch. --- src/inference/trace_translators.jl | 9 +++------ 1 file changed, 3 insertions(+), 6 deletions(-) diff --git a/src/inference/trace_translators.jl b/src/inference/trace_translators.jl index ae261e997..bce99a7f9 100644 --- a/src/inference/trace_translators.jl +++ b/src/inference/trace_translators.jl @@ -126,16 +126,15 @@ macro transform(f_expr, src_expr, to_symbol::Symbol, dest_expr, body) err = true end if MacroTools.@capture(dest_expr, (model_out_, aux_out_)) - elseif MacroTools.@capture(src_expr, (model_out_)) + elseif MacroTools.@capture(dest_expr, (model_out_)) aux_out = gensym("dummy_aux") else err = true end + if err error(syntax_err) end - fn! = gensym("$(esc(f))_fn!") - + fn! = gensym(Symbol(f, "_fn!")) return quote - # mutates the state function $fn!( $(esc(bij_state))::Union{FirstPassState,JacobianPassState}, @@ -149,9 +148,7 @@ macro transform(f_expr, src_expr, to_symbol::Symbol, dest_expr, body) $(esc(body)) return nothing end - Core.@__doc__ $(esc(f)) = TraceTransformDSLProgram($fn!, nothing) - end end From 6a60dcba103f560f7d5bdf45071457fa2764c878 Mon Sep 17 00:00:00 2001 From: Xuan Date: Sat, 27 Mar 2021 01:58:59 -0400 Subject: [PATCH 04/12] Fixed bugs for trace transforms without auxiliary traces. --- src/inference/trace_translators.jl | 41 +++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 12 deletions(-) diff --git a/src/inference/trace_translators.jl b/src/inference/trace_translators.jl index bce99a7f9..e4250bb18 100644 --- a/src/inference/trace_translators.jl +++ b/src/inference/trace_translators.jl @@ -139,8 +139,8 @@ macro transform(f_expr, src_expr, to_symbol::Symbol, dest_expr, body) function $fn!( $(esc(bij_state))::Union{FirstPassState,JacobianPassState}, $(map(esc, args)...)) - model_args = get_args($(esc(bij_state)).model_trace) - aux_args = get_args($(esc(bij_state)).aux_trace) + model_args = get_model_args($(esc(bij_state))) + aux_args = get_aux_args($(esc(bij_state))) $(esc(model_in)) = ModelInputTraceToken(model_args) $(esc(model_out)) = ModelOutputTraceToken() $(esc(aux_in)) = AuxInputTraceToken(aux_args) @@ -241,22 +241,31 @@ function FirstPassResults() end struct FirstPassState - - "trace containing the input model choice map ``t``" + "Trace containing the input model choice map ``t``" model_trace::Trace model_choices::ChoiceMap - "the input proposal choice map ``u``" - aux_trace::Trace + "The input proposal choice map ``u``" + aux_trace::Union{Trace,Nothing} aux_choices::ChoiceMap results::FirstPassResults end -function FirstPassState(model_trace, aux_trace) - return FirstPassState( - model_trace, get_choices(model_trace), - aux_trace, get_choices(aux_trace), FirstPassResults()) +FirstPassState(model_trace::Trace, aux_trace::Trace) = + FirstPassState(model_trace, get_choices(model_trace), + aux_trace, get_choices(aux_trace), FirstPassResults()) + +FirstPassState(model_trace::Trace, aux_trace::Nothing) = + FirstPassState(model_trace, get_choices(model_trace), + aux_trace, EmptyChoiceMap(), FirstPassResults()) + +function get_model_args(state::FirstPassState) + return get_args(state.model_trace) +end + +function get_aux_args(state::FirstPassState) + return state.aux_trace === nothing ? () : get_args(state.aux_trace) end function run_first_pass(transform::TraceTransformDSLProgram, model_trace, aux_trace) @@ -376,8 +385,8 @@ end ##################################################################### struct JacobianPassState{T<:Real} - model_trace - aux_trace + model_trace::Trace + aux_trace::Union{Trace,Nothing} input_arr::AbstractArray{T} output_arr::Array{T} t_key_to_index::Dict @@ -386,6 +395,14 @@ struct JacobianPassState{T<:Real} cont_u_back_key_to_index::Dict end +function get_model_args(state::JacobianPassState) + return get_args(state.model_trace) +end + +function get_aux_args(state::JacobianPassState) + return state.aux_trace === nothing ? () : get_args(state.aux_trace) +end + function read(state::JacobianPassState, src::ModelInputTraceRetValToken, ::DiscreteAnn) return get_retval(state.model_trace) end From 508b3fe26227b2f8f96f4691022720695861209e Mon Sep 17 00:00:00 2001 From: Xuan Date: Sat, 27 Mar 2021 03:03:24 -0400 Subject: [PATCH 05/12] More typo fixes. --- src/inference/trace_translators.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/inference/trace_translators.jl b/src/inference/trace_translators.jl index e4250bb18..533d31096 100644 --- a/src/inference/trace_translators.jl +++ b/src/inference/trace_translators.jl @@ -665,7 +665,7 @@ function (translator::DeterministicTraceTranslator)( # apply trace transform (new_model_trace, log_abs_determinant) = deterministic_trace_translator_run_transform( - translator.f, translator.new_observations, prev_model_trace, translator.p_new, translator.p_new_args) + translator.f, translator.new_observations, prev_model_trace, translator.p_new, translator.p_args) # compute log weight prev_model_score = get_score(prev_model_trace) @@ -742,7 +742,7 @@ function (translator::GeneralTraceTranslator)( # apply trace transform (new_model_trace, backward_proposal_trace, log_abs_determinant) = general_trace_translator_run_transform( - translator.f, prev_model_trace, forward_proposal_trace, translator.new_observations, + translator.f, translator.new_observations, prev_model_trace, forward_proposal_trace, translator.p_new, translator.p_new_args, translator.q_backward, translator.q_backward_args) # compute log weight From e24073e4267a19f2b1ed867f5b1a66f4253973b8 Mon Sep 17 00:00:00 2001 From: Xuan Date: Sat, 27 Mar 2021 03:04:04 -0400 Subject: [PATCH 06/12] Add translator test cases based on examples in docs. --- test/inference/inference.jl | 1 + test/inference/trace_translators.jl | 171 ++++++++++++++++++++++++++++ 2 files changed, 172 insertions(+) create mode 100644 test/inference/trace_translators.jl diff --git a/test/inference/inference.jl b/test/inference/inference.jl index 35f00bab1..c86bff2d2 100644 --- a/test/inference/inference.jl +++ b/test/inference/inference.jl @@ -7,3 +7,4 @@ include("hmc.jl") include("map_optimize.jl") include("elliptical_slice.jl") include("mh.jl") +include("trace_translators.jl") diff --git a/test/inference/trace_translators.jl b/test/inference/trace_translators.jl new file mode 100644 index 000000000..0ce361fa9 --- /dev/null +++ b/test/inference/trace_translators.jl @@ -0,0 +1,171 @@ +@testset "trace translators" begin + +@testset "DeterministicTraceTranslator" begin + + @gen function p1() + r ~ inv_gamma(1, 1) + theta ~ uniform(-pi/2, pi/2) + end + + @gen function p2() + x ~ normal(0, 1) + y ~ normal(0, 1) + end + + @transform f (t1) to (t2) begin + r = @read(t1[:r], :continuous) + theta = @read(t1[:theta], :continuous) + @write(t2[:x], r * cos(theta), :continuous) + @write(t2[:y], r * sin(theta), :continuous) + end + + @transform finv (t2) to (t1) begin + x = @read(t2[:x], :continuous) + y = @read(t2[:y], :continuous) + r = sqrt(x^2 + y^2) + @write(t1[:r], sqrt(x^2 + y^2), :continuous) + @write(t1[:theta], atan(y, x), :continuous) + end + + pair_bijections!(f, finv) + + translator = DeterministicTraceTranslator(p2, (), choicemap(), f) + t1, _ = generate(p1, (), choicemap(:theta => 0)) + t2, log_weight = translator(t1; check=true) + @test t2[:y] == 0 + + translator = DeterministicTraceTranslator(p1, (), choicemap(), finv) + t2, _ = generate(p2, (), choicemap(:y => 0, :x => 1)) + t1, log_weight = translator(t2; check=true) + @test t1[:theta] == 0 + +end + +@testset "SimpleExtendingTraceTranslator" begin + + @gen function model(T::Int) + for t in 1:T + z = {(:z, t)} ~ normal(0, 1) + x = {(:x, t)} ~ normal(z, 1) + end + end + + @gen function proposal(trace::Trace, x) + t = get_args(trace)[1] + 1 + {(:z, t)} ~ normal(x, 1) + end + + translator = SimpleExtendingTraceTranslator( + p_new_args=(2,), p_argdiffs=(UnknownChange(),), + new_observations=choicemap((:x, 2) => 5.0), + q_forward=proposal, q_forward_args=(5.0,)) + t1, _ = generate(model, (1,), choicemap()) + t2, log_weight = translator(t1) + + prop_choices = choicemap((:z, 2) => t2[(:z, 2)]) + prop_weight, _ = assess(proposal, (t1, 5.0), prop_choices) + constraints = merge(prop_choices, choicemap((:x, 2) => 5.0)) + t3, up_weight, _, _ = update(t1, (2,), (UnknownChange(),), constraints) + @test log_weight == up_weight - prop_weight + +end + +@testset "SymmetricTraceTranslator" begin + + @gen function model() + z ~ bernoulli(0.5) + if z + i ~ uniform_discrete(1, 10) + else + x ~ uniform(0, 1) + end + end + + @gen function proposal(trace) + if trace[:z] + dx ~ uniform(0.0, 0.1) + end + end + + @transform involution (p1_trace, q1_trace) to (p2_trace, q2_trace) begin + if @read(p1_trace[:z], :discrete) + @write(p2_trace[:z], false, :discrete) + i = @read(p1_trace[:i], :discrete) + dx = @read(q1_trace[:dx], :continuous) + x = (i-1)/10 + dx + @write(p2_trace[:x], x, :continuous) + else + @write(p2_trace[:z], true, :discrete) + x = @read(p1_trace[:x], :continuous) + i = ceil(x * 10) + @write(p2_trace[:i], i, :discrete) + @write(q2_trace[:dx], x - (i-1)/10, :continuous) + end + end + + is_involution!(involution) + + translator = SymmetricTraceTranslator(proposal, (), involution) + + t1, _ = generate(model, (), choicemap(:z => false, :x => 0.95)) + t2, log_weight = translator(t1; check=true) + @test t2[:z] == true && t2[:i] == 10 + +end + +@testset "GeneralTraceTranslator" begin + + @gen function p1() + x ~ uniform(0, 1) + y ~ uniform(0, 1) + end + + @gen function p2() + i ~ uniform_discrete(1, 10) # interval [(i-1)/10, i/10] + j ~ uniform_discrete(1, 10) # interval [(j-1)/10, j/10] + end + + @gen function q1(p1_trace) end + + @gen function q2(p2_trace) + dx ~ uniform(0.0, 0.1) + dy ~ uniform(0.0, 0.1) + end + + @transform f (p1_trace, q1_trace) to (p2_trace, q2_trace) begin + x = @read(p1_trace[:x], :continuous) + y = @read(p1_trace[:y], :continuous) + i = ceil(x * 10) + j = ceil(y * 10) + @write(p2_trace[:i], i, :discrete) + @write(p2_trace[:j], j, :discrete) + @write(q2_trace[:dx], x - (i-1)/10, :continuous) + @write(q2_trace[:dy], y - (j-1)/10, :continuous) + end + + @transform f_inv (p2_trace, q2_trace) to (p1_trace, q1_trace) begin + i = @read(p2_trace[:i], :discrete) + j = @read(p2_trace[:j], :discrete) + dx = @read(q2_trace[:dx], :continuous) + dy = @read(q2_trace[:dy], :continuous) + x = (i-1)/10 + dx + y = (j-1)/10 + dy + @write(p1_trace[:x], x, :continuous) + @write(p1_trace[:y], y, :continuous) + end + + pair_bijections!(f, f_inv) + + translator = GeneralTraceTranslator( + p_new=p2, p_new_args=(), + new_observations=choicemap(), + q_forward=q1, q_forward_args=(), + q_backward=q2, q_backward_args=(), f=f) + + t1, _ = generate(p1, (), choicemap(:x => 0.05, :y => 0.95)) + t2, log_weight = translator(t1; check=true) + @test t2[:i] == 1 && t2[:j] == 10 + +end + +end From 1177b58197b6a73ddc3a46437af82cb1c6619eb0 Mon Sep 17 00:00:00 2001 From: Xuan Date: Sat, 27 Mar 2021 03:06:06 -0400 Subject: [PATCH 07/12] Fixed bugs in examples for trace translator docs. --- docs/src/ref/trace_translators.md | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/docs/src/ref/trace_translators.md b/docs/src/ref/trace_translators.md index f8146b369..e120f43b3 100644 --- a/docs/src/ref/trace_translators.md +++ b/docs/src/ref/trace_translators.md @@ -78,7 +78,7 @@ Note that the transform DSL code does not specify what the two generative functi This information will be required for computing probabilities and probability densities of traces. We provide this information by constructing a **Trace Translator** that wraps the transform along with this transformation: ```julia -translator = DeterministicTraceTranslator(p2, (), f) +translator = DeterministicTraceTranslator(p2, (), choicemap(), f) ``` We then can then apply the translator to a trace of `p1` using function call syntax. The translator returns a trace of `p2` and a log-weight that we can use to compute the probability (density) of the resulting trace: @@ -228,12 +228,10 @@ We construct `q1` and `q2` so that the two spaces have the same size, and a one- For our example above, we construct `q2` to sample the coordinate (``[0, 0.1]^2``) relative to the cell. We construct `q1` to be empty--there is already a mapping from each trace of `p1` to each trace of `p2` that simply identifies what cell ``(i, j)`` a given point in ``[0, 1]^2`` is in, so no extra random choices are needed. ```julia -@gen function q1() +@gen function q1(p1_trace) end @gen function q2(p2_trace) - i = p2_trace[:i] - j = p2_trace[:j] dx ~ uniform(0.0, 0.1) dy ~ uniform(0.0, 0.1) end @@ -251,8 +249,8 @@ For example, the following defines a trace transform that maps from pairs of tra j = ceil(y * 10) @write(p2_trace[:i], i, :discrete) @write(p2_trace[:j], j, :discrete) - @write(q2_trace[:dx], x / 10, :continuous) - @write(q2_trace[:dy], y / 10, :continuous) + @write(q2_trace[:dx], x - (i-1)/10, :continuous) + @write(q2_trace[:dy], y - (j-1)/10, :continuous) end ``` and the inverse transform: @@ -265,7 +263,7 @@ and the inverse transform: x = (i-1)/10 + dx y = (j-1)/10 + dy @write(p1_trace[:x], x, :continuous) - @write(p1_trace[:y], x, :continuous) + @write(p1_trace[:y], y, :continuous) end ``` which we associate as inverses: @@ -289,7 +287,7 @@ translator = GeneralTraceTranslator( ``` Then, we can apply the trace translator to a trace (`t1`) of `p1` and get a trace (`t2`) of `p2` and a log-weight: ```julia -(t2, log_weight) = translator(t1) +t2, log_weight = translator(t1) ``` From f722f5ee385c5b8822b6655d06a7fdb2c4a4a1d0 Mon Sep 17 00:00:00 2001 From: Xuan Date: Sat, 27 Mar 2021 03:24:56 -0400 Subject: [PATCH 08/12] Document SimpleExtendingTraceTranslator. --- docs/src/ref/trace_translators.md | 44 ++++++++++++++++++++++++++++++- 1 file changed, 43 insertions(+), 1 deletion(-) diff --git a/docs/src/ref/trace_translators.md b/docs/src/ref/trace_translators.md index e120f43b3..cec0cf963 100644 --- a/docs/src/ref/trace_translators.md +++ b/docs/src/ref/trace_translators.md @@ -307,7 +307,49 @@ This has two benefits when the previous and new traces have random choices that ## Simple Extending Trace Translators -TODO Document +Simple extending trace translators extend an existing trace with new random +choices sampled from a proposal distribution, as well as any new observations. +The arguments of the trace may also be updated. This type of trace translation +is the basic operation used in [Particle Filtering](@ref). For example, +we might have a model that sequentially samples new latent variables `(:z, t)` +and observations `(:x, t)` for each timestep `t`: + +```julia +@gen function model(T::Int) + for t in 1:T + z = {(:z, t)} ~ normal(0, 1) + x = {(:x, t)} ~ normal(z, 1) + end +end +``` + +Each time we observe a new `(:x ,t)`, we might want to propose `(:z, t)` so that +it is close in value: + +```julia +@gen function proposal(trace::Trace, x::Real) + t = get_args(trace)[1] + 1 + {(:z, t)} ~ normal(x, 1) +end +``` + +Suppose we initially generated a trace up to timestep `t=1`, e.g. by calling +`t1 = simulate(model, (1,))`. Now we observe `(:x, 2)` to be `5.0`. By +constructing a simple extending trace translator, we can simultaneously +update the trace `t1` with new arguments, introduce the new observation +at `(:x, 2)`, and propose a likely value for `(:z, 2)`: + +```julia +translator = SimpleExtendingTraceTranslator( + p_new_args=(2,), p_argdiffs=(UnknownChange(),), + new_observations=choicemap((:x, 2) => 5.0), + q_forward=proposal, q_forward_args=(5.0,)) +t2, log_weight = translator(t1) +``` + +Similar functionality can be achieved through a combination of [`propose`](@ref) +on the proposal and [`update`](@ref) on the original trace, but using a trace +translator provides a nice layer of abstraction. ## Trace Transform DSL From f2491258e382b0d1af58ecd96c9b40b04684b5c1 Mon Sep 17 00:00:00 2001 From: Xuan Date: Sat, 27 Mar 2021 03:25:42 -0400 Subject: [PATCH 09/12] Use simulate instead of generate in test case. --- test/inference/trace_translators.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/inference/trace_translators.jl b/test/inference/trace_translators.jl index 0ce361fa9..6008ff5cf 100644 --- a/test/inference/trace_translators.jl +++ b/test/inference/trace_translators.jl @@ -59,7 +59,7 @@ end p_new_args=(2,), p_argdiffs=(UnknownChange(),), new_observations=choicemap((:x, 2) => 5.0), q_forward=proposal, q_forward_args=(5.0,)) - t1, _ = generate(model, (1,), choicemap()) + t1 = simulate(model, (1,)) t2, log_weight = translator(t1) prop_choices = choicemap((:z, 2) => t2[(:z, 2)]) From 51ad6e78150a0120cb9718975c257c3d4192b119 Mon Sep 17 00:00:00 2001 From: Xuan Date: Sat, 27 Mar 2021 13:07:32 -0400 Subject: [PATCH 10/12] Refactor `run_transform` code. --- src/inference/trace_translators.jl | 89 ++++++++++++++++-------------- 1 file changed, 49 insertions(+), 40 deletions(-) diff --git a/src/inference/trace_translators.jl b/src/inference/trace_translators.jl index 533d31096..865613c97 100644 --- a/src/inference/trace_translators.jl +++ b/src/inference/trace_translators.jl @@ -1,7 +1,7 @@ import ForwardDiff import MacroTools import LinearAlgebra -import Parameters: @with_kw +import Parameters: @with_kw, @unpack ####################### # trace transform DSL # @@ -644,14 +644,21 @@ Run the translator with: f::TraceTransformDSLProgram # a bijection end -function deterministic_trace_translator_run_transform( - f::TraceTransformDSLProgram, new_observations::ChoiceMap, - prev_model_trace::Trace, p_new::GenerativeFunction, p_new_args::Tuple) +function inverse(translator::DeterministicTraceTranslator, + prev_model_trace::Trace, prev_observations::ChoiceMap=EmptyChoiceMap()) + return DeterministicTraceTranslator( + get_gen_fn(prev_model_trace), get_args(prev_model_trace), + prev_observations, inverse(translator.f)) +end + +function run_transform(translator::DeterministicTraceTranslator, + prev_model_trace::Trace) + @unpack p_new, p_args, new_observations, f = translator first_pass_results = run_first_pass(f, prev_model_trace, nothing) log_abs_determinant = jacobian_correction( f, prev_model_trace, nothing, first_pass_results, nothing) constraints = merge(first_pass_results.constraints, new_observations) - (new_model_trace, _) = generate(p_new, p_new_args, constraints) + (new_model_trace, _) = generate(p_new, p_args, constraints) return (new_model_trace, log_abs_determinant) end @@ -664,8 +671,8 @@ function (translator::DeterministicTraceTranslator)( prev_model_trace::Trace; check=false, prev_observations=EmptyChoiceMap()) # apply trace transform - (new_model_trace, log_abs_determinant) = deterministic_trace_translator_run_transform( - translator.f, translator.new_observations, prev_model_trace, translator.p_new, translator.p_args) + (new_model_trace, log_abs_determinant) = + run_transform(translator, prev_model_trace) # compute log weight prev_model_score = get_score(prev_model_trace) @@ -674,9 +681,8 @@ function (translator::DeterministicTraceTranslator)( if check check_observations(get_choices(new_model_trace), translator.new_observations) - (prev_model_trace_rt, _) = deterministic_trace_translator_run_transform( - inverse(translator.f), prev_observations, new_model_trace, - get_gen_fn(prev_model_trace), get_args(prev_model_trace)) + inverter = inverse(translator, prev_model_trace) + prev_model_trace_rt, _ = run_transform(inverter, new_model_trace) check_round_trip(prev_model_trace, prev_model_trace_rt) end @@ -719,11 +725,19 @@ If `check` is enabled, then `prev_observations` is a choice map containing the o f::TraceTransformDSLProgram # a bijection end -function general_trace_translator_run_transform( - f::TraceTransformDSLProgram, new_observations::ChoiceMap, - prev_model_trace::Trace, forward_proposal_trace::Trace, - p_new::GenerativeFunction, p_new_args::Tuple, - q_backward::GenerativeFunction, q_backward_args::Tuple) +function inverse(translator::GeneralTraceTranslator, prev_model_trace::Trace, + prev_observations::ChoiceMap=EmptyChoiceMap()) + return GeneralTraceTranslator( + get_gen_fn(prev_model_trace), get_args(prev_model_trace), + prev_observations, translator.q_backward, translator.q_backward_args, + translator.q_forward, translator.q_forward_args, + inverse(translator.f)) +end + +function run_transform(translator::GeneralTraceTranslator, + prev_model_trace::Trace, forward_proposal_trace::Trace) + @unpack f, new_observations = translator + @unpack p_new, p_new_args, q_backward, q_backward_args = translator first_pass_results = run_first_pass(f, prev_model_trace, forward_proposal_trace) log_abs_determinant = jacobian_correction( f, prev_model_trace, forward_proposal_trace, first_pass_results, nothing) @@ -741,9 +755,8 @@ function (translator::GeneralTraceTranslator)( forward_proposal_trace = simulate(translator.q_forward, (prev_model_trace, translator.q_forward_args...,)) # apply trace transform - (new_model_trace, backward_proposal_trace, log_abs_determinant) = general_trace_translator_run_transform( - translator.f, translator.new_observations, prev_model_trace, forward_proposal_trace, - translator.p_new, translator.p_new_args, translator.q_backward, translator.q_backward_args) + (new_model_trace, backward_proposal_trace, log_abs_determinant) = + run_transform(translator, prev_model_trace, forward_proposal_trace) # compute log weight prev_model_score = get_score(prev_model_trace) @@ -753,16 +766,11 @@ function (translator::GeneralTraceTranslator)( log_weight = new_model_score - prev_model_score + backward_proposal_score + forward_proposal_score + log_abs_determinant if check - forward_proposal_choices = get_choices(forward_proposal_trace) - f_inv = inverse(translator.f) - (prev_model_trace_rt, forward_proposal_trace_rt, _) = general_trace_translator_run_transform( - inverse(translator.f), prev_observations, - new_model_trace, backward_proposal_trace, - get_gen_fn(prev_model_trace), get_args(prev_model_trace), - translator.q_forward, translator.q_forward_args) - check_round_trip( - prev_model_trace, prev_model_trace_rt, - forward_proposal_trace, forward_proposal_trace_rt) + inverter = inverse(translator, prev_model_trace, prev_observations) + (prev_model_trace_rt, forward_proposal_trace_rt, _) = + run_transform(inverter, new_model_trace, backward_proposal_trace) + check_round_trip(prev_model_trace, prev_model_trace_rt, + forward_proposal_trace, forward_proposal_trace_rt) end return (new_model_trace, log_weight) @@ -843,10 +851,13 @@ If `check` is enabled, then `observations` is a choice map containing the observ involution::T # an involution end -function symmetric_trace_translator_run_transform( - involution::TraceTransformDSLProgram, - prev_model_trace::Trace, forward_proposal_trace::Trace, - q::GenerativeFunction, q_args::Tuple) +function inverse(translator::SymmetricTraceTranslator, prev_model_trace=nothing) + return translator +end + +function run_transform(translator::SymmetricTraceTranslator, + prev_model_trace::Trace, forward_proposal_trace::Trace) + @unpack involution, q, q_args = translator first_pass_results = run_first_pass(involution, prev_model_trace, forward_proposal_trace) (new_model_trace, log_model_weight, _, discard) = update( prev_model_trace, get_args(prev_model_trace), @@ -866,8 +877,8 @@ function (translator::SymmetricTraceTranslator{TraceTransformDSLProgram})( forward_proposal_trace = simulate(translator.q, (prev_model_trace, translator.q_args...,)) # apply trace transform - (new_model_trace, backward_proposal_trace, log_abs_determinant) = symmetric_trace_translator_run_transform( - translator.involution, prev_model_trace, forward_proposal_trace, translator.q, translator.q_args) + (new_model_trace, backward_proposal_trace, log_abs_determinant) = + run_transform(translator, prev_model_trace, forward_proposal_trace) # compute log weight prev_model_score = get_score(prev_model_trace) @@ -878,12 +889,10 @@ function (translator::SymmetricTraceTranslator{TraceTransformDSLProgram})( if check check_observations(get_choices(new_model_trace), observations) - forward_proposal_choices = get_choices(forward_proposal_trace) - (prev_model_trace_rt, forward_proposal_trace_rt, _) = symmetric_trace_translator_run_transform( - translator.involution, new_model_trace, backward_proposal_trace, translator.q, translator.q_args) - check_round_trip( - prev_model_trace, prev_model_trace_rt, - forward_proposal_trace, forward_proposal_trace_rt) + (prev_model_trace_rt, forward_proposal_trace_rt, _) = + run_transform(translator, new_model_trace, backward_proposal_trace) + check_round_trip(prev_model_trace, prev_model_trace_rt, + forward_proposal_trace, forward_proposal_trace_rt) end return (new_model_trace, log_weight) From 5f5097da009212588e8b89a033735a6dd32136c3 Mon Sep 17 00:00:00 2001 From: Xuan Date: Sat, 27 Mar 2021 13:27:24 -0400 Subject: [PATCH 11/12] Coding style and readability improvements. --- src/inference/trace_translators.jl | 83 +++++++++++++++++++----------- 1 file changed, 53 insertions(+), 30 deletions(-) diff --git a/src/inference/trace_translators.jl b/src/inference/trace_translators.jl index 865613c97..a637bc1bf 100644 --- a/src/inference/trace_translators.jl +++ b/src/inference/trace_translators.jl @@ -23,7 +23,8 @@ end """ pair_bijections!(f1::TraceTransformDSLProgram, f2::TraceTransformDSLProgram) -Assert that a pair of bijections contsructed using the [Trace Transform DSL](@ref) are inverses of one another. +Assert that a pair of bijections contsructed using the [Trace Transform DSL](@ref) are +inverses of one another. """ function pair_bijections!(f1::TraceTransformDSLProgram, f2::TraceTransformDSLProgram) f1.inverse = f2 @@ -46,7 +47,8 @@ end Obtain the inverse of a bijection that was constructed with the [Trace Transform DSL](@ref). -The inverse must have been associated with the bijection either via [`pair_bijections!`](@ref) or [`is_involution!`])(@ref). +The inverse must have been associated with the bijection either via +[`pair_bijections!`](@ref) or [`is_involution!`])(@ref). """ function inverse(bijection::TraceTransformDSLProgram) if isnothing(bijection.inverse) @@ -110,7 +112,9 @@ const bij_state = gensym("bij_state") Write a program in the [Trace Transform DSL](@ref). """ macro transform(f_expr, src_expr, to_symbol::Symbol, dest_expr, body) - syntax_err = "valid syntactic forms:\n@transform f (..) to (..) begin .. end\n@transform f(..) (..) to (..) begin .. end" + syntax_err = """valid syntactic forms: + @transform f (..) to (..) begin .. end + @transform f(..) (..) to (..) begin .. end""" err = false if MacroTools.@capture(f_expr, f_(args__)) elseif MacroTools.@capture(f_expr, f_) @@ -180,7 +184,8 @@ end Macro for reading the value of a random choice from an input trace in the [Trace Transform DSL](@ref). - is of the form [] where is an input trace, and is either :discrete or :continuous. + is of the form [] where is an input trace, and +is either :discrete or :continuous. """ macro read(src, ann::QuoteNode) return quote read($(esc(bij_state)), $(esc(src)), $(esc(typed(ann.value)))) end @@ -191,7 +196,8 @@ end Macro for writing the value of a random choice to an output trace in the [Trace Transform DSL](@ref). - is of the form [] where is an input trace, and is either :discrete or :continuous. + is of the form [] where is an input trace, and + is either :discrete or :continuous. """ macro write(dest, val, ann::QuoteNode) return quote write($(esc(bij_state)), $(esc(dest)), $(esc(val)), $(esc(typed(ann.value)))) end @@ -200,9 +206,11 @@ end """ @copy(, ) -Macro for copying the value of a random choice (or a whole namespace of random choices) from an input trace to an output trace in the [Trace Transform DSL](@ref). +Macro for copying the value of a random choice (or a whole namespace of random choices) +from an input trace to an output trace in the [Trace Transform DSL](@ref). - is of the form [] where is an input trace, and is either :discrete or :continuous. + is of the form [] where is an input trace, +and is either :discrete or :continuous. """ macro copy(src, dest) return quote copy($(esc(bij_state)), $(esc(src)), $(esc(dest))) end @@ -498,9 +506,9 @@ function store_addr_info!(dict::Dict, addr, value::AbstractArray{<:Real}, next_i return len # number of elements of array end -function assemble_input_array_and_maps( - t_cont_reads, t_copy_reads, u_cont_reads, u_copy_reads, discard::Union{ChoiceMap,Nothing}) - +function assemble_input_array_and_maps(t_cont_reads, t_copy_reads, + u_cont_reads, u_copy_reads, + discard::Union{ChoiceMap,Nothing}) input_arr = Vector{Float64}() next_input_index = 1 @@ -534,18 +542,21 @@ function assemble_output_maps(t_cont_writes, u_cont_writes) cont_constraints_key_to_index = Dict() for (addr, v) in t_cont_writes - next_output_index += store_addr_info!(cont_constraints_key_to_index, addr, v, next_output_index) + next_output_index += + store_addr_info!(cont_constraints_key_to_index, addr, v, next_output_index) end cont_u_back_key_to_index = Dict() for (addr, v) in u_cont_writes - next_output_index += store_addr_info!(cont_u_back_key_to_index, addr, v, next_output_index) + next_output_index += + store_addr_info!(cont_u_back_key_to_index, addr, v, next_output_index) end return (cont_constraints_key_to_index, cont_u_back_key_to_index, next_output_index-1) end -function jacobian_correction(transform::TraceTransformDSLProgram, prev_model_trace, proposal_trace, first_pass_results, discard) +function jacobian_correction(transform::TraceTransformDSLProgram, + prev_model_trace, proposal_trace, first_pass_results, discard) # create input array and mappings input addresses that are needed for Jacobian # exclude addresses that were copied explicitly to another address @@ -680,7 +691,8 @@ function (translator::DeterministicTraceTranslator)( log_weight = new_model_score - prev_model_score + log_abs_determinant if check - check_observations(get_choices(new_model_trace), translator.new_observations) + check_observations(get_choices(new_model_trace), + translator.new_observations) inverter = inverse(translator, prev_model_trace) prev_model_trace_rt, _ = run_transform(inverter, new_model_trace) check_round_trip(prev_model_trace, prev_model_trace_rt) @@ -710,9 +722,11 @@ Run the translator with: (output_trace, log_weight) = translator(input_trace; check=false, prev_observations=EmptyChoiceMap()) -Use `check` to enable a bijection check (this requires that the transform `f` has been paired with its inverse using [`pair_bijections!](@ref) or [`is_involution`](@ref)). +Use `check` to enable a bijection check (this requires that the transform `f` has been +paired with its inverse using [`pair_bijections!](@ref) or [`is_involution`](@ref)). -If `check` is enabled, then `prev_observations` is a choice map containing the observed random choices in the previous trace. +If `check` is enabled, then `prev_observations` is a choice map containing the observed +random choices in the previous trace. """ @with_kw struct GeneralTraceTranslator p_new::GenerativeFunction @@ -752,7 +766,8 @@ function (translator::GeneralTraceTranslator)( prev_model_trace::Trace; check=false, prev_observations=EmptyChoiceMap()) # sample auxiliary trace - forward_proposal_trace = simulate(translator.q_forward, (prev_model_trace, translator.q_forward_args...,)) + forward_proposal_trace = + simulate(translator.q_forward, (prev_model_trace, translator.q_forward_args...,)) # apply trace transform (new_model_trace, backward_proposal_trace, log_abs_determinant) = @@ -763,7 +778,8 @@ function (translator::GeneralTraceTranslator)( new_model_score = get_score(new_model_trace) forward_proposal_score = get_score(forward_proposal_trace) backward_proposal_score = get_score(backward_proposal_trace) - log_weight = new_model_score - prev_model_score + backward_proposal_score + forward_proposal_score + log_abs_determinant + log_weight = new_model_score - prev_model_score + + backward_proposal_score + forward_proposal_score + log_abs_determinant if check inverter = inverse(translator, prev_model_trace, prev_observations) @@ -805,7 +821,8 @@ end function (translator::SimpleExtendingTraceTranslator)(prev_model_trace::Trace) # simulate from auxiliary program - forward_proposal_trace = simulate(translator.q_forward, (prev_model_trace, translator.q_forward_args...,)) + forward_proposal_trace = + simulate(translator.q_forward, (prev_model_trace, translator.q_forward_args...,)) forward_proposal_score = get_score(forward_proposal_trace) # computing the new trace via update @@ -815,7 +832,7 @@ function (translator::SimpleExtendingTraceTranslator)(prev_model_trace::Trace) translator.p_argdiffs, constraints) if !isempty(discard) - @error("can only extend the trace with random choices, cannot remove random choices") + @error("Can only extend the trace with random choices, not remove them.") error("Invalid SimpleExtendingTraceTranslator") end @@ -835,15 +852,18 @@ end Constructor for a symmetric trace translator. -The involution is either constructed via the [`@transform`](@ref) macro (recommended), or can be provided as a Julia function. +The involution is either constructed via the [`@transform`](@ref) macro (recommended), +or can be provided as a Julia function. Run the translator with: (output_trace, log_weight) = translator(input_trace; check=false, observations=EmptyChoiceMap()) -Use `check` to enable the involution check (this requires that the transform `f` has been marked with [`is_involution`](@ref)). +Use `check` to enable the involution check (this requires that the transform `f` has been +marked with [`is_involution`](@ref)). -If `check` is enabled, then `observations` is a choice map containing the observed random choices, and the check will additionally ensure they are not mutated by the involution. +If `check` is enabled, then `observations` is a choice map containing the observed random +choices, and the check will additionally ensure they are not mutated by the involution. """ @with_kw struct SymmetricTraceTranslator{T <: Union{TraceTransformDSLProgram,Function}} q::GenerativeFunction @@ -874,7 +894,8 @@ function (translator::SymmetricTraceTranslator{TraceTransformDSLProgram})( prev_model_trace::Trace; check=false, observations=EmptyChoiceMap()) # simulate from auxiliary program - forward_proposal_trace = simulate(translator.q, (prev_model_trace, translator.q_args...,)) + forward_proposal_trace = + simulate(translator.q, (prev_model_trace, translator.q_args...,)) # apply trace transform (new_model_trace, backward_proposal_trace, log_abs_determinant) = @@ -885,7 +906,8 @@ function (translator::SymmetricTraceTranslator{TraceTransformDSLProgram})( new_model_score = get_score(new_model_trace) forward_proposal_score = get_score(forward_proposal_trace) backward_proposal_score = get_score(backward_proposal_trace) - log_weight = new_model_score - prev_model_score + backward_proposal_score - forward_proposal_score + log_abs_determinant + log_weight = new_model_score - prev_model_score + + backward_proposal_score - forward_proposal_score + log_abs_determinant if check check_observations(get_choices(new_model_trace), observations) @@ -907,7 +929,8 @@ function (translator::SymmetricTraceTranslator{<:Function})( forward_retval = get_retval(forward_trace) (new_model_trace, backward_choices, log_weight) = translator.involution( prev_model_trace, forward_choices, forward_retval, translator.q_args) - (backward_score, backward_retval) = assess(translator.q, (new_model_trace, translator.q_args...), backward_choices) + (backward_score, backward_retval) = + assess(translator.q, (new_model_trace, translator.q_args...), backward_choices) log_weight += (backward_score - forward_score) @@ -915,10 +938,10 @@ function (translator::SymmetricTraceTranslator{<:Function})( check_observations(get_choices(new_model_trace), observations) (prev_model_trace_rt, forward_choices_rt, _) = translator.involution( new_model_trace, backward_choices, backward_retval, translator.q_args) - (forward_trace_rt, _) = generate(translator.q, (prev_model_trace, translator.q_args...), forward_choices_rt) - check_round_trip( - prev_model_trace, prev_model_trace_rt, - forward_trace, forward_trace_rt) + (forward_trace_rt, _) = generate( + translator.q, (prev_model_trace, translator.q_args...), forward_choices_rt) + check_round_trip(prev_model_trace, prev_model_trace_rt, + forward_trace, forward_trace_rt) end return (new_model_trace, log_weight) From 14b24035fe1cdb49727aa068bc5bcb97a7da65c8 Mon Sep 17 00:00:00 2001 From: Xuan Date: Sat, 27 Mar 2021 13:47:27 -0400 Subject: [PATCH 12/12] Make trace translators mutable, subtypes of the TraceTranslator type. --- src/inference/trace_translators.jl | 33 +++++++++++++++++++++--------- 1 file changed, 23 insertions(+), 10 deletions(-) diff --git a/src/inference/trace_translators.jl b/src/inference/trace_translators.jl index a637bc1bf..73d8ff26f 100644 --- a/src/inference/trace_translators.jl +++ b/src/inference/trace_translators.jl @@ -632,6 +632,21 @@ function check_round_trip( return nothing end +################################ +# TraceTranslator # +################################ + +"Abstract type for trace translators." +abstract type TraceTranslator end + +""" + (new_trace, log_weight) = (translator::TraceTranslator)(trace) + +Apply a trace translator on an input trace, returning a new trace and an incremental +log weight. +""" +(translator::TraceTranslator)(trace::Trace; kwargs...) = error("Not implemented.") + ################################ # DeterministicTraceTranslator # ################################ @@ -648,7 +663,7 @@ Run the translator with: (output_trace, log_weight) = translator(input_trace) """ -@with_kw struct DeterministicTraceTranslator +@with_kw mutable struct DeterministicTraceTranslator <: TraceTranslator p_new::GenerativeFunction p_args::Tuple = () new_observations::ChoiceMap = EmptyChoiceMap() @@ -673,11 +688,6 @@ function run_transform(translator::DeterministicTraceTranslator, return (new_model_trace, log_abs_determinant) end -""" - (new_trace, log_weight) = (translator::DeterministicTraceTranslator)(trace) - -Apply a trace translator. -""" function (translator::DeterministicTraceTranslator)( prev_model_trace::Trace; check=false, prev_observations=EmptyChoiceMap()) @@ -728,7 +738,7 @@ paired with its inverse using [`pair_bijections!](@ref) or [`is_involution`](@re If `check` is enabled, then `prev_observations` is a choice map containing the observed random choices in the previous trace. """ -@with_kw struct GeneralTraceTranslator +@with_kw mutable struct GeneralTraceTranslator <: TraceTranslator p_new::GenerativeFunction p_new_args::Tuple = () new_observations::ChoiceMap = EmptyChoiceMap() @@ -810,7 +820,7 @@ Run the translator with: (output_trace, log_weight) = translator(input_trace) """ -@with_kw struct SimpleExtendingTraceTranslator +@with_kw mutable struct SimpleExtendingTraceTranslator <: TraceTranslator p_new_args::Tuple = () p_argdiffs::Tuple = () new_observations::ChoiceMap = EmptyChoiceMap() @@ -844,6 +854,8 @@ end # SymmetricTraceTranslator # ############################ +const TransformFunction = Union{TraceTransformDSLProgram,Function} + """ translator = SymmetricTraceTranslator(; q::GenerativeFunction, @@ -865,7 +877,7 @@ marked with [`is_involution`](@ref)). If `check` is enabled, then `observations` is a choice map containing the observed random choices, and the check will additionally ensure they are not mutated by the involution. """ -@with_kw struct SymmetricTraceTranslator{T <: Union{TraceTransformDSLProgram,Function}} +@with_kw mutable struct SymmetricTraceTranslator{T <: TransformFunction} <: TraceTranslator q::GenerativeFunction q_args::Tuple = () involution::T # an involution @@ -951,4 +963,5 @@ end export @transform export @read, @write, @copy, @tcall export TraceTransformDSLProgram, pair_bijections!, is_involution!, inverse -export DeterministicTraceTranslator, SymmetricTraceTranslator, SimpleExtendingTraceTranslator, GeneralTraceTranslator +export TraceTranslator, DeterministicTraceTranslator, SymmetricTraceTranslator, + SimpleExtendingTraceTranslator, GeneralTraceTranslator