Skip to content

Commit

Permalink
Merge pull request #124 from alyst/enh_silh
Browse files Browse the repository at this point in the history
Some silhouettes() improvements
  • Loading branch information
alyst authored Aug 19, 2018
2 parents 4b4105b + fb66d14 commit f153e4c
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 18 deletions.
8 changes: 4 additions & 4 deletions LICENSE.md
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
Clustering.jl is licensed under the MIT License:

> Copyright (c) 2012: all contributors to the package.
>
> Copyright (c) 2012-2018: all contributors to the package.
>
> Permission is hereby granted, free of charge, to any person obtaining
> a copy of this software and associated documentation files (the
> "Software"), to deal in the Software without restriction, including
> without limitation the rights to use, copy, modify, merge, publish,
> distribute, sublicense, and/or sell copies of the Software, and to
> permit persons to whom the Software is furnished to do so, subject to
> the following conditions:
>
>
> The above copyright notice and this permission notice shall be
> included in all copies or substantial portions of the Software.
>
>
> THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
> EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
> MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
Expand Down
57 changes: 44 additions & 13 deletions src/silhouette.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,10 @@
# this function returns r of size (k, n), such that
# r[i, j] is the sum of distances of all points from cluster i to sample j
#
function sil_aggregate_dists(k::Int, a::AbstractVector{Int}, dists::DenseMatrix{T}) where T<:Real
function sil_aggregate_dists(k::Int, a::AbstractVector{Int}, dists::AbstractMatrix{T}) where T<:Real
n = length(a)
r = zeros(k, n)
for j = 1:n
(1 <= a[j] <= k) || error("a[j] should have 1 <= a[j] <= k.")
end
S = typeof((one(T)+one(T))/2)
r = zeros(S, k, n)
@inbounds for j = 1:n
for i = 1:j-1
r[a[i],j] += dists[i,j]
Expand All @@ -22,13 +20,37 @@ function sil_aggregate_dists(k::Int, a::AbstractVector{Int}, dists::DenseMatrix{
end


function silhouettes(assignments::Vector{Int},
counts::AbstractVector{Int},
dists::DenseMatrix{T}) where T<:Real
"""
silhouettes(assignments::AbstractVector, [counts,] dists)
silhouettes(clustering::ClusteringResult, dists)
Compute silhouette values for individual points w.r.t. given clustering.
* `assignments` the vector of point assignments (cluster indices)
* `counts` the optional vector of cluster sizes (how many points assigned to each cluster; should match `assignments`)
* `clustering` the output of some clustering method
* `dists` point×point pairwise distance matrix
Returns a vector of silhouette values for each individual point.
`mean(silhouettes(...))` could be used as a measure of clustering quality;
higher values indicate better separation of clusters w.r.t. distances provided in `dists`.
#### References
1. [Silhouette Wikipedia page](http://en.wikipedia.org/wiki/Silhouette_(clustering)).
2. Peter J. Rousseeuw (1987). "Silhouettes: a Graphical Aid to the Interpretation and Validation of Cluster Analysis". Computational and Applied Mathematics. 20: 53–65.
"""
function silhouettes(assignments::AbstractVector{<:Integer},
counts::AbstractVector{<:Integer},
dists::AbstractMatrix{T}) where T<:Real

n = length(assignments)
k = length(counts)
size(dists) == (n, n) || throw(DimensionMismatch("Inconsistent array dimensions."))
for j = 1:n
(1 <= assignments[j] <= k) || throw(ArgumentError("Bad assignments[$j]=$(assignments[j]): should be in 1:$k range."))
end
sum(counts) == n || throw(ArgumentError("Mismatch between assignments ($n) and counts (sum(counts)=$(sum(counts)))."))
size(dists) == (n, n) || throw(DimensionMismatch("The size of a distance matrix ($(size(dists))) doesn't match the length of assignment vector ($n)."))

# compute average distance from each cluster to each point --> r
r = sil_aggregate_dists(k, assignments, dists)
Expand All @@ -50,14 +72,15 @@ function silhouettes(assignments::Vector{Int},
# compute a and b
# a: average distance w.r.t. the assigned cluster
# b: the minimum average distance w.r.t. other cluster
a = Vector{Float64}(undef, n)
b = Vector{Float64}(undef, n)
S = eltype(r)
a = Vector{S}(undef, n)
b = Vector{S}(undef, n)

for j = 1:n
l = assignments[j]
a[j] = r[l, j]

v = Inf
v = S(Inf)
p = -1
for i = 1:k
@inbounds rij = r[i,j]
Expand All @@ -81,5 +104,13 @@ function silhouettes(assignments::Vector{Int},
return sil
end

silhouettes(R::ClusteringResult, dists::DenseMatrix) =
silhouettes(R::ClusteringResult, dists::AbstractMatrix) =
silhouettes(assignments(R), counts(R), dists)

function silhouettes(assignments::AbstractVector{<:Integer}, dists::AbstractMatrix)
counts = fill(0, maximum(assignments))
for a in assignments
counts[a] += 1
end
silhouettes(assignments, counts, dists)
end
11 changes: 10 additions & 1 deletion test/silhouette.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,16 @@ local D = [0 1 2 3
local a = [1, 1, 2, 2]
local c = [2, 2]

@test silhouettes(a, c, D) [1.5/2.5, 0.5/1.5, 0.5/1.5, 1.5/2.5]
@testset "Input checks" begin
@test_skip silhouettes(a, [1, 1, 2], D) # should throw because cluster counts are inconsistent
@test_throws ArgumentError silhouettes(a, [3, 2], D)
@test_throws ArgumentError silhouettes([1, 1, 3, 2], [2, 2], D)
@test_throws DimensionMismatch silhouettes([1, 1, 2, 2, 2], [2, 3], D)
end

@test @inferred(silhouettes(a, c, D)) [1.5/2.5, 0.5/1.5, 0.5/1.5, 1.5/2.5]
@test @inferred(silhouettes(a, c, convert(Matrix{Float32}, D))) isa AbstractVector{Float32}
@test silhouettes(a, D) == silhouettes(a, c, D) # c is optional

a = [1, 2, 1, 2]
c = [2, 2]
Expand Down

0 comments on commit f153e4c

Please sign in to comment.