diff --git a/ext/DistributionsChainRulesCoreExt/eachvariate.jl b/ext/DistributionsChainRulesCoreExt/eachvariate.jl index 359a6703eb..454d203359 100644 --- a/ext/DistributionsChainRulesCoreExt/eachvariate.jl +++ b/ext/DistributionsChainRulesCoreExt/eachvariate.jl @@ -1,4 +1,4 @@ -function ChainRulesCore.rrule(::Type{Distributions.EachVariate{V}}, x::AbstractArray{<:Real}) where {V} +function ChainRulesCore.rrule(::ChainRulesCore.RuleConfig, ::Type{Distributions.EachVariate{V}}, x::AbstractArray{<:Real}) where {V} y = Distributions.EachVariate{V}(x) size_x = size(x) function EachVariate_pullback(Δ) diff --git a/ext/DistributionsChainRulesCoreExt/multivariate/dirichlet.jl b/ext/DistributionsChainRulesCoreExt/multivariate/dirichlet.jl index 5aa0d727d5..69b0dbe18c 100644 --- a/ext/DistributionsChainRulesCoreExt/multivariate/dirichlet.jl +++ b/ext/DistributionsChainRulesCoreExt/multivariate/dirichlet.jl @@ -1,4 +1,4 @@ -function ChainRulesCore.frule((_, Δalpha)::Tuple{Any,Any}, ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}} +function ChainRulesCore.frule(::ChainRulesCore.RuleConfig, (_, Δalpha)::Tuple{Any,Any}, ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}} d = DT(alpha; check_args=check_args) ∂alpha0 = sum(Δalpha) digamma_alpha0 = SpecialFunctions.digamma(d.alpha0) @@ -9,7 +9,7 @@ function ChainRulesCore.frule((_, Δalpha)::Tuple{Any,Any}, ::Type{DT}, alpha::A return d, Δd end -function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}} +function ChainRulesCore.rrule(::ChainRulesCore.RuleConfig, ::Type{DT}, alpha::AbstractVector{T}; check_args::Bool = true) where {T <: Real, DT <: Union{Dirichlet{T}, Dirichlet}} d = DT(alpha; check_args=check_args) digamma_alpha0 = SpecialFunctions.digamma(d.alpha0) function Dirichlet_pullback(_Δd) @@ -20,7 +20,7 @@ function ChainRulesCore.rrule(::Type{DT}, alpha::AbstractVector{T}; check_args:: return d, Dirichlet_pullback end -function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(Distributions._logpdf), d::Dirichlet, x::AbstractVector{<:Real}) +function ChainRulesCore.frule(::ChainRulesCore.RuleConfig, (_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(Distributions._logpdf), d::Dirichlet, x::AbstractVector{<:Real}) Ω = Distributions._logpdf(d, x) ∂alpha = sum(Broadcast.instantiate(Broadcast.broadcasted(Δd.alpha, Δx, d.alpha, x) do Δalphai, Δxi, alphai, xi StatsFuns.xlogy(Δalphai, xi) + (alphai - 1) * Δxi / xi @@ -33,7 +33,7 @@ function ChainRulesCore.frule((_, Δd, Δx)::Tuple{Any,Any,Any}, ::typeof(Distri return Ω, ΔΩ end -function ChainRulesCore.rrule(::typeof(Distributions._logpdf), d::T, x::AbstractVector{<:Real}) where {T<:Dirichlet} +function ChainRulesCore.rrule(::ChainRulesCore.RuleConfig, ::typeof(Distributions._logpdf), d::T, x::AbstractVector{<:Real}) where {T<:Dirichlet} Ω = Distributions._logpdf(d, x) isfinite_Ω = isfinite(Ω) alpha = d.alpha diff --git a/ext/DistributionsChainRulesCoreExt/univariate/continuous/uniform.jl b/ext/DistributionsChainRulesCoreExt/univariate/continuous/uniform.jl index 0461329577..9acaa06910 100644 --- a/ext/DistributionsChainRulesCoreExt/univariate/continuous/uniform.jl +++ b/ext/DistributionsChainRulesCoreExt/univariate/continuous/uniform.jl @@ -1,4 +1,4 @@ -function ChainRulesCore.frule((_, Δd, _), ::typeof(logpdf), d::Uniform, x::Real) +function ChainRulesCore.frule(::ChainRulesCore.RuleConfig, (_, Δd, _), ::typeof(logpdf), d::Uniform, x::Real) # Compute log probability a, b = params(d) insupport = a <= x <= b @@ -12,7 +12,7 @@ function ChainRulesCore.frule((_, Δd, _), ::typeof(logpdf), d::Uniform, x::Real return Ω, ΔΩ end -function ChainRulesCore.rrule(::typeof(logpdf), d::Uniform, x::Real) +function ChainRulesCore.rrule(::ChainRulesCore.RuleConfig, ::typeof(logpdf), d::Uniform, x::Real) # Compute log probability a, b = params(d) insupport = a <= x <= b diff --git a/ext/DistributionsChainRulesCoreExt/univariate/discrete/negativebinomial.jl b/ext/DistributionsChainRulesCoreExt/univariate/discrete/negativebinomial.jl index 06ee0294f7..0a29dbe7cb 100644 --- a/ext/DistributionsChainRulesCoreExt/univariate/discrete/negativebinomial.jl +++ b/ext/DistributionsChainRulesCoreExt/univariate/discrete/negativebinomial.jl @@ -11,7 +11,7 @@ function (f::LogPDFNegativeBinomialPullback{D})(Δ) where {D} return ChainRulesCore.NoTangent(), Δd, ChainRulesCore.NoTangent() end -function ChainRulesCore.rrule(::typeof(logpdf), d::NegativeBinomial, k::Real) +function ChainRulesCore.rrule(::ChainRulesCore.RuleConfig, ::typeof(logpdf), d::NegativeBinomial, k::Real) # Compute log probability (as in the definition of `logpdf(d, k)` above) r, p = params(d) z = StatsFuns.xlogy(r, p) + StatsFuns.xlog1py(k, -p) diff --git a/ext/DistributionsChainRulesCoreExt/univariate/discrete/poissonbinomial.jl b/ext/DistributionsChainRulesCoreExt/univariate/discrete/poissonbinomial.jl index aa27a9fd94..65d77942d3 100644 --- a/ext/DistributionsChainRulesCoreExt/univariate/discrete/poissonbinomial.jl +++ b/ext/DistributionsChainRulesCoreExt/univariate/discrete/poissonbinomial.jl @@ -2,13 +2,13 @@ for f in (:poissonbinomial_pdf, :poissonbinomial_pdf_fft) pullback = Symbol(f, :_pullback) @eval begin function ChainRulesCore.frule( - (_, Δp)::Tuple{<:Any,<:AbstractVector{<:Real}}, ::typeof(Distributions.$f), p::AbstractVector{<:Real} + ::ChainRulesCore.RuleConfig, (_, Δp)::Tuple{<:Any,<:AbstractVector{<:Real}}, ::typeof(Distributions.$f), p::AbstractVector{<:Real} ) y = Distributions.$f(p) A = Distributions.poissonbinomial_pdf_partialderivatives(p) return y, A' * Δp end - function ChainRulesCore.rrule(::typeof(Distributions.$f), p::AbstractVector{<:Real}) + function ChainRulesCore.rrule(::ChainRulesCore.RuleConfig, ::typeof(Distributions.$f), p::AbstractVector{<:Real}) y = Distributions.$f(p) A = Distributions.poissonbinomial_pdf_partialderivatives(p) function $pullback(Δy)