Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added Chinese Whisphers clustering #112

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions REQUIRE
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ Compat 0.17
Distances 0.3.1
NearestNeighbors 0.0.3
StatsBase 0.9.0
DataStructures # need a newer as yet un released version
9 changes: 8 additions & 1 deletion src/Clustering.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ module Clustering
using Distances
using NearestNeighbors
using StatsBase
using DataStructures

import Base: show
import StatsBase: IntegerVector, RealVector, RealMatrix, counts
Expand Down Expand Up @@ -54,7 +55,11 @@ module Clustering
Hclust, hclust, cutree,

# MCL
mcl, MCLResult
mcl, MCLResult,

# chinese_whispers
chinese_whispers, ChineseWhispersResult


## source files

Expand All @@ -66,6 +71,8 @@ module Clustering
include("affprop.jl")
include("dbscan.jl")
include("mcl.jl")
include("chinesewhispers.jl")

include("fuzzycmeans.jl")

include("silhouette.jl")
Expand Down
74 changes: 74 additions & 0 deletions src/chinesewhispers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@


# Abstractions proposed in https://github.com/JuliaLang/julia/issues/26613
colinds(A::AbstractMatrix) = indices(A,2)

rowinds(A::AbstractMatrix, col::Integer) = indices(A,1)
rowinds(A::SparseMatrixCSC, col::Integer) = rowvals(A)[nzrange(A, col)]

type ChineseWhispersResult <: ClusteringResult
assignments::Vector{Int} # assignments (n)
counts::Vector{Int} # number of samples assigned to each cluster (k)
iterations::Int # number of elapsed iterations
converged::Bool # whether the procedure converged
end

function ChineseWhispersResult(raw_assignments::Associative, iterations, converged)
raw_labels = getindex.(raw_assignments, 1:length(raw_assignments))
normalised_names = Dict{eltype(raw_labels), Int}()
counts = Int[]
assignments = Vector{Int}(length(raw_labels))
for (node, raw_lbl) in enumerate(raw_labels)
name = get!(normalised_names, raw_lbl) do
push!(counts, 0)
length(counts) #Normalised name is next usused integer
end

counts[name]+=1
assignments[node]=name
end
ChineseWhispersResult(assignments, counts, iterations, converged)
end


function chinese_whispers(sim::AbstractMatrix, max_iter=100; verbose=false)
node_labels = DefaultDict{Int,Int}(identity; passkey=true)
# Initially all nodes are labelled with their own ID. (nclusters==nnodes)

for ii in 1:max_iter
changed = false
for node in shuffle(colinds(sim))
old_lbl = node_labels[node]
node_labels[node] = update_node_label(node, sim, node_labels)
changed |= node_labels[node]==old_lbl
end

verbose && println("Iteration: $ii, lbls: $(node_labels)")

if !changed
return ChineseWhispersResult(node_labels, ii, true)
end
end

ChineseWhispersResult(node_labels, max_iter, false)
end

function update_node_label(node::N, adj::AbstractMatrix{W}, node_labels::Associative{N, L}) where {N<:Integer, W<:Real, L}
label_weights = Accumulator(L, W==Bool ? Int : W)

neighbours = rowinds(adj, node)
for neighbour in neighbours
lbl = node_labels[neighbour]
label_weights[lbl] += adj[node, neighbour]
end

old_lbl = node_labels[node]
label_weights[old_lbl]+=zero(W) # Make sure at least one entry in the weights
new_lbl, weight = first(most_common(label_weights, 1))
if weight==0 # No connection
return old_lbl
else
return new_lbl
end
end

66 changes: 66 additions & 0 deletions test/chinesewhispers.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
using Base.Test
using Distances
using Clustering


@testset "basic seperated graph" begin
eg1 = [
0 1 0 1;
1 0 0 0;
0 0 0 0;
1 0 0 0;
]
@testset "$(first(vv))" for vv in [("dense", eg1), ("sparse", sparse(eg1))]
eg = last(vv)
res = chinese_whispers(eg)
lbls = assignments(res)
@test lbls[3] != lbls[1]
@test lbls[3] != lbls[2]
@test lbls[3] != lbls[4]

@test nclusters(res) >= 2
@test sum(counts(res)) == 4
end
end

@testset "planar based" begin
srand(1) # make determanistic
coordersA = randn(10, 2)
coordersB = randn(10, 2) .+ [5 5]

coords = [coordersA; coordersB]';

adj = 1./pairwise(Euclidean(), coords)
adj[isinf.(adj)]=0 # no selfsim
adj[rand(size(adj)).<0.6]=0 #remove some connections

res = chinese_whispers(adj)
lbls = assignments(res)
@test all(lbls[1].==(lbls[1:10]))
@test all(lbls[20].==(lbls[11:20]))

@test nclusters(res) == 2
@test counts(res) == [10, 10]
end


@testset "acts the same for all types" begin
examples = [
sprand(500,500,0.3),
sprand(1500,1500,0.1).>0.5, #Boolean elements
rand(200, 200)
]
function test_assignments(x)
srand(1)
assignments(chinese_whispers(x))
end

for eg in (examples)
eg = collect(Symmetric(eg))
dense_res = test_assignments(eg)
sparse_res = test_assignments(sparse(eg))
symetric_res = test_assignments(Symmetric(eg))

@test dense_res == sparse_res == symetric_res
end
end
10 changes: 7 additions & 3 deletions test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
include("../src/Clustering.jl")
using Compat
using Base.Test

tests = ["seeding",
"kmeans",
Expand All @@ -11,11 +12,14 @@ tests = ["seeding",
"varinfo",
"randindex",
"hclust",
"mcl"]
"mcl",
"chinesewhispers"
]

println("Runing tests:")
for t in tests
fp = "$(t).jl"
println("* $fp ...")
include(fp)
@testset "$t" begin
include(fp)
end
end