Skip to content

Commit

Permalink
Merge pull request #510 from ztangent/20230627-ztangent-fix_load_gene…
Browse files Browse the repository at this point in the history
…rated_functions

Re-implement get_schema using type params.
  • Loading branch information
alex-lew authored Sep 13, 2023
2 parents 3287746 + 6604c09 commit 2ee1d7a
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
34 changes: 14 additions & 20 deletions src/static_ir/trace.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,6 @@ struct StaticIRTraceAssmt{T} <: ChoiceMap
trace::T
end

function get_schema end

@inline get_address_schema(::Type{StaticIRTraceAssmt{T}}) where {T} = get_schema(T)

@inline Base.isempty(choices::StaticIRTraceAssmt) = isempty(choices.trace)
Expand Down Expand Up @@ -35,7 +33,11 @@ static_get_submap(::StaticIRTraceAssmt, ::Val) = EmptyChoiceMap()
# trace type generation #
#########################

abstract type StaticIRTrace <: Trace end
abstract type StaticIRTrace{T} <: Trace end

function get_schema(::Type{<:StaticIRTrace{T}}) where {T}
StaticAddressSchema(Set{Symbol}(T))
end

@inline function static_get_subtrace(trace::StaticIRTrace, addr)
error("Not implemented")
Expand Down Expand Up @@ -124,7 +126,11 @@ function generate_trace_struct(ir::StaticIR, trace_struct_name::Symbol, options:
mutable = false
fields = get_trace_fields(ir, options)
field_exprs = map((f) -> Expr(:(::), f.fieldname, f.typ), fields)
Expr(:struct, mutable, Expr(:(<:), trace_struct_name, QuoteNode(StaticIRTrace)),
choice_addrs = [node.addr for node in ir.choice_nodes]
call_addrs = [node.addr for node in ir.call_nodes]
addrs = Tuple(vcat(choice_addrs, call_addrs))
parent_type = Expr(:curly, QuoteNode(StaticIRTrace), addrs)
Expr(:struct, mutable, Expr(:(<:), trace_struct_name, parent_type),
Expr(:block, field_exprs..., Expr(:(::), static_ir_gen_fn_ref, QuoteNode(Any))))
end

Expand Down Expand Up @@ -271,17 +277,6 @@ function generate_static_get_submap(ir::StaticIR, trace_struct_name::Symbol)
methods
end

function generate_get_schema(ir::StaticIR, trace_struct_name::Symbol)
choice_addrs = [QuoteNode(node.addr) for node in ir.choice_nodes]
call_addrs = [QuoteNode(node.addr) for node in ir.call_nodes]
addrs = vcat(choice_addrs, call_addrs)
Expr(:function,
Expr(:call, GlobalRef(Gen, :get_schema), :(::Type{$trace_struct_name})),
Expr(:block,
:($(QuoteNode(StaticAddressSchema))(
Set{Symbol}([$(addrs...)])))))
end

function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::StaticIRGenerativeFunctionOptions)
trace_struct_name = gensym("StaticIRTrace_$name")
trace_struct_expr = generate_trace_struct(ir, trace_struct_name, options)
Expand All @@ -290,7 +285,6 @@ function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::St
get_args_expr = generate_get_args(ir, trace_struct_name)
get_retval_expr = generate_get_retval(ir, trace_struct_name)
get_choices_expr = generate_get_choices(trace_struct_name)
get_schema_expr = generate_get_schema(ir, trace_struct_name)
get_values_shallow_expr = generate_get_values_shallow(ir, trace_struct_name)
get_submaps_shallow_expr = generate_get_submaps_shallow(ir, trace_struct_name)
static_get_value_exprs = generate_static_get_value(ir, trace_struct_name)
Expand All @@ -299,10 +293,10 @@ function generate_trace_type_and_methods(ir::StaticIR, name::Symbol, options::St
getindex_exprs = generate_getindex(ir, trace_struct_name)

exprs = Expr(:block, trace_struct_expr, isempty_expr, get_score_expr,
get_args_expr, get_retval_expr,
get_choices_expr, get_schema_expr, get_values_shallow_expr,
get_submaps_shallow_expr, static_get_value_exprs...,
static_has_value_exprs..., static_get_submap_exprs..., getindex_exprs...)
get_args_expr, get_retval_expr, get_choices_expr,
get_values_shallow_expr, get_submaps_shallow_expr,
static_get_value_exprs..., static_has_value_exprs...,
static_get_submap_exprs..., getindex_exprs...)
(exprs, trace_struct_name)
end

Expand Down
12 changes: 12 additions & 0 deletions test/dsl/static_dsl.jl
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,18 @@ ch = get_choices(tr)
@test length(get_values_shallow(ch)) == 1
@test length(get_submaps_shallow(ch)) == 1

@gen (static) function baz(trace)
x ~ normal(trace[:x], 0.1)
return x
end

ch, w, rval = propose(baz, (tr,))
@test has_value(ch, :x)
@test ch[:x] == rval

new_tr, _ = generate(bar1, (), ch)
@test new_tr[:x] == ch[:x]

end

@testset "returning a SML function from macro" begin
Expand Down

0 comments on commit 2ee1d7a

Please sign in to comment.