Skip to content

Commit

Permalink
negative F
Browse files Browse the repository at this point in the history
  • Loading branch information
PoorvaGarg committed Nov 29, 2024
1 parent a9c7bf7 commit 9ddeb86
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 14 deletions.
20 changes: 10 additions & 10 deletions src/dist/number/fix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ function DistFix{W, F}(b::AbstractVector) where {W,F}
end

function DistFix{W, F}(x::Float64) where {W,F}
mantissa = DistInt{W}(floor(Int, x*2^F))
mantissa = DistInt{W}(floor(Int, x*2^float(F)))
DistFix{W, F}(mantissa)
end

Expand Down Expand Up @@ -63,7 +63,7 @@ end
tobits(x::DistFix) = tobits(x.mantissa)

function frombits(x::DistFix{W, F}, world) where {W,F}
frombits(x.mantissa, world)/2^F
frombits(x.mantissa, world)/2^float(F)
end

function expectation(x::DistFix{W, F}; kwargs...) where {W,F}
Expand Down Expand Up @@ -125,7 +125,7 @@ function bitblast(::Type{DistFix{W,F}}, dist::ContinuousUnivariateDistribution,

# count bits and pieces
@assert -(2^(W-F-1)) <= start < stop <= 2^(W-F-1)
f_range_bits = log2((stop - start)*2^F)
f_range_bits = log2((stop - start)*2^float(F))
@assert isinteger(f_range_bits) "The number of $(1/2^F)-sized intervals between $start and $stop must be a power of two (not $f_range_bits)."
@assert ispow2(numpieces) "Number of pieces must be a power of two (not $numpieces)"
intervals_per_piece = (2^Int(f_range_bits))/numpieces
Expand All @@ -140,14 +140,15 @@ function bitblast(::Type{DistFix{W,F}}, dist::ContinuousUnivariateDistribution,
linear_piece_probs = Vector(undef, numpieces) # prob of each piece if it were linear between end points

for i=1:numpieces
firstinter = start + (i-1)*intervals_per_piece/2^F
lastinter = start + (i)*intervals_per_piece/2^F
firstinter = start + (i-1)*intervals_per_piece/2^float(F)
lastinter = start + (i)*intervals_per_piece/2^float(F)

piece_probs[i] = (cdf(dist, lastinter) - cdf(dist, firstinter))
total_prob += piece_probs[i]

border_probs[i] = [cdf(dist, firstinter + 1/2^F ) - cdf(dist, firstinter),
cdf(dist, lastinter) - cdf(dist, lastinter - 1/2^F )]
border_probs[i] = [cdf(dist, firstinter + 1/2^float(F) ) - cdf(dist, firstinter),
cdf(dist, lastinter) - cdf(dist, lastinter - 1/2^float(F) )]
@show bits_per_piece
linear_piece_probs[i] = (border_probs[i][1] + border_probs[i][2])/2 * 2^(bits_per_piece)
end

Expand All @@ -172,8 +173,8 @@ function bitblast(::Type{DistFix{W,F}}, dist::ContinuousUnivariateDistribution,
z = nothing
for i=1:numpieces
iszero(linear_piece_probs[i]) && continue
firstinterval = DistFix{W,F}(start + (i-1)*2^bits_per_piece/2^F)
lastinterval = DistFix{W,F}(start + (i*2^bits_per_piece-1)/2^F)
firstinterval = DistFix{W,F}(start + (i-1)*2^bits_per_piece/2^float(F))
lastinterval = DistFix{W,F}(start + (i*2^bits_per_piece-1)/2^float(F))
linear_dist =
if isdecreasing[i]
(ifelse(slope_flips[i],
Expand Down Expand Up @@ -688,7 +689,6 @@ A modified version of the function bitblast that works with the assumption of lo
function bitblast_sample(::Type{DistFix{W,F}}, dist::ContinuousUnivariateDistribution,
numpieces::Int, start::Float64, stop::Float64, offset::Float64, width::Float64) where {W,F}

@show start, stop, numpieces, offset, width
# count bits and pieces
@assert -(2^(W-F-1)) <= start < stop <= 2^(W-F-1)
f_range_bits = log2((stop - start)*2^F)
Expand Down
11 changes: 10 additions & 1 deletion test/dist/number/fix_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ end
@test sum(p) 1.0
@test sum(q) 1.0
ans = 0
for i=1:length(p)
l = length(p)
for i=1:l
if p[i] > 0
ans += p[i] *(log(p[i]) - log(q[i]))
end
Expand Down Expand Up @@ -171,6 +172,14 @@ end
p2 = map(a -> a[2], sort([(k, v) for (k, v) in p]))
@test p2 q

# Negative F
y = bitblast(DistFix{5, -1}, Normal(1, 1), 2, -4.0, 4.0)
p = pr(y)
d = TruncatedNormal(1, 1, -4, 4)
for i in keys(p)
@test p[i] cdf(d, i+2) - cdf(d, i)
end

#TODO: write tests for continuous distribution other than gaussian
#TODO: Write tests for exponential pieces
end
Expand Down
10 changes: 7 additions & 3 deletions test/util_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,18 @@ end

# TODO: test for negative F
@testset "gaussian_bitblast_sample" begin
x1 = gaussian_bitblast_sample(DistFix{3, 1}, 0.0, 1.0, 2, -1.0, 1.0, [false])
x2 = gaussian_bitblast_sample(DistFix{3, 1}, 0.0, 1.0, 2, -1.0, 1.0, [true])
x1 = gaussian_bitblast_sample(DistFix{3, 1}, 0.0, 1.0, 2, -2.0, 2.0, [false])
x2 = gaussian_bitblast_sample(DistFix{3, 1}, 0.0, 1.0, 2, -2.0, 2.0, [true])
x = ifelse(flip(0.5), x1, x2)
p1 = pr(x)

y = bitblast(DistFix{3, 1}, Normal(0, 1), 4, -1.0, 1.0)
y = bitblast(DistFix{3, 1}, Normal(0, 1), 4, -2.0, 2.0)
p2 = pr(y)
for i in keys(p1)
@test p1[i] p2[i]
end

x1 = gaussian_bitblast_sample(DistFix{3, 1}, 0.0, 1.0, 2, -2.0, 2.0, [false, false])
p = pr(x1)

end

0 comments on commit 9ddeb86

Please sign in to comment.