You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
# Probabilistic value from zero to one. TODO: use DistFix
using Revise
using Dice
using Random
# Finding accuracy errors in an approximate inverse square root function# Current output of this program (nondeterministic):# 15.127460 seconds (141.17 M allocations: 6.819 GiB, 6.44% gc time, 19.28% compilation time)# Initial p(bugfound) = 0.011# Trained p(bugfound) = 0.99################################################################################# HYPERPARAMETERS AND CONFIG################################################################################# How many steps of Newton's method to take. Increase to make the approximate fn# more accurate
STEPS =5# Precision of the fixed point number generator. Increase to be able to generate# numbers closer to 0
PREC=10# We consider a bug to be found if we get relative error BUG_THRESH# Note: relative error is from approx_inv_sqrt(x)^2 to 1/x. This prevents us# from needing a preexisting implementation of sqrt.# Maybe this helps the story, but it also makes "relative error" a bit harder to# explain so maybe we should just use `sqrt`, and measure error from# approx_inv_sqrt(x) to 1/sqrt(x).
BUG_THRESH =0.01# Training
NUM_EPOCHS =100
NUM_SAMPLES =1000
LR =0.03################################################################################# DistZeroToOne################################################################################# Probabilistic value from zero to one. TODO: use DistFiximport Dice: tobits, frombits, prob_equals
struct DistZeroToOne <:Dist{Any}
mantissa::DistUIntendtobits(x::DistZeroToOne) =tobits(x.mantissa)
frombits(x::DistZeroToOne, world) =float(frombits(x.mantissa, world)) /2^float(length(x.mantissa.bits))
DistZeroToOne(x::Float64, W) =DistZeroToOne(DistUInt{W}(Int(x *2^W)))
prob_equals(x::DistZeroToOne, y::DistZeroToOne) =prob_equals(x.mantissa, y.mantissa)
################################################################################# Approximate inverse square root and its error################################################################################# Thanks to REINFORCE, none of these functions need be implemented in Dice# Approximate sqrt(x) by Newton's methodfunctionapprox_sqrt(x, steps)
guess =1for _ in1:steps
guess =1/2* (guess + x/guess)
end
guess
endapprox_inv_sqrt(x, steps) =1/approx_sqrt(x, steps)
rel_error(actual, expected) =abs((actual - expected) / expected)
# Note we don't need an "oracle" (an existing `sqrt`` function) to target error!# We instead compare the squared approx-inverse-sqrt to the inverse.# Maybe this helps the story, but it also makes "relative error" a bit harder to# explain so maybe we should just use `sqrt`.approx_inv_sqrt_error(x, steps) =rel_error(approx_inv_sqrt(x, steps)^2, 1/x)
################################################################################# Generator################################################################################
var_vals =Valuation()
adnodes_of_interest =Dict{String, ADNode}()
functionregister_weight!(s)
var =Var("$(s)_before_sigmoid")
var_vals[var] =0
weight =sigmoid(var)
adnodes_of_interest[s] = weight
weight
end# Uniform from 0 to 1
g =DistZeroToOne(DistUInt{PREC}([
flip(register_weight!("x_$(i)"))
for i in1:PREC
]))
################################################################################# Train to maximize error################################################################################
history_err = []
history_p_bugfound = []
@timefor epoch in1:NUM_EPOCHS
samples = Dice.with_concrete_ad_flips(var_vals, g) do
[sample(Random.default_rng(), g) for _ in1:NUM_SAMPLES]
end
l = Dice.LogPrExpander(WMC(BDDCompiler([
prob_equals(g, DistZeroToOne(sample, PREC))
for sample in samples
])))
loss, total_dist, valid_samples, samples_finding_bug =sum(
begin
lpr_eq = Dice.expand_logprs(l, LogPr(prob_equals(g, DistZeroToOne(sample, PREC))))
dist =approx_inv_sqrt_error(sample, STEPS)
ifisnan(dist)
[Dice.Constant(0), 0, 0, 0]
else
[-lpr_eq * dist, dist, 1, if dist > BUG_THRESH 1else0end]
endendfor sample in samples
)
push!(history_err, total_dist / valid_samples)
push!(history_p_bugfound, samples_finding_bug / NUM_SAMPLES)
vals, derivs =differentiate(
var_vals,
Derivs(loss =>1)
)
for (adnode, d) in derivs
if adnode isa Var
var_vals[adnode] -= d * LR
endendendprintln("Initial p(bugfound) = $(first(history_p_bugfound))")
println("Trained p(bugfound) = $(last(history_p_bugfound))")
The text was updated successfully, but these errors were encountered:
###############################################################################
Thanks to REINFORCE, none of these functions need be implemented in Dice
We instead compare the squared approx-inverse-sqrt to the inverse.
Maybe this helps the story, but it also makes "relative error" a bit harder to
explain so maybe we should just use
sqrt
.###############################################################################
###############################################################################
Dice.jl/examples/qc/invsqrt.jl
Line 41 in 50aad17
The text was updated successfully, but these errors were encountered: