Skip to content

Commit

Permalink
3: Add function fit_beta_mixture (#15)
Browse files Browse the repository at this point in the history
  • Loading branch information
danielinteractive authored Feb 4, 2024
1 parent 6165f4b commit ceb3c13
Show file tree
Hide file tree
Showing 14 changed files with 1,410 additions and 40 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ docs/site/
# committed for packages, but should be committed for applications that require a static
# environment.
Manifest.toml

.vscode/settings.json

.DS_Store
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
ExpectationMaximization = "e1fe09cc-5134-44c2-a941-50f4cd97986a"
FreqTables = "da1fdf0e-e0ff-5433-a45f-9bb5ff651cb1"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Pluto = "c3e4b0f8-55cb-11ea-2926-15256bba5781"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Turing = "fce5fe82-541a-59a6-adf8-730c64b5f9a0"
4 changes: 2 additions & 2 deletions design/ReconcileWithR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ weakdeps = ["Dates", "LinearAlgebra"]
[[deps.CompilerSupportLibraries_jll]]
deps = ["Artifacts", "Libdl"]
uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae"
version = "1.0.5+0"
version = "1.0.2+0"
[[deps.CompositionsBase]]
git-tree-sha1 = "802bb88cd69dfd1509f6670416bd4434015693ad"
Expand Down Expand Up @@ -1623,7 +1623,7 @@ version = "0.42.2+0"
[[deps.Pkg]]
deps = ["Artifacts", "Dates", "Downloads", "FileWatching", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"]
uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
version = "1.9.2"
version = "1.9.0"
[[deps.PlotThemes]]
deps = ["PlotUtils", "Statistics"]
Expand Down
9 changes: 8 additions & 1 deletion src/SafetySignalDetection.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,17 @@ module SafetySignalDetection
using Turing
using StatsPlots
using Distributions
using SpecialFunctions
using Statistics
using LinearAlgebra
using ExpectationMaximization

export
meta_analytic
meta_analytic,
fit_beta_mixture

include("meta_analytic.jl")
include("fit_mle.jl")
include("fit_beta_mixture.jl")

end
26 changes: 26 additions & 0 deletions src/fit_beta_mixture.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
function init_beta_dists(n_components::Int)
zero_one_range = range(start = 0, stop = 1, length = n_components + 2)
alpha_range = zero_one_range[2:(n_components + 1)]
[Beta(alpha, 1) for alpha in alpha_range]
end

"""
Fit a beta mixture to a vector of prior samples
This function returns a beta mixture of `n_components` components approximating
the distribution of the sample vector `x`.
"""
function fit_beta_mixture(x::AbstractArray{T}, n_components::Int) where T<:Real
0 < n_components || throw(DomainError(n_components, "there must be at least one component"))
# Remove outliers to stabilize the fitting.
lower_quant = quantile!(x, 0.001)
upper_quant = quantile!(x, 0.999)
x = filter(y -> y > lower_quant && y < upper_quant, x)

# We initialize here with Beta distributions that are not identical but have increasing alpha parameters.
beta_dists = init_beta_dists(n_components)
mix_guess = MixtureModel(beta_dists)

# Fit the MLE with the classic EM algorithm.
fit_mle(mix_guess, x)
end
87 changes: 87 additions & 0 deletions src/fit_mle.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import Distributions: fit_mle, suffstats, varm

# Weighted MLE for beta distribution
# This supplements Distributions.jl and is needed for the classic EM algorithm to work for the beta distribution.

# sufficient statistics - these capture everything of the data we need
struct BetaStats <: SufficientStats
sum_log_x::Float64 # (weighted) sum of log(x)
sum_log_1mx::Float64 # (weighted) sum of log(1 - x)
tw::Float64 # total sample weight
x_bar::Float64 # (weighted) mean of x
v_bar::Float64 # (weighted) variance of x
end

function suffstats(::Type{<:Beta}, x::AbstractArray{T}, w::AbstractArray{T}) where T<:Real

tw = 0.0
sum_log_x = 0.0 * zero(T)
sum_log_1mx = 0.0 * zero(T)
x_bar = 0.0 * zero(T)

for i in eachindex(x, w)
@inbounds xi = x[i]
0 < xi < 1 || throw(DomainError(xi, "samples must be larger than 0 and smaller than 1"))
@inbounds wi = w[i]
wi >= 0 || throw(DomainError(wi, "weights must be non-negative"))
isfinite(wi) || throw(DomainError(wi, "weights must be finite"))
@inbounds sum_log_x += wi * log(xi)
@inbounds sum_log_1mx += wi * log(one(T) - xi)
@inbounds x_bar += wi * xi
tw += wi
end
sum_log_x /= tw
sum_log_1mx /= tw
x_bar /= tw
v_bar = varm(x, x_bar)

BetaStats(sum_log_x, sum_log_1mx, tw, x_bar, v_bar)
end

# without weights we just put weight 1 for each observation
function suffstats(::Type{<:Beta}, x::AbstractArray{T}) where T<:Real

w = ones(Float64, length(x))
suffstats(Beta, x, w)

end

# generic fit function based on the sufficient statistics, on the log scale to be robust
function fit_mle(::Type{<:Beta}, ss::BetaStats;
maxiter::Int=1000, tol::Float64=1e-14)

# Initialization of parameters at the moment estimators (I guess)
temp = ((ss.x_bar * (1 - ss.x_bar)) / ss.v_bar) - 1
α₀ = ss.x_bar * temp
β₀ = (1 - ss.x_bar) * temp

g₁ = ss.sum_log_x
g₂ = ss.sum_log_1mx

θ= [log(α₀) ; log(β₀)]

converged = false
t=0
while !converged && t < maxiter
t += 1
α = exp(θ[1])
β = exp(θ[2])
temp1 = digamma+ β)
temp2 = trigamma+ β)
temp3 = g₁ + temp1 - digamma(α)
grad = [temp3 * α
(temp1 + g₂ - digamma(β)) * β]
hess = [((temp2 - trigamma(α)) * α + temp3) * α temp2 * β * α
temp2 * α * β ((temp2 - trigamma(β)) * β + temp1 + g₂ - digamma(β)) * β ]
Δθ = hess\grad #newton step
θ .-= Δθ
converged = dot(Δθ,Δθ) < 2*tol #stopping criterion
end

α = exp(θ[1])
β = exp(θ[2])
return Beta(α, β)
end

# then define methods for the original data
fit_mle(::Type{<:Beta}, x::AbstractArray{T}, w::AbstractArray{T}; maxiter::Int=1000, tol::Float64=1e-14) where T<:Real = fit_mle(Beta, suffstats(Beta, x, w))
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
Loading

0 comments on commit ceb3c13

Please sign in to comment.