Skip to content

Commit

Permalink
benchmark training and sample from bdd
Browse files Browse the repository at this point in the history
  • Loading branch information
rtjoa committed Nov 1, 2024
1 parent 34f7767 commit daa0fc5
Show file tree
Hide file tree
Showing 7 changed files with 458 additions and 5 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ version = "0.1.0"

[deps]
ADEV = "91c67158-5de4-465b-a572-6ca3a628f939"
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CUDD = "345a2cc7-28d8-58b2-abdf-cff77ea7d7f1"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Expand All @@ -18,6 +19,7 @@ Jive = "ba5e3d4b-8524-549f-bc71-e76ad9e9deed"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
ProfileView = "c46f51b8-102a-5cf2-8d2c-8597cb0e0da7"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Expand Down
8 changes: 8 additions & 0 deletions qc/benchmarks/benchmarks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,14 @@ function wellTyped(e::OptExpr.t)
None() -> false,
]
end
function wellTyped(e::Expr.t)
@assert isdeterministic(e)
@match typecheck(e) [
Some(_) -> true,
None() -> false,
]
end


##################################
# Sampling STLC entropy loss
Expand Down
152 changes: 152 additions & 0 deletions qc/benchmarks/rbt_faster.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
using Revise
using Dice
include("benchmarks.jl")

generation_params = LangSiblingDerivedGenerator{RBT}(
root_ty=ColorKVTree.t,
ty_sizes=[ColorKVTree.t=>4, Color.t=>0],
stack_size=2,
intwidth=3,
)

SEED = 0
out_dir="/tmp"
log_path="/dev/null"
rs = RunState(Valuation(), Dict{String,ADNode}(), open(log_path, "w"), out_dir, MersenneTwister(SEED), nothing,generation_params)

generation::Generation = generate(rs, generation_params)

g::Dist = generation.value

# Assignments
# rs.var_vals

# Distribution of constructors of root node:
pr_mixed(rs.var_vals)(g.union.which)

# Sample some tree until it's valid (TODO: make this faster)
a = ADComputer(rs.var_vals)
isRBT(t) = satisfies_bookkeeping_invariant(t) && satisfies_balance_invariant(t) && satisfies_order_invariant(t)
using BenchmarkTools

@benchmark begin
samples = []
while length(samples) < 200
some_tree = sample_as_dist(rs.rng, a, g)
if isRBT(some_tree)
push!(samples, some_tree)
end
end
end

# one sample
# BenchmarkTools.Trial: 1683 samples with 1 evaluation.
# Range (min … max): 1.789 ms … 29.207 ms ┊ GC (min … max): 0.00% … 77.88%
# Time (median): 2.012 ms ┊ GC (median): 0.00%
# Time (mean ± σ): 2.895 ms ± 2.100 ms ┊ GC (mean ± σ): 4.64% ± 7.26%

# █▇▅▃ ▃▅▃▂▃▁ ▁▂▂▁
# █████▆▄▁▁▁██████▇▆▁▄▆████▇▇▇▅▇▄▇▇▅▆▆▅▅▇▄▄▅▅▅▆▄▆▄▁▁▄▄▅▁▄▄▁▄ █
# 1.79 ms Histogram: log(frequency) by time 8.87 ms <

# Memory estimate: 759.81 KiB, allocs estimate: 19182.

# 200 samples
# BenchmarkTools.Trial: 9 samples with 1 evaluation.
# Range (min … max): 551.427 ms … 637.939 ms ┊ GC (min … max): 3.72% … 6.07%
# Time (median): 571.511 ms ┊ GC (median): 3.63%
# Time (mean ± σ): 577.534 ms ± 29.908 ms ┊ GC (mean ± σ): 4.56% ± 1.59%

# █ ▃
# █▁▁▁▁▇▁▁▁▁▁▁▁▁█▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇▁▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇ ▁
# 551 ms Histogram: frequency by time 638 ms <

# Memory estimate: 207.46 MiB, allocs estimate: 5280362.

some_tree

# .551 * 1000 / 60 ~= 9 minutes on sampling


# Every other epoch, we spend 1/2 a second taking ~300 samples in order to get
# exactly 200 samples that meet the spec.

# "smart conditional sampling" saves at most 2/9 of runtime for RBT

# time per epoch: ~.25

retries = 0
samples = []
while length(samples) < 200
retries +=1
some_tree = sample_as_dist(rs.rng, a, g)
if isRBT(some_tree)
push!(samples, some_tree)
end
end

retries # 321 samples taken


l = Dice.LogPrExpander(WMC(BDDCompiler([
prob_equals(g,sample)
for sample in samples
])))

num_meeting = 0
@time begin
loss, actual_loss = sum(
begin
lpr_eq = Dice.expand_logprs(l, LogPr(prob_equals(g, sample)))
[lpr_eq * compute(a, lpr_eq), lpr_eq]
end
for sample in samples
)
end
# 1.74s on first run, ~.5 seconds on later runs

length(l.cache) # 935

# 0.165 seconds first time


@benchmark vals, derivs = differentiate(
rs.var_vals,
Derivs([loss => 1.])
)

# BenchmarkTools.Trial: 1867 samples with 1 evaluation.
# Range (min … max): 2.441 ms … 23.635 ms ┊ GC (min … max): 0.00% … 88.88%
# Time (median): 2.544 ms ┊ GC (median): 0.00%
# Time (mean ± σ): 2.634 ms ± 1.110 ms ┊ GC (mean ± σ): 2.54% ± 5.30%

# ▁▁▁▅▆▆▆▇█▇▅▇▂▂
# ▃▄▆████████████████▇▇▆▆▆▅▅▄▄▄▄▃▄▃▃▃▃▃▃▃▃▃▃▃▃▃▂▃▃▃▃▂▂▂▁▂▂▂▂ ▄
# 2.44 ms Histogram: frequency by time 2.9 ms <

# Memory estimate: 635.62 KiB, allocs estimate: 19618.

ct = [0]
Dice.foreach_down(loss) do _
ct[1] += 1
end
ct # 1334

p_eq_g = prob_equals(some_tree, g)
to_maximize::Dice.ADNode = LogPr(p_eq_g)
using ProfileView

pr_mixed(rs.var_vals)(p_eq_g)

l = Dice.LogPrExpander(WMC(BDDCompiler(Dice.bool_roots([to_maximize]))))
to_maximize_expanded = Dice.expand_logprs(l, to_maximize)

using ProfileView

ProfileView.@profview begin
vals, derivs = Dice.differentiate(
rs.var_vals,
Derivs(to_maximize_expanded => 1.)
)
end

87 changes: 87 additions & 0 deletions qc/benchmarks/stlc_faster_10samples.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
using Revise
using Dice
using BenchmarkTools
using ProfileView

include("benchmarks.jl")

generation_params = LangSiblingDerivedGenerator{STLC}(
root_ty=Expr.t,
ty_sizes=[Expr.t=>5, Typ.t=>2],
stack_size=2,
intwidth=3,
)

SEED = 0
out_dir="/tmp"
log_path="/dev/null"
rs = RunState(Valuation(), Dict{String,ADNode}(), open(log_path, "w"), out_dir, MersenneTwister(SEED), nothing,generation_params)

generation::Generation = generate(rs, generation_params)

g::Dist = generation.value

# Sample some tree until it's valid (TODO: make this faster)
a = ADComputer(rs.var_vals)

NUM_SAMPLES = 10

function wellTyped(e)
@assert isdeterministic(e)
@match typecheck(e) [
Some(_) -> true,
None() -> false,
]
end

retries = Ref(0)
#== @benchmark ==# @time begin
samples = []
while length(samples) < NUM_SAMPLES
retries[] += 1
s = sample_as_dist(rs.rng, a, g)
if wellTyped(s)
push!(samples, s)
end
end
end
# Single result which took 26.426 s (3.00% GC) to evaluate, (7s, 26s, 30s, 40s)
# with a memory estimate of 388.02 MiB, over 8512429 allocations.
retries[] # 30

l = Dice.LogPrExpander(WMC(BDDCompiler([
prob_equals(g, sample)
for sample in samples
])))
@time begin
loss, actual_loss = sum(
begin
lpr_eq = Dice.expand_logprs(l, LogPr(prob_equals(g, sample)))
[lpr_eq * compute(a, lpr_eq), lpr_eq]
end
for sample in samples
)
end
# 5.3s first run, 1.4s rest

length(l.cache) # 331

@benchmark vals, derivs = differentiate(
rs.var_vals,
Derivs([loss => 1.])
)
# BenchmarkTools.Trial: 1060 samples with 1 evaluation.
# Range (min … max): 2.029 ms … 137.030 ms ┊ GC (min … max): 0.00% … 98.14%
# Time (median): 2.879 ms ┊ GC (median): 0.00%
# Time (mean ± σ): 4.377 ms ± 6.119 ms ┊ GC (mean ± σ): 4.36% ± 4.07%

# ██▇▆▅▃▃▂▃▁▁▂▂ ▁
# ██████████████████▅▇▆▄▆▄▆▇▄▆▄▅▁▆▁▅▇▁▄▄▁▁▁▄▁▁▅▄▆▁▄▄▁▁▁▄▁▁▁▅▅ █
# 2.03 ms Histogram: log(frequency) by time 22.6 ms <

# Memory estimate: 292.17 KiB, allocs estimate: 8034.

ct = Ref(0)
Dice.foreach_down(loss) do _ ct[] += 1 end
ct[] # 350

85 changes: 80 additions & 5 deletions qc/benchmarks/stlc_faster_200samples.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,21 @@
# We found sampling from the BDD, like sampling from the computation graph, also took ~160s for 200 well-typed samples.
# OTHER TIMINGS IN THIS FILE ARE WRONG
# mainly we care about `sample_one_as_dist_compile` in this file

using Revise
using Dice
using BenchmarkTools
using ProfileView

function comp_graph_size(roots)
cmp_graph_ct = Ref(0)
Dice.foreach_down(roots) do _
cmp_graph_ct[] += 1
end
cmp_graph_ct[] # 2040
end


include("benchmarks.jl")

generation_params = LangSiblingDerivedGenerator{STLC}(
Expand All @@ -26,6 +39,42 @@ a = ADComputer(rs.var_vals)

NUM_SAMPLES = 200

function sample_one_as_dist_compile(c::BDDCompiler, a::ADComputer, d::Dist, roots)
# State for one sampling
bdd_node_to_tf = Dict{CuddNode,Bool}()
level_to_tf = Dict{Integer, Bool}()
bdd_node_to_tf[Dice.constant(c.mgr, true)] = true
bdd_node_to_tf[Dice.constant(c.mgr, false)] = false

function sample_level(c, level::Integer)
get!(level_to_tf, level) do
prob = compute(a, c.level_to_flip[level].prob)
rand() < prob
end
end

function sample_one(c, bdd_node_to_tf, x::AnyBool)
sample_one(c, bdd_node_to_tf, compile(c, x))
end

function sample_one(c, bdd_node_to_tf, x::CuddNode)
get!(bdd_node_to_tf, x) do
if sample_level(c, Dice.level(x))
sample_one(c, bdd_node_to_tf, Dice.high(x))
else
sample_one(c, bdd_node_to_tf, Dice.low(x))
end
end
end

bits = Dict()
for root in roots
bits[root] = sample_one(c, bdd_node_to_tf, root)
end
Dice.frombits_as_dist(d, bits)
end


function wellTyped(e)
@assert isdeterministic(e)
@match typecheck(e) [
Expand All @@ -37,22 +86,48 @@ end
retries = Ref(0)
#== @benchmark ==# @time begin
samples = []
d = g
roots = Dice.tobits(d)
c = BDDCompiler(roots)
a = Dice.ADComputer(rs.var_vals)
while length(samples) < NUM_SAMPLES
retries[] += 1
s = sample_as_dist(rs.rng, a, g)
s = sample_one_as_dist_compile(c, a, d, roots)
if wellTyped(s)
push!(samples, s)
end
end
end
# 174s, 155s
# 174s, 155s, 281
retries[] # 607, 556

@time l = Dice.LogPrExpander(WMC(BDDCompiler([
@time eqs = [
prob_equals(g, sample)
for sample in samples
])))
# 32s, 36s
]
# 27 sec

comp_graph_size(eqs) # 2040
comp_graph_size(Dice.tobits(g)) # 16825

# @benchmark prob_equals(g, samples[1])
# BenchmarkTools.Trial: 24 samples with 1 evaluation.
# Range (min … max): 104.068 ms … 592.294 ms ┊ GC (min … max): 0.00% … 0.00%
# Time (median): 168.251 ms ┊ GC (median): 0.00%
# Time (mean ± σ): 186.669 ms ± 97.022 ms ┊ GC (mean ± σ): 0.00% ± 0.00%
# ▃▃█ ▃▃▃ ▃▃
# ▇███▁▁███▇██▁▇▇▁▁▇▁▁▁▁▁▇▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇ ▁
# 104 ms Histogram: frequency by time 592 ms <
# Memory estimate: 6.51 MiB, allocs estimate: 235681.

@time c = BDDCompiler(eqs)
# 0.28553 s, 0.23, 0.1

@time w = WMC(c)
# instant

@time l = Dice.LogPrExpander(w)
# instant

@time loss, actual_loss = sum(
begin
Expand Down
Loading

0 comments on commit daa0fc5

Please sign in to comment.