Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #64: added n_init to kmeans #78

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 25 additions & 7 deletions src/kmeans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ const _kmeans_default_init = :kmpp
const _kmeans_default_maxiter = 100
const _kmeans_default_tol = 1.0e-6
const _kmeans_default_display = :none
const _kmeans_default_n_init = 10

function kmeans!{T<:AbstractFloat}(X::Matrix{T}, centers::Matrix{T};
weights=nothing,
Expand All @@ -43,18 +44,33 @@ function kmeans(X::Matrix, k::Int;
weights=nothing,
init=_kmeans_default_init,
maxiter::Integer=_kmeans_default_maxiter,
n_init::Integer=_kmeans_default_n_init,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand that n_init comes from Python's sklearn (#64), but it doesn't sound like a best choice for me.
Maybe something like n_tries to reflect that the parameter defines how many times the algorithm, rather than some initialization procedure, is run?

Copy link
Contributor

@wildart wildart Sep 28, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

or ntries? And wouldn't be an overkill to run 10 times? I recommend default value 1, because usually a quick partitioning is required and not necessarily best one. And, if one needs to find a best clustering, this parameter can be set to larger value explicitly.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 is what sklearn does at it sounds reasonable to me.
It isn't unusual to run 1000s of times, (that was done as the baseline for the affinity propagation paper)
If some need a quick partition they can ask for it.

The default shouldn't be so sensitive to random factors.

I think 10 strikes the right balance.
Though I could see argument for 3 or 30

tol::Real=_kmeans_default_tol,
display::Symbol=_kmeans_default_display)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Another unrelated whitespace change


Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One last remaining extraneous newline.

m, n = size(X)
(2 <= k < n) || error("k must have 2 <= k < n.")
iseeds = initseeds(init, X, k)
centers = copyseeds(X, iseeds)
kmeans!(X, centers;
weights=weights,
maxiter=maxiter,
tol=tol,
display=display)
n_init > 0 || throw(ArgumentError("n_init must be greater than 0"))

lowestcost::Float64 = Inf
local bestresult::KmeansResult

for i = 1:n_init
iseeds = initseeds(init, X, k)
centers = copyseeds(X, iseeds)
result = kmeans!(X, centers;
weights=weights,
maxiter=maxiter,
tol=tol,
display=display)

if result.totalcost < lowestcost
lowestcost = result.totalcost
bestresult = result
end
end
return bestresult
end

#### Core implementation
Expand All @@ -72,6 +88,8 @@ function _kmeans!{T<:AbstractFloat}(
tol::Real, # in: tolerance of change at convergence
displevel::Int) # in: the level of display



Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the excess whitespace here and above. The change is unrelated to the PR.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Trailing whitespace. I don't know what editor you use, but I think Atom trims trailing whitespace by default, and in Vim you can do :%s/\s\+$//g to remove it.

# initialize

k = size(centers, 2)
Expand Down
6 changes: 3 additions & 3 deletions test/kmeans.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ k = 10
x = rand(m, n)

# non-weighted
r = kmeans(x, k; maxiter=50)
r = kmeans(x, k; maxiter=50, n_init=2)
@test isa(r, KmeansResult{Float64})
@test size(r.centers) == (m, k)
@test length(r.assignments) == n
Expand All @@ -24,7 +24,7 @@ r = kmeans(x, k; maxiter=50)
@test_approx_eq sum(r.costs) r.totalcost

# non-weighted (float32)
r = kmeans(@compat(map(Float32, x)), k; maxiter=50)
r = kmeans(@compat(map(Float32, x)), k; maxiter=50, n_init=2)
@test isa(r, KmeansResult{Float32})
@test size(r.centers) == (m, k)
@test length(r.assignments) == n
Expand All @@ -37,7 +37,7 @@ r = kmeans(@compat(map(Float32, x)), k; maxiter=50)

# weighted
w = rand(n)
r = kmeans(x, k; maxiter=50, weights=w)
r = kmeans(x, k; maxiter=50, weights=w, n_init=2)
@test isa(r, KmeansResult{Float64})
@test size(r.centers) == (m, k)
@test length(r.assignments) == n
Expand Down