Imports
+using Distributions
+using Random
+using LinearAlgebra
+using StatsBase
+using CairoMakie
+using Colors
+using ColorSchemes
+= ColorSchemes.seaborn_colorblind6; cmap
diff --git a/_freeze/other/fun/ml/em/execute-results/html.json b/_freeze/other/fun/ml/em/execute-results/html.json new file mode 100644 index 0000000..eacff5e --- /dev/null +++ b/_freeze/other/fun/ml/em/execute-results/html.json @@ -0,0 +1,12 @@ +{ + "hash": "b3704aeedd4cffe8519bd0fe859a76e7", + "result": { + "engine": "jupyter", + "markdown": "---\ntitle: Visualizing Expectation Maximization\nsubtitle: k-means vs. Gaussian Mixture Models (GMM)\nauthor: Connor Robertson\nexecute:\n daemon: true\n---\n\n## Overview\n\nI recently read a [great overview of Gaussian Mixture Models](https://jakevdp.github.io/PythonDataScienceHandbook/05.12-gaussian-mixtures.html) and thought, \"Man, I'd like to see this clustering visualized.\"\nSo here I am to visualize it myself.\n\n## Expectation maximization (EM)\n\nBest thing to do is just [read the Wikipedia](https://en.wikipedia.org/wiki/Expectation%E2%80%93maximization_algorithm), but the general procedure here is:\n\n::: {#imp-emsteps .callout-important}\n## EM steps\n1. Sample some points from a distribution\n2. Find the expected value of each point\n3. Adjust parameters to maximize the expected values of the points\n4. Repeat 2 and 3\n:::\n\nThat's all very abstract, but I will define it for each of our cases more concretely (which, fortunately, are simple).\n\n## k-means\n\nk-means is a clustering algorithm that looks to find $k$ clusters in a dataset which are parametrized by the *mean* of the cluster.\nGo figure.\nThe naive algorithm for k-means is Lloyd's algorithm which has the following steps (in a similar spirit to EM):\n\n::: {#imp-kmsteps .callout-important}\n## k-means steps\n1. Sample $k$ reasonable cluster centers from your domain\n2. Compute the closest of the $k$ centers for each of your data points\n3. Adjust the locations of the $k$ centers to be the center of the corresponding data points\n4. Repeat 2 and 3\n:::\n\n### Creating sample data\nTo explore this idea, let's consider two sets of sample data from some Gaussian distributions in two dimensions.\nThe first will be circular (diagonal covariance matrix) and the second will be skewed (dense covariance matrix).\nLet's write this out in simple code:\n\n::: {#f809c5b9 .cell execution_count=1}\n``` {.julia .cell-code code-fold=\"true\" code-summary=\"Imports\"}\nusing Distributions\nusing Random\nusing LinearAlgebra\nusing StatsBase\nusing CairoMakie\nusing Colors\nusing ColorSchemes\ncmap = ColorSchemes.seaborn_colorblind6;\n```\n:::\n\n\n::: {#2985af2d .cell execution_count=2}\n``` {.julia .cell-code code-fold=\"true\" code-summary=\"Distribution creation functions\"}\nfunction create_spherical_gaussian(mu, var)\n cov = diagm(ones(length(mu)) .* var)\n return MvNormal(mu, cov)\nend\n\nfunction create_oblong_gaussian(mu, cov)\n return MvNormal(mu, cov)\nend;\n```\n:::\n\n\n::: {#e6aa65d7 .cell execution_count=3}\n``` {.julia .cell-code}\n# Settings\nn_distributions = 5\nn_samples = 20\n\n# Means/variances/covariances\nRandom.seed!(1234)\nmeans = [2i*rand(2) for i in 1:n_distributions]\nvars_1 = [0.2*i*rand() for i in 1:n_distributions]\ncovs_2 = []\nfor i in 1:n_distributions\n cov = 0.2i*rand(2,2)\n cov = cov' * cov # Ensure its positive definite\n push!(covs_2, cov)\nend\n\n# Create distributions\ndists_1 = [create_spherical_gaussian(m, v) for (m,v) in zip(means, vars_1)]\ndists_2 = [create_oblong_gaussian(m, c) for (m,c) in zip(means, covs_2)]\n\n# Sample distributions\nsamples_1 = [rand(d, n_samples) for d in dists_1]\nsamples_2 = [rand(d, n_samples) for d in dists_2]\n\n# Create datasets (combine samples and shuffle)\ndata_1 = hcat(samples_1...)[:,randperm(n_distributions*n_samples)]\ndata_2 = hcat(samples_2...)[:,randperm(n_distributions*n_samples)];\n```\n:::\n\n\nWe can plot the samples we've taken from the distributions to get a sense of the clustering:\n\n::: {#04529378 .cell fig-width='100%' execution_count=4}\n``` {.julia .cell-code}\nfig = Figure(size=(800,400))\nax1 = Axis(fig[1,1])\nax2 = Axis(fig[1,2])\nfor i in 1:n_distributions\n scatter!(ax1, samples_1[i][1,:], samples_1[i][2,:], color=cmap[i], label=nothing)\n scatter!(ax2, samples_2[i][1,:], samples_2[i][2,:], color=cmap[i], label=\"Distribution $i\")\nend\nhidedecorations!(ax1, ticklabels=false)\nhidedecorations!(ax2, ticklabels=false)\naxislegend(ax2; position=:lt)\ndisplay(fig);\n```\n\n::: {.cell-output .cell-output-stderr}\n```\n┌ Warning: Found `resolution` in the theme when creating a `Scene`. The `resolution` keyword for `Scene`s and `Figure`s has been deprecated. Use `Figure(; size = ...` or `Scene(; size = ...)` instead, which better reflects that this is a unitless size and not a pixel resolution. The key could also come from `set_theme!` calls or related theming functions.\n└ @ Makie ~/.julia/packages/Makie/VRavR/src/scenes.jl:220\n```\n:::\n\n::: {.cell-output .cell-output-display}\n![](em_files/figure-html/cell-5-output-2.svg){}\n:::\n:::\n\n\n### Implementation\n\nNow that we have some sample data to try to cluster, let's implement a simple k-means algorithm using the steps in @imp-kmsteps:\n\n::: {#79463777 .cell execution_count=5}\n``` {.julia .cell-code code-fold=\"true\" code-summary=\"Utilities\"}\nfunction distances_to_center(data, center)\n distances = zeros(size(data,2))\n for i in 1:size(data,2)\n distances[i] = mean((data[:,i] .- center).^2)\n end\n return distances\nend\n\nfunction plot_cluster_results!(fig, centers, labels, data, true_data)\n n_clusters = length(true_data)\n k = size(centers, 2)\n\n ax1 = Axis(fig[1,1])\n ax2 = Axis(fig[1,2])\n for i in 1:k\n approx_cluster = data[:,labels .== i]\n scatter!(ax1, approx_cluster[1,:], approx_cluster[2,:], color=cmap[i], label=\"Approx cluster $i\")\n scatter!(ax1, [centers[1,i]], [centers[2,i]], color=cmap[i], marker=:star5, markersize=10, label=nothing, strokewidth=2)\n end\n for i in 1:n_clusters\n scatter!(ax2, true_data[i][1,:], true_data[i][2,:], color=cmap[i], label=\"True cluster $i\")\n end\n hidedecorations!(ax1, ticklabels=false)\n hidedecorations!(ax2, ticklabels=false)\n axislegend(ax1; position=:rb)\n axislegend(ax2; position=:rb)\n return fig\nend;\n\nfunction animate_kmeans!!(centers, data_labels, data, samples, filename)\n fig = Figure(size=(800,400))\n record(fig, filename; framerate=2) do io\n converged = false\n nsteps = 1\n while converged == false\n # Clear the figure\n empty!(fig)\n\n # Step\n converged = kmeans!!(centers, data_labels, data, 1)\n\n # Plot\n plot_cluster_results!(fig, centers, data_labels, data, samples)\n text!(fig.scene, 5, 0, text=\"Step: $nsteps\", font=:bold)\n\n # Capture frame\n recordframe!(io)\n nsteps += 1\n end\n end\nend;\n```\n:::\n\n\n::: {#c24d179a .cell execution_count=6}\n``` {.julia .cell-code code-fold=\"true\" code-summary=\"kmeans\"}\nfunction kmeans!!(centers, data_labels, data, max_steps=100, tol=1e-5)\n k = size(centers,2)\n center_shift = Inf * ones(k)\n nsteps = 0\n\n while (maximum(center_shift) > tol) & (nsteps < max_steps)\n # Compute the distance from each center to our data and assign closest\n distances = hcat([distances_to_center(data, c) for c in eachcol(centers)]...)\n data_labels .= [argmin(d) for d in eachrow(distances)]\n\n # Reassign centers to mean of closest data\n for i in 1:k\n closest_data = data[:,data_labels .== i]\n new_center = mean(closest_data, dims=2)\n center_shift[i] = norm(centers[:,i] .- new_center)\n centers[:,i] .= new_center\n end\n nsteps += 1\n end\n return (maximum(center_shift) < tol)\nend\n```\n\n::: {.cell-output .cell-output-display execution_count=7}\n```\nkmeans!! (generic function with 3 methods)\n```\n:::\n:::\n\n\n::: {#f821f7dd .cell execution_count=7}\n``` {.julia .cell-code}\n# Sample initial cluster centers within our sample data\nRandom.seed!(1)\nk = 5\nsmin = minimum(data_1, dims=2)\nsmax = maximum(data_1, dims=2)\ncenters = (smax - smin) .* rand(2,k) .+ smin\ndata_labels = ones(Int64, size(data_1,2))\n\n# Repeat steps 2 and 3 until centers aren't moving\nkmeans!!(centers, data_labels, data_1)\n\n# Visualize results\nfig = Figure(size=(800,400))\nplot_cluster_results!(fig, centers, data_labels, data_1, samples_1)\n```\n\n::: {.cell-output .cell-output-stderr}\n```\n┌ Warning: Found `resolution` in the theme when creating a `Scene`. The `resolution` keyword for `Scene`s and `Figure`s has been deprecated. Use `Figure(; size = ...` or `Scene(; size = ...)` instead, which better reflects that this is a unitless size and not a pixel resolution. The key could also come from `set_theme!` calls or related theming functions.\n└ @ Makie ~/.julia/packages/Makie/VRavR/src/scenes.jl:220\n```\n:::\n\n::: {.cell-output .cell-output-display execution_count=8}\n![](em_files/figure-html/cell-8-output-2.svg){}\n:::\n:::\n\n\nNotice that the results are highly dependent on our initial sampling. If we adjust the random seeding for the initial centers, we don't get such nice results:\n\n::: {#7db88fe3 .cell execution_count=8}\n``` {.julia .cell-code}\n# Sample initial cluster centers within our sample data\nRandom.seed!(2)\nk = 5\nsmin = minimum(data_1, dims=2)\nsmax = maximum(data_1, dims=2)\ncenters = (smax - smin) .* rand(2,k) .+ smin\ndata_labels = ones(size(data_1,2))\n\n# Execute\nkmeans!!(centers, data_labels, data_1)\n\n# Visualize results\nfig = Figure(size=(800,400))\nplot_cluster_results!(fig, centers, data_labels, data_1, samples_1)\n```\n\n::: {.cell-output .cell-output-stderr}\n```\n┌ Warning: Found `resolution` in the theme when creating a `Scene`. The `resolution` keyword for `Scene`s and `Figure`s has been deprecated. Use `Figure(; size = ...` or `Scene(; size = ...)` instead, which better reflects that this is a unitless size and not a pixel resolution. The key could also come from `set_theme!` calls or related theming functions.\n└ @ Makie ~/.julia/packages/Makie/VRavR/src/scenes.jl:220\n```\n:::\n\n::: {.cell-output .cell-output-display execution_count=9}\n![](em_files/figure-html/cell-9-output-2.svg){}\n:::\n:::\n\n\nWe can also try this on our second dataset:\n\n::: {#72825580 .cell execution_count=9}\n``` {.julia .cell-code}\n# Sample initial cluster centers within our sample data\nRandom.seed!(1)\nk = 5\nsmin = minimum(data_2, dims=2)\nsmax = maximum(data_2, dims=2)\ncenters = (smax - smin) .* rand(2,k) .+ smin\ndata_labels = ones(size(data_2,2))\n\n# Execute\nkmeans!!(centers, data_labels, data_2)\n\n# Visualize results\nfig = Figure(size=(800,400))\nplot_cluster_results!(fig, centers, data_labels, data_2, samples_2)\n```\n\n::: {.cell-output .cell-output-stderr}\n```\n┌ Warning: Found `resolution` in the theme when creating a `Scene`. The `resolution` keyword for `Scene`s and `Figure`s has been deprecated. Use `Figure(; size = ...` or `Scene(; size = ...)` instead, which better reflects that this is a unitless size and not a pixel resolution. The key could also come from `set_theme!` calls or related theming functions.\n└ @ Makie ~/.julia/packages/Makie/VRavR/src/scenes.jl:220\n```\n:::\n\n::: {.cell-output .cell-output-display execution_count=10}\n![](em_files/figure-html/cell-10-output-2.svg){}\n:::\n:::\n\n\n### Animating the process\n\nNow, the entire purpose of this document is actually to visually explore the methods as they evolve, so let's set up some animation structure.\nFortunately, `Makie.jl` makes this easy.\nWe only need to iteratively take a step and capture the plotted frame at each step.\n\n::: {#72429596 .cell execution_count=10}\n``` {.julia .cell-code}\nRandom.seed!(5)\nk = 5\nsmin = minimum(data_1, dims=2)\nsmax = maximum(data_1, dims=2)\ncenters = (smax - smin) .* rand(2,k) .+ smin\ndata_labels = ones(size(data_1,2))\n\nanimate_kmeans!!(centers, data_labels, data_1, samples_1, \"gifs/kmeans.gif\")\n```\n\n::: {.cell-output .cell-output-stderr}\n```\n┌ Warning: Found `resolution` in the theme when creating a `Scene`. The `resolution` keyword for `Scene`s and `Figure`s has been deprecated. Use `Figure(; size = ...` or `Scene(; size = ...)` instead, which better reflects that this is a unitless size and not a pixel resolution. The key could also come from `set_theme!` calls or related theming functions.\n└ @ Makie ~/.julia/packages/Makie/VRavR/src/scenes.jl:220\n```\n:::\n\n::: {.cell-output .cell-output-display execution_count=11}\n```\n\"gifs/kmeans.gif\"\n```\n:::\n:::\n\n\n![](gifs/kmeans.gif)\n\nAmazing! Now we can see the progress of kmeans in each step (where the stars represent the cluster means).\n\n## Gaussian Mixture Models (GMM)\n\nGaussian mixture models really boil down to using a Gaussian distribution for each cluster.\nWith these cluster distributions, we can assign a cluster probability to each data point and then subsequently update the distributions to new means and covariances.\nThis procedure is a true EM process because it aims to maximize the probability of cluster inclusion for each datapoint.\n\nThe naive algorithm for this EM procedure is as follows:\n\n::: {#imp-gmmsteps .callout-important}\n## GMM steps\n1. Sample $k$ reasonable cluster means and covariances for Gaussians from your domain\n2. Compute the PDF values corresponding to each Gaussian and normalize to get cluster probabilities for each datapoint\n3. Adjust the means and covariances of the $k$ Gaussians using the cluster probabilities from each datapoint\n4. Repeat 2 and 3\n:::\n\n### Implementing EM\n\nTo start, we are going to implement @imp-gmmsteps by considering spherical Gaussians in which the covariance is a diagonally constrained.\n\n::: {#d3f156fa .cell execution_count=11}\n``` {.julia .cell-code code-fold=\"true\" code-summary=\"Utilities\"}\nfunction plot_gaussians(gaussians)\n# Show initial gaussian clusters\n fig = Figure()\n ax = Axis(fig[1,1])\n sample_x = range(smin[1], smax[1], 1000)\n sample_y = range(smin[2], smax[2], 1000)\n for (i,g) in enumerate(gaussians)\n sample_pdfs = [pdf(g, [x,y]) for x in sample_x, y in sample_y]\n sample_pdfs ./= sum(sample_pdfs)\n my_colormap = reverse([cmap[i]*(1 - j/256) + RGB(1, 1, 1)*(j/256) - RGBAf(0,0,0,j/256) for j in 0:256-1])\n contourf!(sample_x, sample_y, sample_pdfs, colormap=my_colormap, label=\"Gaussian $i\")\n end\n axislegend(ax; position=:rb)\n hidedecorations!(ax, ticklabels=false)\nend\n\nfunction plot_gmm_results!(fig, gaussians, labels, data, true_data)\n n_clusters = length(true_data)\n k = size(centers, 2)\n\n sample_x = range(smin[1], smax[1], 1000)\n sample_y = range(smin[2], smax[2], 1000)\n ax1 = Axis(fig[1,1])\n ax2 = Axis(fig[1,2])\n for i in 1:k\n approx_cluster = data[:,labels .== i]\n scatter!(ax1, approx_cluster[1,:], approx_cluster[2,:], color=cmap[i], label=\"Approx cluster $i\")\n sample_pdfs = [pdf(gaussians[i], [x,y]) for x in sample_x, y in sample_y]\n my_colormap = reverse([cmap[i]*(1 - j/256) + RGB(1, 1, 1)*(j/256) - RGBAf(0,0,0,j/256) for j in 0:256-1])\n contourf!(ax1, sample_x, sample_y, sample_pdfs, colormap=my_colormap, label=nothing)\n end\n for i in 1:n_clusters\n scatter!(ax2, true_data[i][1,:], true_data[i][2,:], color=cmap[i], label=\"True cluster $i\")\n end\n hidedecorations!(ax1, ticklabels=false)\n hidedecorations!(ax2, ticklabels=false)\n axislegend(ax1; position=:rb)\n axislegend(ax2; position=:rb)\n return fig\nend;\n\n# Execute one step at a time, visualize results at each step, then record a frame\nfunction animate_gmm!!(gaussians, data_labels, data, samples, filename)\n fig = Figure(size=(800,400))\n record(fig, filename; framerate=2) do io\n converged = false\n nsteps = 1\n while converged == false\n # Clear the figure\n empty!(fig)\n\n # Step\n converged = gmm!!(gaussians, data_labels, data, 1)\n\n # Plot\n plot_gmm_results!(fig, gaussians, data_labels, data, samples)\n text!(fig.scene, 5, 0, text=\"Step: $nsteps\", font=:bold)\n\n # Capture frame\n recordframe!(io)\n nsteps += 1\n end\n end;\nend\n```\n\n::: {.cell-output .cell-output-display execution_count=12}\n```\nanimate_gmm!! (generic function with 1 method)\n```\n:::\n:::\n\n\n::: {#6c5faebc .cell execution_count=12}\n``` {.julia .cell-code}\n# Sample initial cluster means and variances\nRandom.seed!(1)\nk = 5\nsmin = minimum(data_1, dims=2)\nsmax = maximum(data_1, dims=2)\nmeans = (smax - smin) .* rand(2,k) .+ smin\nvars = ones(k)\ngaussians = [MvNormal(means[:,i],vars[i]) for i in 1:k]\ndata_labels = ones(Int64, size(data_1,2))\n\nplot_gaussians(gaussians);\n```\n\n::: {.cell-output .cell-output-stderr}\n```\n┌ Warning: Found `resolution` in the theme when creating a `Scene`. The `resolution` keyword for `Scene`s and `Figure`s has been deprecated. Use `Figure(; size = ...` or `Scene(; size = ...)` instead, which better reflects that this is a unitless size and not a pixel resolution. The key could also come from `set_theme!` calls or related theming functions.\n└ @ Makie ~/.julia/packages/Makie/VRavR/src/scenes.jl:220\n```\n:::\n:::\n\n\nAs you can observe, we have some great circular starting distributions.\nWe now need to align these with the data.\nLet's implement the GMM procedure from @imp-gmmsteps:\n\n::: {#7722da20 .cell execution_count=13}\n``` {.julia .cell-code code-fold=\"true\" code-summary=\"GMM\"}\nfunction gmm!!(gaussians::Vector{IsoNormal}, data_labels, data, max_steps=100, tol=1e-5)\n k = length(gaussians)\n ll_shift = 10.0\n nsteps = 0\n gweights = ones(k) ./ k\n data_weights = zeros(length(data_labels),k)\n\n while (ll_shift > tol) & (nsteps < max_steps)\n # Compute the pdfs for each datapoint to each gaussian and assign the most probable\n cluster_probs = hcat([pdf(g, data) for g in gaussians]...)\n ll_old = sum(log.(cluster_probs * gweights))\n all_weight = cluster_probs * gweights\n data_weights .= gweights' .* cluster_probs ./ all_weight\n\n # Reassign gaussians to weighted means and std devs\n for i in 1:k\n Nk = sum(data_weights[:,i])\n new_mean = sum(data .* data_weights[:,i]', dims=2) ./ Nk\n new_std = sqrt(mean(sum((data .- new_mean).^2 .* data_weights[:,i]', dims=2) ./ Nk))\n gaussians[i] = MvNormal(vec(new_mean), new_std)\n end\n\n # Check change in log likelihood for convergence\n cluster_probs = hcat([pdf(g, data) for g in gaussians]...)\n ll_new = sum(log.(cluster_probs * gweights))\n ll_shift = ll_new - ll_old\n ll_old = ll_new\n nsteps += 1\n end\n data_labels .= [argmax(d) for d in eachrow(data_weights)]\n return (ll_shift < tol)\nend;\n```\n:::\n\n\nNow, let's put it all together and run it for our initial distributions.\n\n::: {#cfd39015 .cell execution_count=14}\n``` {.julia .cell-code}\nRandom.seed!(5)\nk = 5\nsmin = minimum(data_1, dims=2)\nsmax = maximum(data_1, dims=2)\nmeans = (smax - smin) .* rand(2,k) .+ smin\nvars = ones(k)\ngaussians = [MvNormal(means[:,i],vars[i]) for i in 1:k]\ndata_labels = ones(Int64, size(data_1,2))\n\nanimate_gmm!!(gaussians, data_labels, data_1, samples_1, \"gifs/gmm1.gif\");\n```\n\n::: {.cell-output .cell-output-stderr}\n```\n┌ Warning: Found `resolution` in the theme when creating a `Scene`. The `resolution` keyword for `Scene`s and `Figure`s has been deprecated. Use `Figure(; size = ...` or `Scene(; size = ...)` instead, which better reflects that this is a unitless size and not a pixel resolution. The key could also come from `set_theme!` calls or related theming functions.\n└ @ Makie ~/.julia/packages/Makie/VRavR/src/scenes.jl:220\n```\n:::\n:::\n\n\n![](gifs/gmm1.gif)\n\nGreat! These are pretty good fits.\nHowever, applying it to our second dataset quickly reveals that spherical Gaussians just won't cut it:\n\n::: {#3960f7a0 .cell execution_count=15}\n``` {.julia .cell-code}\nRandom.seed!(1)\nk = 5\nsmin = minimum(data_1, dims=2)\nsmax = maximum(data_1, dims=2)\nmeans = (smax - smin) .* rand(2,k) .+ smin\nvars = ones(k)\ngaussians = [MvNormal(means[:,i],vars[i]) for i in 1:k]\ndata_labels = ones(Int64, size(data_2,2))\n\nanimate_gmm!!(gaussians, data_labels, data_2, samples_2, \"gifs/gmm2.gif\");\n```\n\n::: {.cell-output .cell-output-stderr}\n```\n┌ Warning: Found `resolution` in the theme when creating a `Scene`. The `resolution` keyword for `Scene`s and `Figure`s has been deprecated. Use `Figure(; size = ...` or `Scene(; size = ...)` instead, which better reflects that this is a unitless size and not a pixel resolution. The key could also come from `set_theme!` calls or related theming functions.\n└ @ Makie ~/.julia/packages/Makie/VRavR/src/scenes.jl:220\n```\n:::\n:::\n\n\nInstead, we can modify our calculation of the variance in the `gmm!!` function to calculate the full covariance matrix:\n\n::: {#2bf07107 .cell execution_count=16}\n``` {.julia .cell-code code-fold=\"true\" code-summary=\"GMM (full covariance)\"}\nfunction gmm!!(gaussians::Vector{FullNormal}, data_labels, data, max_steps=100, tol=1e-5)\n k = length(gaussians)\n ll_shift = 10.0\n nsteps = 0\n gweights = ones(k) ./ k\n data_weights = zeros(length(data_labels),k)\n\n while (ll_shift > tol) & (nsteps < max_steps)\n # Compute the pdfs for each datapoint to each gaussian and assign the most probable\n cluster_probs = hcat([pdf(g, data) for g in gaussians]...)\n ll_old = sum(log.(cluster_probs * gweights))\n all_weight = cluster_probs * gweights\n data_weights .= gweights' .* cluster_probs ./ all_weight\n\n # Reassign gaussians to weighted means and std devs\n for i in 1:k\n Nk = sum(data_weights[:,i])\n new_mean = sum(data .* data_weights[:,i]', dims=2) ./ Nk\n new_std = mean([(d .- new_mean)*(d .- new_mean)' for d in eachcol(data)], weights(data_weights[:,i]))\n gaussians[i] = MvNormal(vec(new_mean), new_std)\n end\n\n # Check change in log likelihood for convergence\n cluster_probs = hcat([pdf(g, data) for g in gaussians]...)\n ll_new = sum(log.(cluster_probs * gweights))\n ll_shift = ll_new - ll_old\n ll_old = ll_new\n nsteps += 1\n end\n data_labels .= [argmax(d) for d in eachrow(data_weights)]\n return (ll_shift < tol)\nend;\n```\n:::\n\n\n::: {#1d8ea610 .cell execution_count=17}\n``` {.julia .cell-code}\nRandom.seed!(1)\nk = 5\nsmin = minimum(data_1, dims=2)\nsmax = maximum(data_1, dims=2)\nmeans = (smax - smin) .* rand(2,k) .+ smin\nvars = []\nfor i in 1:k\n A = I(2) .+ rand(2,2)\n push!(vars,A' * A)\nend\ngaussians = [MvNormal(means[:,i],vars[i]) for i in 1:k]\ndata_labels = ones(Int64, size(data_2,2))\n\nanimate_gmm!!(gaussians, data_labels, data_2, samples_2, \"gifs/gmm3.gif\");\n```\n\n::: {.cell-output .cell-output-stderr}\n```\n┌ Warning: Found `resolution` in the theme when creating a `Scene`. The `resolution` keyword for `Scene`s and `Figure`s has been deprecated. Use `Figure(; size = ...` or `Scene(; size = ...)` instead, which better reflects that this is a unitless size and not a pixel resolution. The key could also come from `set_theme!` calls or related theming functions.\n└ @ Makie ~/.julia/packages/Makie/VRavR/src/scenes.jl:220\n```\n:::\n:::\n\n\n![](gifs/gmm3.gif)\n\nThe clusters aren't perfectly recovered (due to the initialization), but you can see that the Gaussians cover the data much more precisely.\n\n## Initializing clusters\n\nNow that we have fully fledged approaches for clustering with expectation maximization, let's consider better initializations:\n\n1. **Forgy's method**: Sample $k$ points from the data as starting centers\n2. **Random partitions**: Randomly assign cluster ids then average to yield the initial points\n3. **kmeans++**: Choose each mean from the data a large distance from the previously selected (proportional to the squared distance)\n\n::: {#e0957e2f .cell execution_count=18}\n``` {.julia .cell-code code-fold=\"true\" code-summary=\"Plotting function\"}\nfunction plot_centers!(fig, centers, true_data)\n n_clusters = length(true_data)\n k = size(centers, 2)\n\n ax1 = Axis(fig[1,1])\n for i in 1:n_clusters\n scatter!(ax1, true_data[i][1,:], true_data[i][2,:], color=cmap[i], label=\"True cluster $i\")\n end\n for i in 1:k\n scatter!(ax1, [centers[1,i]], [centers[2,i]], color=:black, marker=:star5, markersize=10, label=nothing, strokewidth=2)\n end\n hidedecorations!(ax1, ticklabels=false)\n axislegend(ax1; position=:rb)\n return fig\nend;\n```\n:::\n\n\n::: {#6394d29d .cell execution_count=19}\n``` {.julia .cell-code code-fold=\"true\" code-summary=\"Initialization functions\"}\nfunction forgy_init(data, k)\n return hcat(sample(eachcol(data), k, replace=false)...)\nend\nfunction random_init(data, k)\n cluster_ids = sample(1:k, size(data,2))\n return hcat([mean(data[:,cluster_ids .== i], dims=2) for i in 1:k]...)\nend\nfunction kpp_init(data, k)\n means = [sample(eachcol(data))]\n for i in 2:k\n dists = sum(hcat([norm.(eachcol(data .- m)).^2 for m in means]...), dims=2)\n dists ./= sum(dists)\n new_mean = sample(eachcol(data), weights(dists))\n push!(means, new_mean)\n end\n return hcat(means...)\nend;\n```\n:::\n\n\n::: {#351fac81 .cell execution_count=20}\n``` {.julia .cell-code}\nRandom.seed!(1)\n# Forgy\nfig = Figure(size=(600,400))\ncenters = forgy_init(data_1, 5)\nplot_centers!(fig, centers, samples_1)\nfig.content[1].title = \"Forgy\"\ndisplay(fig);\n\n# Random partition\nfig = Figure(size=(600,400))\ncenters = random_init(data_1, 5)\nplot_centers!(fig, centers, samples_1)\nfig.content[1].title = \"Random partition\"\ndisplay(fig);\n\n# Kpp\nfig = Figure(size=(600,400))\ncenters = kpp_init(data_1, 5)\nplot_centers!(fig, centers, samples_1)\nfig.content[1].title = \"Kpp\";\ndisplay(fig);\n```\n\n::: {.cell-output .cell-output-stderr}\n```\n┌ Warning: Found `resolution` in the theme when creating a `Scene`. The `resolution` keyword for `Scene`s and `Figure`s has been deprecated. Use `Figure(; size = ...` or `Scene(; size = ...)` instead, which better reflects that this is a unitless size and not a pixel resolution. The key could also come from `set_theme!` calls or related theming functions.\n└ @ Makie ~/.julia/packages/Makie/VRavR/src/scenes.jl:220\n┌ Warning: Found `resolution` in the theme when creating a `Scene`. The `resolution` keyword for `Scene`s and `Figure`s has been deprecated. Use `Figure(; size = ...` or `Scene(; size = ...)` instead, which better reflects that this is a unitless size and not a pixel resolution. The key could also come from `set_theme!` calls or related theming functions.\n└ @ Makie ~/.julia/packages/Makie/VRavR/src/scenes.jl:220\n┌ Warning: Found `resolution` in the theme when creating a `Scene`. The `resolution` keyword for `Scene`s and `Figure`s has been deprecated. Use `Figure(; size = ...` or `Scene(; size = ...)` instead, which better reflects that this is a unitless size and not a pixel resolution. The key could also come from `set_theme!` calls or related theming functions.\n└ @ Makie ~/.julia/packages/Makie/VRavR/src/scenes.jl:220\n```\n:::\n\n::: {.cell-output .cell-output-display}\n![](em_files/figure-html/cell-21-output-2.svg){}\n:::\n\n::: {.cell-output .cell-output-display}\n![](em_files/figure-html/cell-21-output-3.svg){}\n:::\n\n::: {.cell-output .cell-output-display}\n![](em_files/figure-html/cell-21-output-4.svg){}\n:::\n:::\n\n\nIn this case, it is immediately obvious that the kmeans++ technique yielded the best, most spread centers that will work well with either our kmeans or GMM algorithms.\nOn the other hand, the Forgy yielded data which was too clustered, something that can easily occur if there is an imbalance in representation between samples from different clusters.\nThe Random partition approach gave points which were somewhat reasonable, but very centered.\nIn this case, it is possible that kmeans or GMM does capture the correct cluster partitions, but also possible that some clusters are grouped.\n\n## Conclusion\nNow, given improved initialization approaches, let's rerun our kmeans and GMM approaches with random seeds that previously posed challenges:\n\n### Kmeans\n\n::: {#92242111 .cell execution_count=21}\n``` {.julia .cell-code}\n# For dataset 1\nRandom.seed!(2)\nk = 5\ncenters = kpp_init(data_1, k)\ndata_labels = ones(size(data_1,2))\n\nkmeans!!(centers, data_labels, data_1)\n\nfig = Figure(size=(800,400))\nplot_cluster_results!(fig, centers, data_labels, data_1, samples_1)\ndisplay(fig);\n```\n\n::: {.cell-output .cell-output-stderr}\n```\n┌ Warning: Found `resolution` in the theme when creating a `Scene`. The `resolution` keyword for `Scene`s and `Figure`s has been deprecated. Use `Figure(; size = ...` or `Scene(; size = ...)` instead, which better reflects that this is a unitless size and not a pixel resolution. The key could also come from `set_theme!` calls or related theming functions.\n└ @ Makie ~/.julia/packages/Makie/VRavR/src/scenes.jl:220\n```\n:::\n\n::: {.cell-output .cell-output-display}\n![](em_files/figure-html/cell-22-output-2.svg){}\n:::\n:::\n\n\n::: {#7523c32c .cell execution_count=22}\n``` {.julia .cell-code}\n# For dataset 2\nRandom.seed!(2)\nk = 5\ncenters = kpp_init(data_2, k)\ndata_labels = ones(size(data_2,2))\n\nkmeans!!(centers, data_labels, data_2)\n\nfig = Figure(size=(800,400))\nplot_cluster_results!(fig, centers, data_labels, data_2, samples_2)\n```\n\n::: {.cell-output .cell-output-stderr}\n```\n┌ Warning: Found `resolution` in the theme when creating a `Scene`. The `resolution` keyword for `Scene`s and `Figure`s has been deprecated. Use `Figure(; size = ...` or `Scene(; size = ...)` instead, which better reflects that this is a unitless size and not a pixel resolution. The key could also come from `set_theme!` calls or related theming functions.\n└ @ Makie ~/.julia/packages/Makie/VRavR/src/scenes.jl:220\n```\n:::\n\n::: {.cell-output .cell-output-display execution_count=23}\n![](em_files/figure-html/cell-23-output-2.svg){}\n:::\n:::\n\n\n### GMM\n\n::: {#a4aeb5b1 .cell execution_count=23}\n``` {.julia .cell-code}\n# For dataset 1\nRandom.seed!(2)\nk = 5\nmeans = kpp_init(data_1, k)\nvars = []\nfor i in 1:k\n A = I(2) .+ rand(2,2)\n push!(vars,A' * A)\nend\ngaussians = [MvNormal(means[:,i],vars[i]) for i in 1:k]\ndata_labels = ones(size(data_2,2))\n\ngmm!!(gaussians, data_labels, data_1)\n\nfig = Figure(size=(800,400))\nplot_gmm_results!(fig, gaussians, data_labels, data_1, samples_1)\n```\n\n::: {.cell-output .cell-output-stderr}\n```\n┌ Warning: Found `resolution` in the theme when creating a `Scene`. The `resolution` keyword for `Scene`s and `Figure`s has been deprecated. Use `Figure(; size = ...` or `Scene(; size = ...)` instead, which better reflects that this is a unitless size and not a pixel resolution. The key could also come from `set_theme!` calls or related theming functions.\n└ @ Makie ~/.julia/packages/Makie/VRavR/src/scenes.jl:220\n```\n:::\n\n::: {.cell-output .cell-output-display execution_count=24}\n![](em_files/figure-html/cell-24-output-2.svg){}\n:::\n:::\n\n\n::: {#62f5454a .cell execution_count=24}\n``` {.julia .cell-code}\n# For dataset 2\nRandom.seed!(2)\nk = 5\nmeans = kpp_init(data_2, k)\nvars = []\nfor i in 1:k\n A = I(2) .+ rand(2,2)\n push!(vars,A' * A)\nend\ngaussians = [MvNormal(means[:,i],vars[i]) for i in 1:k]\ndata_labels = ones(size(data_2,2))\n\ngmm!!(gaussians, data_labels, data_2)\n\nfig = Figure(size=(800,400))\nplot_gmm_results!(fig, gaussians, data_labels, data_2, samples_2)\n```\n\n::: {.cell-output .cell-output-stderr}\n```\n┌ Warning: Found `resolution` in the theme when creating a `Scene`. The `resolution` keyword for `Scene`s and `Figure`s has been deprecated. Use `Figure(; size = ...` or `Scene(; size = ...)` instead, which better reflects that this is a unitless size and not a pixel resolution. The key could also come from `set_theme!` calls or related theming functions.\n└ @ Makie ~/.julia/packages/Makie/VRavR/src/scenes.jl:220\n```\n:::\n\n::: {.cell-output .cell-output-display execution_count=25}\n![](em_files/figure-html/cell-25-output-2.svg){}\n:::\n:::\n\n\nGreat fits!\n\n", + "supporting": [ + "em_files" + ], + "filters": [], + "includes": {} + } +} \ No newline at end of file diff --git a/_freeze/other/fun/ml/em/figure-html/cell-10-output-2.svg b/_freeze/other/fun/ml/em/figure-html/cell-10-output-2.svg new file mode 100644 index 0000000..31703c9 --- /dev/null +++ b/_freeze/other/fun/ml/em/figure-html/cell-10-output-2.svg @@ -0,0 +1,901 @@ + + diff --git a/_freeze/other/fun/ml/em/figure-html/cell-19-output-2.svg b/_freeze/other/fun/ml/em/figure-html/cell-19-output-2.svg new file mode 100644 index 0000000..7dc41ba --- /dev/null +++ b/_freeze/other/fun/ml/em/figure-html/cell-19-output-2.svg @@ -0,0 +1,939 @@ + + diff --git a/_freeze/other/fun/ml/em/figure-html/cell-20-output-2.svg b/_freeze/other/fun/ml/em/figure-html/cell-20-output-2.svg new file mode 100644 index 0000000..a3c760b --- /dev/null +++ b/_freeze/other/fun/ml/em/figure-html/cell-20-output-2.svg @@ -0,0 +1,1413 @@ + + diff --git a/_freeze/other/fun/ml/em/figure-html/cell-21-output-1.svg b/_freeze/other/fun/ml/em/figure-html/cell-21-output-1.svg new file mode 100644 index 0000000..9d2c151 --- /dev/null +++ b/_freeze/other/fun/ml/em/figure-html/cell-21-output-1.svg @@ -0,0 +1,1465 @@ + + diff --git a/_freeze/other/fun/ml/em/figure-html/cell-21-output-2.svg b/_freeze/other/fun/ml/em/figure-html/cell-21-output-2.svg new file mode 100644 index 0000000..de8579f --- /dev/null +++ b/_freeze/other/fun/ml/em/figure-html/cell-21-output-2.svg @@ -0,0 +1,488 @@ + + diff --git a/_freeze/other/fun/ml/em/figure-html/cell-21-output-3.svg b/_freeze/other/fun/ml/em/figure-html/cell-21-output-3.svg new file mode 100644 index 0000000..58c6f71 --- /dev/null +++ b/_freeze/other/fun/ml/em/figure-html/cell-21-output-3.svg @@ -0,0 +1,547 @@ + + diff --git a/_freeze/other/fun/ml/em/figure-html/cell-21-output-4.svg b/_freeze/other/fun/ml/em/figure-html/cell-21-output-4.svg new file mode 100644 index 0000000..ace9831 --- /dev/null +++ b/_freeze/other/fun/ml/em/figure-html/cell-21-output-4.svg @@ -0,0 +1,476 @@ + + diff --git a/_freeze/other/fun/ml/em/figure-html/cell-22-output-2.svg b/_freeze/other/fun/ml/em/figure-html/cell-22-output-2.svg new file mode 100644 index 0000000..4038f9c --- /dev/null +++ b/_freeze/other/fun/ml/em/figure-html/cell-22-output-2.svg @@ -0,0 +1,896 @@ + + diff --git a/_freeze/other/fun/ml/em/figure-html/cell-23-output-2.svg b/_freeze/other/fun/ml/em/figure-html/cell-23-output-2.svg new file mode 100644 index 0000000..ef0eccb --- /dev/null +++ b/_freeze/other/fun/ml/em/figure-html/cell-23-output-2.svg @@ -0,0 +1,902 @@ + + diff --git a/_freeze/other/fun/ml/em/figure-html/cell-24-output-2.svg b/_freeze/other/fun/ml/em/figure-html/cell-24-output-2.svg new file mode 100644 index 0000000..cdea130 --- /dev/null +++ b/_freeze/other/fun/ml/em/figure-html/cell-24-output-2.svg @@ -0,0 +1,1028 @@ + + diff --git a/_freeze/other/fun/ml/em/figure-html/cell-25-output-2.svg b/_freeze/other/fun/ml/em/figure-html/cell-25-output-2.svg new file mode 100644 index 0000000..5ab78a7 --- /dev/null +++ b/_freeze/other/fun/ml/em/figure-html/cell-25-output-2.svg @@ -0,0 +1,1108 @@ + + diff --git a/_freeze/other/fun/ml/em/figure-html/cell-5-output-1.svg b/_freeze/other/fun/ml/em/figure-html/cell-5-output-1.svg new file mode 100644 index 0000000..f3e309f --- /dev/null +++ b/_freeze/other/fun/ml/em/figure-html/cell-5-output-1.svg @@ -0,0 +1,277 @@ + + diff --git a/_freeze/other/fun/ml/em/figure-html/cell-5-output-2.svg b/_freeze/other/fun/ml/em/figure-html/cell-5-output-2.svg new file mode 100644 index 0000000..6c85601 --- /dev/null +++ b/_freeze/other/fun/ml/em/figure-html/cell-5-output-2.svg @@ -0,0 +1,636 @@ + + diff --git a/_freeze/other/fun/ml/em/figure-html/cell-6-output-1.svg b/_freeze/other/fun/ml/em/figure-html/cell-6-output-1.svg new file mode 100644 index 0000000..7094284 --- /dev/null +++ b/_freeze/other/fun/ml/em/figure-html/cell-6-output-1.svg @@ -0,0 +1,150 @@ + + diff --git a/_freeze/other/fun/ml/em/figure-html/cell-7-output-1.svg b/_freeze/other/fun/ml/em/figure-html/cell-7-output-1.svg new file mode 100644 index 0000000..74419bd --- /dev/null +++ b/_freeze/other/fun/ml/em/figure-html/cell-7-output-1.svg @@ -0,0 +1,150 @@ + + diff --git a/_freeze/other/fun/ml/em/figure-html/cell-7-output-2.svg b/_freeze/other/fun/ml/em/figure-html/cell-7-output-2.svg new file mode 100644 index 0000000..2efcd32 --- /dev/null +++ b/_freeze/other/fun/ml/em/figure-html/cell-7-output-2.svg @@ -0,0 +1,896 @@ + + diff --git a/_freeze/other/fun/ml/em/figure-html/cell-8-output-2.svg b/_freeze/other/fun/ml/em/figure-html/cell-8-output-2.svg new file mode 100644 index 0000000..2efcd32 --- /dev/null +++ b/_freeze/other/fun/ml/em/figure-html/cell-8-output-2.svg @@ -0,0 +1,896 @@ + + diff --git a/_freeze/other/fun/ml/em/figure-html/cell-9-output-2.svg b/_freeze/other/fun/ml/em/figure-html/cell-9-output-2.svg new file mode 100644 index 0000000..4294dd3 --- /dev/null +++ b/_freeze/other/fun/ml/em/figure-html/cell-9-output-2.svg @@ -0,0 +1,895 @@ + + diff --git a/_quarto.yml b/_quarto.yml index f06c39e..b982e3f 100644 --- a/_quarto.yml +++ b/_quarto.yml @@ -61,6 +61,8 @@ website: - section: "Just for fun" href: other/fun/fun.qmd contents: + - text: "Expectation maximization with GMM" + href: other/fun/ml/em.qmd - text: "Polygon angles" href: other/fun/doodles/polygon_angles.qmd - text: "Geocentrism" diff --git a/docs/other/fun/fun.html b/docs/other/fun/fun.html index 09dde11..f17435c 100644 --- a/docs/other/fun/fun.html +++ b/docs/other/fun/fun.html @@ -181,6 +181,10 @@