diff --git a/src/silhouette.jl b/src/silhouette.jl index 2b2c8ea4..70dd64c1 100644 --- a/src/silhouette.jl +++ b/src/silhouette.jl @@ -22,8 +22,8 @@ function sil_aggregate_dists(k::Int, a::AbstractVector{Int}, dists::DenseMatrix{ end -function silhouettes(assignments::Vector{Int}, - counts::AbstractVector{Int}, +function silhouettes(assignments::Vector{Int}, + counts::AbstractVector{Int}, dists::DenseMatrix{T}) where T<:Real n = length(assignments) @@ -69,13 +69,17 @@ function silhouettes(assignments::Vector{Int}, b[j] = v end - # compute silhouette score + # compute silhouette score sil = a # reuse the memory of a for sil for j = 1:n - @inbounds sil[j] = (b[j] - a[j]) / max(a[j], b[j]) + if counts[assignments[j]] == 1 + sil[j] = 0 + else + @inbounds sil[j] = (b[j] - a[j]) / max(a[j], b[j]) + end end return sil end -silhouettes(R::ClusteringResult, dists::DenseMatrix) = - silhouettes(assignments(R), counts(R), dists) +silhouettes(R::ClusteringResult, dists::DenseMatrix) = + silhouettes(assignments(R), counts(R), dists) diff --git a/test/silhouette.jl b/test/silhouette.jl index e86d6c36..71c7930f 100644 --- a/test/silhouette.jl +++ b/test/silhouette.jl @@ -17,3 +17,8 @@ a = [1, 2, 1, 2] c = [2, 2] @test all(isapprox.(silhouettes(a, c, D), [0.0, -0.5, -0.5, 0.0])) + +a = [1, 1, 1, 2] +c = [3, 1] + +@test all(isapprox.(silhouettes(a, c, D), [0.5, 0.5, -1/3, 0.0])) \ No newline at end of file