Skip to content

Commit

Permalink
test for gaussian_bitblast_sample
Browse files Browse the repository at this point in the history
  • Loading branch information
PoorvaGarg committed Nov 28, 2024
1 parent 9d72d6e commit a9c7bf7
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 4 deletions.
7 changes: 7 additions & 0 deletions src/dist/number/fix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -688,6 +688,7 @@ A modified version of the function bitblast that works with the assumption of lo
function bitblast_sample(::Type{DistFix{W,F}}, dist::ContinuousUnivariateDistribution,
numpieces::Int, start::Float64, stop::Float64, offset::Float64, width::Float64) where {W,F}

@show start, stop, numpieces, offset, width
# count bits and pieces
@assert -(2^(W-F-1)) <= start < stop <= 2^(W-F-1)
f_range_bits = log2((stop - start)*2^F)
Expand All @@ -710,6 +711,8 @@ function bitblast_sample(::Type{DistFix{W,F}}, dist::ContinuousUnivariateDistrib
piece_probs[i] = 0
# Warning: A potential source of high runtime
for j=1:intervals_per_piece
@show firstinter + offset + width + (j-1)/2^F
@show firstinter + offset + (j-1)/2^F
piece_probs[i] += cdf(dist, firstinter + offset + width + (j-1)/2^F) - cdf(dist, firstinter + offset + (j-1)/2^F)
end
total_prob += piece_probs[i]
Expand All @@ -719,6 +722,10 @@ function bitblast_sample(::Type{DistFix{W,F}}, dist::ContinuousUnivariateDistrib
linear_piece_probs[i] = (border_probs[i][1] + border_probs[i][2])/2 * 2^(bits_per_piece)
end


@show piece_probs
@show total_prob

PieceChoice = DistUInt{max(1,Int(log2(numpieces)))}
piecechoice = discrete(PieceChoice, piece_probs ./ total_prob) # selector variable for pieces
slope_flips = Vector(undef, numpieces)
Expand Down
8 changes: 4 additions & 4 deletions src/util.jl
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,16 @@ The functions takes as input lower order bits for a gaussian distribution and re
"""
function gaussian_bitblast_sample(::Type{DistFix{W, F}}, mean::Float64, std::Float64, numpieces::Int64, start::Float64, stop::Float64, lsb::Vector{Bool}) where {W, F}
distribution = Normal(mean, std)
nbits = length(lsbs)
nbits = length(lsb)
DFiP = DistFix{W-nbits, F-nbits}
width = 1/2^F
offset = 0.0
for i in 1:nbits
if lsbs[i]
offset += 2^(F-nbits+i)
if lsb[i]
offset += 1/2^(F-nbits+i)
end
end

sub_gaussian = bitblast_sample(DFiP, distribution, numpieces, start, stop, offset, width)
DistFix{W, F}([sub_gaussian.mantissa.number.bits..., lsbs...])
DistFix{W, F}([sub_gaussian.mantissa.number.bits..., lsb...])
end
14 changes: 14 additions & 0 deletions test/util_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,3 +104,17 @@ end
└── Leaf
"""
end

# TODO: test for negative F
@testset "gaussian_bitblast_sample" begin
x1 = gaussian_bitblast_sample(DistFix{3, 1}, 0.0, 1.0, 2, -1.0, 1.0, [false])
x2 = gaussian_bitblast_sample(DistFix{3, 1}, 0.0, 1.0, 2, -1.0, 1.0, [true])
x = ifelse(flip(0.5), x1, x2)
p1 = pr(x)

y = bitblast(DistFix{3, 1}, Normal(0, 1), 4, -1.0, 1.0)
p2 = pr(y)
for i in keys(p1)
@test p1[i] p2[i]
end
end

0 comments on commit a9c7bf7

Please sign in to comment.