diff --git a/src/dist/number/fix.jl b/src/dist/number/fix.jl index ffb77c6a..372031f4 100644 --- a/src/dist/number/fix.jl +++ b/src/dist/number/fix.jl @@ -21,7 +21,7 @@ function DistFix{W, F}(b::AbstractVector) where {W,F} end function DistFix{W, F}(x::Float64) where {W,F} - mantissa = DistInt{W}(floor(Int, x*2^F)) + mantissa = DistInt{W}(floor(Int, x*2^float(F))) DistFix{W, F}(mantissa) end @@ -63,7 +63,7 @@ end tobits(x::DistFix) = tobits(x.mantissa) function frombits(x::DistFix{W, F}, world) where {W,F} - frombits(x.mantissa, world)/2^F + frombits(x.mantissa, world)/2^float(F) end function expectation(x::DistFix{W, F}; kwargs...) where {W,F} @@ -125,7 +125,7 @@ function bitblast(::Type{DistFix{W,F}}, dist::ContinuousUnivariateDistribution, # count bits and pieces @assert -(2^(W-F-1)) <= start < stop <= 2^(W-F-1) - f_range_bits = log2((stop - start)*2^F) + f_range_bits = log2((stop - start)*2^float(F)) @assert isinteger(f_range_bits) "The number of $(1/2^F)-sized intervals between $start and $stop must be a power of two (not $f_range_bits)." @assert ispow2(numpieces) "Number of pieces must be a power of two (not $numpieces)" intervals_per_piece = (2^Int(f_range_bits))/numpieces @@ -140,14 +140,15 @@ function bitblast(::Type{DistFix{W,F}}, dist::ContinuousUnivariateDistribution, linear_piece_probs = Vector(undef, numpieces) # prob of each piece if it were linear between end points for i=1:numpieces - firstinter = start + (i-1)*intervals_per_piece/2^F - lastinter = start + (i)*intervals_per_piece/2^F + firstinter = start + (i-1)*intervals_per_piece/2^float(F) + lastinter = start + (i)*intervals_per_piece/2^float(F) piece_probs[i] = (cdf(dist, lastinter) - cdf(dist, firstinter)) total_prob += piece_probs[i] - border_probs[i] = [cdf(dist, firstinter + 1/2^F ) - cdf(dist, firstinter), - cdf(dist, lastinter) - cdf(dist, lastinter - 1/2^F )] + border_probs[i] = [cdf(dist, firstinter + 1/2^float(F) ) - cdf(dist, firstinter), + cdf(dist, lastinter) - cdf(dist, lastinter - 1/2^float(F) )] + @show bits_per_piece linear_piece_probs[i] = (border_probs[i][1] + border_probs[i][2])/2 * 2^(bits_per_piece) end @@ -172,8 +173,8 @@ function bitblast(::Type{DistFix{W,F}}, dist::ContinuousUnivariateDistribution, z = nothing for i=1:numpieces iszero(linear_piece_probs[i]) && continue - firstinterval = DistFix{W,F}(start + (i-1)*2^bits_per_piece/2^F) - lastinterval = DistFix{W,F}(start + (i*2^bits_per_piece-1)/2^F) + firstinterval = DistFix{W,F}(start + (i-1)*2^bits_per_piece/2^float(F)) + lastinterval = DistFix{W,F}(start + (i*2^bits_per_piece-1)/2^float(F)) linear_dist = if isdecreasing[i] (ifelse(slope_flips[i], @@ -688,7 +689,6 @@ A modified version of the function bitblast that works with the assumption of lo function bitblast_sample(::Type{DistFix{W,F}}, dist::ContinuousUnivariateDistribution, numpieces::Int, start::Float64, stop::Float64, offset::Float64, width::Float64) where {W,F} - @show start, stop, numpieces, offset, width # count bits and pieces @assert -(2^(W-F-1)) <= start < stop <= 2^(W-F-1) f_range_bits = log2((stop - start)*2^F) diff --git a/test/dist/number/fix_test.jl b/test/dist/number/fix_test.jl index 48a03d3f..922e9051 100644 --- a/test/dist/number/fix_test.jl +++ b/test/dist/number/fix_test.jl @@ -119,7 +119,8 @@ end @test sum(p) ≈ 1.0 @test sum(q) ≈ 1.0 ans = 0 - for i=1:length(p) + l = length(p) + for i=1:l if p[i] > 0 ans += p[i] *(log(p[i]) - log(q[i])) end @@ -171,6 +172,14 @@ end p2 = map(a -> a[2], sort([(k, v) for (k, v) in p])) @test p2 ≈ q + # Negative F + y = bitblast(DistFix{5, -1}, Normal(1, 1), 2, -4.0, 4.0) + p = pr(y) + d = TruncatedNormal(1, 1, -4, 4) + for i in keys(p) + @test p[i] ≈ cdf(d, i+2) - cdf(d, i) + end + #TODO: write tests for continuous distribution other than gaussian #TODO: Write tests for exponential pieces end diff --git a/test/util_test.jl b/test/util_test.jl index 58105d8b..bc966202 100644 --- a/test/util_test.jl +++ b/test/util_test.jl @@ -107,14 +107,18 @@ end # TODO: test for negative F @testset "gaussian_bitblast_sample" begin - x1 = gaussian_bitblast_sample(DistFix{3, 1}, 0.0, 1.0, 2, -1.0, 1.0, [false]) - x2 = gaussian_bitblast_sample(DistFix{3, 1}, 0.0, 1.0, 2, -1.0, 1.0, [true]) + x1 = gaussian_bitblast_sample(DistFix{3, 1}, 0.0, 1.0, 2, -2.0, 2.0, [false]) + x2 = gaussian_bitblast_sample(DistFix{3, 1}, 0.0, 1.0, 2, -2.0, 2.0, [true]) x = ifelse(flip(0.5), x1, x2) p1 = pr(x) - y = bitblast(DistFix{3, 1}, Normal(0, 1), 4, -1.0, 1.0) + y = bitblast(DistFix{3, 1}, Normal(0, 1), 4, -2.0, 2.0) p2 = pr(y) for i in keys(p1) @test p1[i] ≈ p2[i] end + + x1 = gaussian_bitblast_sample(DistFix{3, 1}, 0.0, 1.0, 2, -2.0, 2.0, [false, false]) + p = pr(x1) + end