You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# Sample some tree until it's valid (TODO: make this faster)
# We found sampling from the BDD, like sampling from the computation graph, also took ~160s for 200 well-typed samples.# OTHER TIMINGS IN THIS FILE ARE WRONG# mainly we care about `sample_one_as_dist_compile` in this fileusing Revise
using Dice
using BenchmarkTools
using ProfileView
functioncomp_graph_size(roots)
cmp_graph_ct =Ref(0)
Dice.foreach_down(roots) do _
cmp_graph_ct[] +=1end
cmp_graph_ct[] # 2040endinclude("benchmarks.jl")
generation_params =LangSiblingDerivedGenerator{STLC}(
root_ty=Expr.t,
ty_sizes=[Expr.t=>5, Typ.t=>2],
stack_size=2,
intwidth=3,
)
SEED =0
out_dir="/tmp"
log_path="/dev/null"
rs =RunState(Valuation(), Dict{String,ADNode}(), open(log_path, "w"), out_dir, MersenneTwister(SEED), nothing,generation_params)
generation::Generation=generate(rs, generation_params)
g::Dist= generation.value
# Sample some tree until it's valid (TODO: make this faster)
a =ADComputer(rs.var_vals)
NUM_SAMPLES =200functionsample_one_as_dist_compile(c::BDDCompiler, a::ADComputer, d::Dist, roots)
# State for one sampling
bdd_node_to_tf =Dict{CuddNode,Bool}()
level_to_tf =Dict{Integer, Bool}()
bdd_node_to_tf[Dice.constant(c.mgr, true)] =true
bdd_node_to_tf[Dice.constant(c.mgr, false)] =falsefunctionsample_level(c, level::Integer)
get!(level_to_tf, level) do
prob =compute(a, c.level_to_flip[level].prob)
rand() < prob
endendfunctionsample_one(c, bdd_node_to_tf, x::AnyBool)
sample_one(c, bdd_node_to_tf, compile(c, x))
endfunctionsample_one(c, bdd_node_to_tf, x::CuddNode)
get!(bdd_node_to_tf, x) doifsample_level(c, Dice.level(x))
sample_one(c, bdd_node_to_tf, Dice.high(x))
elsesample_one(c, bdd_node_to_tf, Dice.low(x))
endendend
bits =Dict()
for root in roots
bits[root] =sample_one(c, bdd_node_to_tf, root)
end
Dice.frombits_as_dist(d, bits)
endfunctionwellTyped(e)
@assertisdeterministic(e)
@matchtypecheck(e) [
Some(_) ->true,
None() ->false,
]
end
retries =Ref(0)
#== @benchmark ==#@timebegin
samples = []
d = g
roots = Dice.tobits(d)
c =BDDCompiler(roots)
a = Dice.ADComputer(rs.var_vals)
whilelength(samples) < NUM_SAMPLES
retries[] +=1
s =sample_one_as_dist_compile(c, a, d, roots)
ifwellTyped(s)
push!(samples, s)
endendend
The text was updated successfully, but these errors were encountered:
Dice.jl/qc/benchmarks/stlc_faster_200samples_from_bdd.jl
Line 37 in daa0fc5
The text was updated successfully, but these errors were encountered: