Skip to content

Commit

Permalink
Make trace translators mutable, subtypes of the TraceTranslator type.
Browse files Browse the repository at this point in the history
  • Loading branch information
ztangent committed Mar 27, 2021
1 parent 5f5097d commit 14b2403
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions src/inference/trace_translators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,21 @@ function check_round_trip(
return nothing
end

################################
# TraceTranslator #
################################

"Abstract type for trace translators."
abstract type TraceTranslator end

"""
(new_trace, log_weight) = (translator::TraceTranslator)(trace)
Apply a trace translator on an input trace, returning a new trace and an incremental
log weight.
"""
(translator::TraceTranslator)(trace::Trace; kwargs...) = error("Not implemented.")

################################
# DeterministicTraceTranslator #
################################
Expand All @@ -648,7 +663,7 @@ Run the translator with:
(output_trace, log_weight) = translator(input_trace)
"""
@with_kw struct DeterministicTraceTranslator
@with_kw mutable struct DeterministicTraceTranslator <: TraceTranslator
p_new::GenerativeFunction
p_args::Tuple = ()
new_observations::ChoiceMap = EmptyChoiceMap()
Expand All @@ -673,11 +688,6 @@ function run_transform(translator::DeterministicTraceTranslator,
return (new_model_trace, log_abs_determinant)
end

"""
(new_trace, log_weight) = (translator::DeterministicTraceTranslator)(trace)
Apply a trace translator.
"""
function (translator::DeterministicTraceTranslator)(
prev_model_trace::Trace; check=false, prev_observations=EmptyChoiceMap())

Expand Down Expand Up @@ -728,7 +738,7 @@ paired with its inverse using [`pair_bijections!](@ref) or [`is_involution`](@re
If `check` is enabled, then `prev_observations` is a choice map containing the observed
random choices in the previous trace.
"""
@with_kw struct GeneralTraceTranslator
@with_kw mutable struct GeneralTraceTranslator <: TraceTranslator
p_new::GenerativeFunction
p_new_args::Tuple = ()
new_observations::ChoiceMap = EmptyChoiceMap()
Expand Down Expand Up @@ -810,7 +820,7 @@ Run the translator with:
(output_trace, log_weight) = translator(input_trace)
"""
@with_kw struct SimpleExtendingTraceTranslator
@with_kw mutable struct SimpleExtendingTraceTranslator <: TraceTranslator
p_new_args::Tuple = ()
p_argdiffs::Tuple = ()
new_observations::ChoiceMap = EmptyChoiceMap()
Expand Down Expand Up @@ -844,6 +854,8 @@ end
# SymmetricTraceTranslator #
############################

const TransformFunction = Union{TraceTransformDSLProgram,Function}

"""
translator = SymmetricTraceTranslator(;
q::GenerativeFunction,
Expand All @@ -865,7 +877,7 @@ marked with [`is_involution`](@ref)).
If `check` is enabled, then `observations` is a choice map containing the observed random
choices, and the check will additionally ensure they are not mutated by the involution.
"""
@with_kw struct SymmetricTraceTranslator{T <: Union{TraceTransformDSLProgram,Function}}
@with_kw mutable struct SymmetricTraceTranslator{T <: TransformFunction} <: TraceTranslator
q::GenerativeFunction
q_args::Tuple = ()
involution::T # an involution
Expand Down Expand Up @@ -951,4 +963,5 @@ end
export @transform
export @read, @write, @copy, @tcall
export TraceTransformDSLProgram, pair_bijections!, is_involution!, inverse
export DeterministicTraceTranslator, SymmetricTraceTranslator, SimpleExtendingTraceTranslator, GeneralTraceTranslator
export TraceTranslator, DeterministicTraceTranslator, SymmetricTraceTranslator,
SimpleExtendingTraceTranslator, GeneralTraceTranslator

0 comments on commit 14b2403

Please sign in to comment.