diff --git a/Project.toml b/Project.toml index 793933d4..60488b66 100644 --- a/Project.toml +++ b/Project.toml @@ -25,7 +25,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] Adapt = "2" ChainRules = "0.7" -ChainRulesCore = "0.9.7" +ChainRulesCore = "0.9.9" Compat = "3.6" DiffRules = "0.1, 1.0" Distributions = "0.23.3" diff --git a/src/univariate.jl b/src/univariate.jl index 726d3e23..a7469ab9 100644 --- a/src/univariate.jl +++ b/src/univariate.jl @@ -41,15 +41,10 @@ end function TuringPoissonBinomial(p::AbstractArray{<:Real}; check_args = true) pb = Distributions.poissonbinomial_pdf_fft(p) ϵ = eps(eltype(pb)) - check_args && @assert all(x -> x >= -ϵ, pb) && _isapprox(sum(pb), 1, ϵ) + check_args && @assert all(x -> x >= -ϵ, pb) && isapprox(sum(pb), 1; atol=ϵ) return TuringPoissonBinomial(p, pb) end -# @non_differentiable cannot deal with kwargs. -# This circumvents the kwarg in `isapprox` in the constructor above -_isapprox(x,y,atol) = isapprox(x,y,atol=atol) -@non_differentiable _isapprox(::Any,::Any,::Any) - function logpdf(d::TuringPoissonBinomial{T}, k::Int) where T<:Real insupport(d, k) ? log(d.pmf[k + 1]) : -T(Inf) end