diff --git a/src/Clustering.jl b/src/Clustering.jl index 2da54517..2e9e6393 100644 --- a/src/Clustering.jl +++ b/src/Clustering.jl @@ -65,7 +65,10 @@ module Clustering Hclust, hclust, cutree, # MCL - mcl, MCLResult + mcl, MCLResult, + + # utils + assign_clusters ## source files diff --git a/src/kmeans.jl b/src/kmeans.jl index 3f96403f..7b5f6552 100644 --- a/src/kmeans.jl +++ b/src/kmeans.jl @@ -1,5 +1,5 @@ # K-means algorithm - +using Distances #### Interface # C is the type of centers, an (abstract) matrix of size (d x k) @@ -391,3 +391,47 @@ function repick_unused_centers(X::AbstractMatrix{<:Real}, # in: the data matrix tcosts = min(tcosts, ds) end end + + +""" + assign_clusters(X::AbstractMatrix{<:Real}, R::KmeansResult; kwargs...) -> Vector{Int} + +Assign the samples specified as the columns of `X` to the corresponding clusters from `R`. + +# Arguments +- `X`: Input data to be clustered. +- `R`: Fitted clustering result. + +# Keyword arguments +- `distance`: SemiMertric used to compute distances between vectors and clusters centroids. +- `pairwise_computation`: Boolean specifying whether to compute and store pairwise distances. + +""" +function assign_clusters( + X::AbstractMatrix{T}, + R::KmeansResult; + distance::SemiMetric = SqEuclidean(), + pairwise_computation::Bool = true) where {T} + + if pairwise_computation + Xdist = pairwise(distance, X, R.centers, dims=2) + cluster_assignments = partialsortperm.(eachrow(Xdist), 1) + else + cluster_assignments = zeros(Int, size(X, 2)) + Threads.@threads for n in axes(X, 2) + min_dist = typemax(T) + cluster_assignment = 0 + + for k in axes(R.centers, 2) + dist = distance(@view(X[:, n]), @view(R.centers[:, k])) + if dist < min_dist + min_dist = dist + cluster_assignment = k + end + end + cluster_assignments[n] = cluster_assignment + end + end + + return cluster_assignments +end diff --git a/src/utils.jl b/src/utils.jl index b832e86b..efdf2d30 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,6 +1,6 @@ # Common utilities - ##### common types +using Distances """ ClusteringResult @@ -70,3 +70,31 @@ function updatemin!(r::AbstractArray, x::AbstractArray) end return r end + + +""" + assign_clusters(X::AbstractMatrix{<:Real}, R::ClusteringResult; kwargs...) -> Vector{Int} + +Assign the samples specified as the columns of `X` to the corresponding clusters from `R`. + +# Arguments +- `X`: Input data to be clustered. +- `R`: Fitted clustering result. + +# Keyword arguments +- Cluster specific keyword arguments. For example, see the `assign_clusters` method in + [`kmeans`](@ref) for the description of optional `kwargs`. + +""" +function assign_clusters( + X::AbstractMatrix{T}, + R::ClusteringResult; + distance::SemiMetric = SqEuclidean(), + pairwise_computation::Bool = true) where {T} + + if !(typeof(R) <: KmeansResult) + throw(MethodError(assign_clusters, + "NotImplemented: assign_clusters not implemented for R of type $(typeof(R))")) + end + +end \ No newline at end of file diff --git a/test/kmeans.jl b/test/kmeans.jl index 6c4d25e6..e10756a9 100644 --- a/test/kmeans.jl +++ b/test/kmeans.jl @@ -204,4 +204,15 @@ end end end +@testset "get cluster assigments" begin + X = rand(5, 100) + R = kmeans(X, 10; maxiter=200) + reassigned_clusters = assign_clusters(X, R; pairwise_computation=true); + @test R.assignments == reassigned_clusters + + reassigned_clusters2 = assign_clusters(X, R; pairwise_computation=false); + @test R.assignments == reassigned_clusters2 + +end + end diff --git a/test/runtests.jl b/test/runtests.jl index 1f9d483a..5e2cbbb1 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -6,7 +6,8 @@ using SparseArrays using StableRNGs using Statistics -tests = ["seeding", +tests = ["utils", + "seeding", "kmeans", "kmedoids", "affprop", diff --git a/test/utils.jl b/test/utils.jl new file mode 100644 index 00000000..97917e5f --- /dev/null +++ b/test/utils.jl @@ -0,0 +1,12 @@ +using Test +using Clustering +using Distances + +@testset "get cluster assigments not implemented method" begin + + X = rand(10,5) + dist = pairwise(SqEuclidean(), X, dims=2) + R = kmedoids!(dist, [1, 2, 3]) + + @test_throws MethodError assign_clusters(X, R); +end