You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# The bridge between autodiff and cuddexport LogPr, compute_mixed, train!, pr_mixed, support_mixed, with_concrete_ad_flips
mutable struct LogPr <:ADNode
bool::Dist{Bool}endNodeType(::Type{LogPr}) =Leaf()
compute_leaf(::LogPr) =error("LogPr must be expanded")
backward(::LogPr, _, _) =error("LogPr must be expanded")
mutable struct LogPrExpander
w::WMC
cache::Dict{ADNode, ADNode}functionLogPrExpander(w)
new(w, Dict{ADNode, ADNode}())
endendfunctionexpand_logprs(l::LogPrExpander, root::ADNode)::ADNodefl(x::LogPr) =expand_logprs(l, logprob(l.w, x.bool))
fl(x::Var) = x
fl(x::Constant) = x
fi(x::Add, call) =Add(call(x.x), call(x.y))
fi(x::Mul, call) =Mul(call(x.x), call(x.y))
fi(x::Pow, call) =Pow(call(x.x), call(x.y))
fi(x::Sin, call) =Sin(call(x.x))
fi(x::Cos, call) =Cos(call(x.x))
fi(x::Log, call) =Log(call(x.x))
fi(x::ADMatrix, call) =ADMatrix(map(call, x.x))
fi(x::GetIndex, call) =GetIndex(call(x.x), x.i)
fi(x::Map, call) =Map(x.f, x.f′, call(x.x))
fi(x::Transpose, call) =Transpose(call(x.x))
fi(x::NodeLogPr, call) =NodeLogPr(call(x.pr), call(x.hi), call(x.lo))
foldup(root, fl, fi, ADNode, l.cache)
end# Within roots' LogPrs there are Dist{Bool} DAGs. Collect minimal roots all DAGsfunctionbool_roots(roots)
#TODO: have non-root removal be done in src/inference/cudd/compile.jl
seen_adnodes =Dict{ADNode, Nothing}()
seen_bools =Dict{AnyBool, Nothing}()
non_roots =Set{AnyBool}()
to_visit =Vector{ADNode}(roots)
while!isempty(to_visit)
x =pop!(to_visit)
foreach(x, seen_adnodes) do y
if y isa LogPr
foreach(y.bool, seen_bools) do bool
union!(non_roots, children(bool))
if bool isa Flip && bool.prob isa ADNode &&!haskey(seen_adnodes, bool.prob)
push!(to_visit, bool.prob)
endendendendendsetdiff(keys(seen_bools), non_roots)
endfunctioncompute_mixed(var_vals::Valuation, root::ADNode)
compute_mixed(var_vals, [root])[root]
endfunctioncompute_mixed(var_vals::Valuation, roots)
l =LogPrExpander(WMC(BDDCompiler(bool_roots(roots))))
expanded_roots = [expand_logprs(l, x) for x in roots]
vals =compute(var_vals, expanded_roots)
Dict(root => vals[l.cache[root]] for root in roots)
endfunctiontrain!(
var_vals::Valuation,
loss::ADNode;
epochs::Integer,
learning_rate::Real,
)
losses = []
l =LogPrExpander(WMC(BDDCompiler(bool_roots([loss]))))
loss =expand_logprs(l, loss)
for _ in1:epochs
vals, derivs =differentiate(var_vals, Derivs(loss =>1))
# update varsfor (adnode, d) in derivs
if adnode isa Var
var_vals[adnode] -= d * learning_rate
endendpush!(losses, vals[loss])
endpush!(losses, compute_mixed(var_vals, loss))
losses
endfunctioncollect_flips(bools)
flips =Vector{Flip}()
foreach_down(bools) do x
x isa Flip &&push!(flips, x)
end
flips
endfunctionwith_concrete_ad_flips(f, var_vals, dist)
flip_to_original_prob =Dict()
a =ADComputer(var_vals)
l =LogPrExpander(WMC(BDDCompiler()))
for x incollect_flips(tobits(dist))
if x.prob isa ADNode
flip_to_original_prob[x] = x.prob
x.prob =compute(a, expand_logprs(l, x.prob))
endend
res =f()
for (x, prob) in flip_to_original_prob
x.prob = prob
end
res
endfunctionpr_mixed(var_vals)
(args...; kwargs...) ->with_concrete_ad_flips(var_vals, args...) dopr(args...; kwargs...)
endendfunctionsupport_mixed(dist)
flip_to_original_prob =Dict()
for x incollect_flips(tobits(dist))
if x.prob isa ADNode
flip_to_original_prob[x] = x.prob
x.prob =0.5endend
res =keys(pr(dist))
for (x, prob) in flip_to_original_prob
x.prob = prob
end
res
end
The text was updated successfully, but these errors were encountered:
https://github.com/Juice-jl/Dice.jl/blob/d80e01a7dd7019eb8daa543199b6276e7682dbc8/src/autodiff_pr/train.jl#L39
The text was updated successfully, but these errors were encountered: