Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix type stability of sampling from Chisq, TDist, Gamma #1885

Merged
merged 12 commits into from
Aug 23, 2024
Merged
2 changes: 1 addition & 1 deletion src/samplers/gamma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,6 @@ end

function rand(rng::AbstractRNG, s::GammaIPSampler)
x = rand(rng, s.s)
e = randexp(rng)
e = randexp(rng, typeof(x))
x*exp(s.nia*e)
end
2 changes: 1 addition & 1 deletion src/univariate/continuous/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ cf(d::Exponential, t::Real) = 1/(1 - t * im * scale(d))


#### Sampling
rand(rng::AbstractRNG, d::Exponential) = xval(d, randexp(rng))
rand(rng::AbstractRNG, d::Exponential) where {T} = xval(d, randexp(rng, float(T)))
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved


#### Fit model
Expand Down
2 changes: 1 addition & 1 deletion src/univariate/continuous/tdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ end
function rand(rng::AbstractRNG, d::TDist)
ν = d.ν
z = sqrt(rand(rng, Chisq{typeof(ν)}(ν)) / ν)
return randn(rng) / (isinf(ν) ? one(z) : z)
return randn(rng, typeof(z)) / (isinf(ν) ? one(z) : z)
end

function cf(d::TDist{T}, t::Real) where T <: Real
Expand Down
7 changes: 5 additions & 2 deletions test/univariate/continuous/chisq.jl
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
test_cgf(Chisq(1), (0.49, -1, -100, -1f6))
test_cgf(Chisq(3), (0.49, -1, -100, -1f6))

@testset "Chisq" begin
test_cgf(Chisq(1), (0.49, -1, -100, -1.0f6))
test_cgf(Chisq(3), (0.49, -1, -100, -1.0f6))
end
40 changes: 25 additions & 15 deletions test/univariate/continuous/gamma.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,35 @@
using Test, Distributions, OffsetArrays

test_cgf(Gamma(1 ,1 ), (0.9, -1, -100f0, -1e6))
test_cgf(Gamma(10 ,1 ), (0.9, -1, -100f0, -1e6))
test_cgf(Gamma(0.2, 10), (0.08, -1, -100f0, -1e6))
@testset "Gamma" begin
test_cgf(Gamma(1, 1), (0.9, -1, -100.0f0, -1e6))
test_cgf(Gamma(10, 1), (0.9, -1, -100.0f0, -1e6))
test_cgf(Gamma(0.2, 10), (0.08, -1, -100.0f0, -1e6))

@testset "Gamma suffstats and OffsetArrays" begin
a = rand(Gamma(), 11)
wa = 1.0:11.0
@testset "Gamma suffstats and OffsetArrays" begin
a = rand(Gamma(), 11)
wa = 1.0:11.0

resulta = @inferred(suffstats(Gamma, a))
resulta = @inferred(suffstats(Gamma, a))

resultwa = @inferred(suffstats(Gamma, a, wa))
resultwa = @inferred(suffstats(Gamma, a, wa))

b = OffsetArray(a, -5:5)
wb = OffsetArray(wa, -5:5)
b = OffsetArray(a, -5:5)
wb = OffsetArray(wa, -5:5)

resultb = @inferred(suffstats(Gamma, b))
@test resulta == resultb
resultb = @inferred(suffstats(Gamma, b))
@test resulta == resultb

resultwb = @inferred(suffstats(Gamma, b, wb))
@test resultwa == resultwb
resultwb = @inferred(suffstats(Gamma, b, wb))
@test resultwa == resultwb

@test_throws DimensionMismatch suffstats(Gamma, a, wb)
@test_throws DimensionMismatch suffstats(Gamma, a, wb)
end

@test rand(Gamma(1.0, 1.0)) isa Float64
@test rand(Gamma(0.5, 1.0)) isa Float64
@test rand(Gamma(2.0, 1.0)) isa Float64

@test rand(Gamma(1.0f0, 1.0f0)) isa Float32
@test rand(Gamma(0.5f0, 1.0f0)) isa Float32
@test rand(Gamma(2.0f0, 1.0f0)) isa Float32
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
end
19 changes: 14 additions & 5 deletions test/univariate/continuous/tdist.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,19 @@ using ForwardDiff

using Test

@testset "Type stability of `rand` (#1614)" begin
if VERSION >= v"1.9.0-DEV.348"
# randn(::BigFloat) was only added in https://github.com/JuliaLang/julia/pull/44714
@inferred(rand(TDist(big"1.0")))
@testset "TDist" begin
@testset "Type stability of `rand` (#1614)" begin
if VERSION >= v"1.9.0-DEV.348"
# randn(::BigFloat) was only added in https://github.com/JuliaLang/julia/pull/44714
@inferred(rand(TDist(big"1.0")))
end
@inferred(rand(TDist(ForwardDiff.Dual(1.0))))

end
@inferred(rand(TDist(ForwardDiff.Dual(1.0))))

@test rand(TDist(1.0)) isa Float64
@test rand(TDist(1.0f0)) isa Float32

@test entropy(TDist(1.0)) isa Float64
@test entropy(TDist(1.0f0)) isa Float32
Red-Portal marked this conversation as resolved.
Show resolved Hide resolved
end
Loading