-
Notifications
You must be signed in to change notification settings - Fork 19
/
Copy pathmini_batch_kmeans_sklearn.py
40 lines (30 loc) · 1.07 KB
/
mini_batch_kmeans_sklearn.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from sklearn.cluster import MiniBatchKMeans
import numpy as np
# for large scale
# For large scale learning (say n_samples > 10k) MiniBatchKMeans is probably much faster than the default batch implementation.
# dataset
X = np.array([[1, 2], [1, 4], [1, 0],
[4, 2], [4, 0], [4, 4],
[4, 5], [0, 1], [2, 2],
[3, 2], [5, 5], [1, -1]])
print(X.shape)# (12, 2)
# ------usage-1------
# manually fit on batches
kmeans = MiniBatchKMeans(n_clusters=2,random_state=0,batch_size=6)
kmeans = kmeans.partial_fit(X[0:6,:])
kmeans = kmeans.partial_fit(X[6:12,:])
# get cluster center
cc = kmeans.cluster_centers_
print(cc)# [[1 1] [3 4]]
# predict
pred = kmeans.predict([[0, 0], [4, 4]])
print(pred)# [0 1]
# ------usage-2------
# fit on the whole data
kmeans = MiniBatchKMeans(n_clusters=2,random_state=0,batch_size=6,max_iter=10).fit(X)
# cluster center
cc = kmeans.cluster_centers_
print(cc)# [[3.95918367 2.40816327] [1.12195122 1.3902439 ]]
# predict
pred = kmeans.predict([[0, 0], [4, 4]])
print(pred)# [1 0]