-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathts_clustering.py
119 lines (94 loc) · 4.09 KB
/
ts_clustering.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from rich import print
from tslearn.clustering import TimeSeriesKMeans
from tslearn.piecewise import SymbolicAggregateApproximation
from tslearn.preprocessing import TimeSeriesScalerMeanVariance
from music import MusicDB
mpl.rcParams["figure.dpi"] = 300
savefig_options = dict(format="png", dpi=300, bbox_inches="tight")
def do_sax_kmeans(params):
df, k = params
# Make sax
sax = SymbolicAggregateApproximation(n_segments=130, alphabet_size_avg=20)
ts_sax = sax.fit_transform(df)
sax_dataset_inv = sax.inverse_transform(ts_sax)
"""
km_dtw = TimeSeriesKMeans(
n_clusters=k, metric="euclidean", max_iter=50, random_state=5138
)
km_dtw.fit(ts_sax)
"""
km_dtw = TimeSeriesKMeans(
n_clusters=k, metric="dtw", max_iter=50, random_state=5138
)
km_dtw.fit(ts_sax)
return (
km_dtw.cluster_centers_,
km_dtw.labels_,
round(km_dtw.inertia_, 4),
)
if __name__ == "__main__":
sns.set()
"""
On the dataset created, compute clustering based on Euclidean/Manhattan and DTW distances and compare the results. To perform the clustering you can choose among different distance functions and clustering algorithms. Remember that you can reduce the dimensionality through approximation. Analyze the clusters and highlight similarities and differences.
"""
musi = MusicDB()
# Kmeans with SAX, grid search, multiprocessing
k = 11
x = musi.df
# Rescale - but why?
scaler = TimeSeriesScalerMeanVariance(mu=0.0, std=1.0) # Rescale time series
ts = scaler.fit_transform(x)
# param_collection.append((x, 4)) to do
"""
# populate results
for one_result in param_collection:
pl_results.append(do_sax_kmeans(one_result))
"""
centroids, labels, inertia = do_sax_kmeans((ts, 8))
musi.feat["ClusterLabel"] = labels
musi.feat = musi.feat.drop(["enc_genre"], axis=1)
plt.plot(np.squeeze(centroids).T)
plt.show()
df_centroids = pd.DataFrame()
df_centroids = df_centroids.append(pd.Series(centroids[0, :, 0]), ignore_index=True)
df_centroids = df_centroids.append(pd.Series(centroids[1, :, 0]), ignore_index=True)
df_centroids = df_centroids.append(pd.Series(centroids[2, :, 0]), ignore_index=True)
df_centroids = df_centroids.append(pd.Series(centroids[3, :, 0]), ignore_index=True)
df_centroids = df_centroids.append(pd.Series(centroids[4, :, 0]), ignore_index=True)
df_centroids = df_centroids.append(pd.Series(centroids[5, :, 0]), ignore_index=True)
df_centroids = df_centroids.append(pd.Series(centroids[6, :, 0]), ignore_index=True)
df_centroids = df_centroids.append(pd.Series(centroids[7, :, 0]), ignore_index=True)
print(df_centroids)
print(musi.feat)
musi.feat = musi.feat.groupby(["genre", "ClusterLabel"], as_index=False).size()
musi.feat = musi.feat[musi.feat["size"] != 0]
musi.feat = musi.feat.sort_values(by=["ClusterLabel"])
musi.feat.index = musi.feat["genre"]
cluster1 = musi.feat[musi.feat["ClusterLabel"] == 1].sort_values(by=["size"])
cluster1["size"].plot(kind="bar", x="genre")
plt.title("Tracks genre distribution - cluster 1")
plt.xticks(rotation=18)
plt.show()
cluster4 = musi.feat[musi.feat["ClusterLabel"] == 4].sort_values(by=["size"])
cluster4["size"].plot(kind="bar", x="genre")
plt.title("Tracks genre distribution - cluster 4")
plt.xticks(rotation=18)
plt.show()
cluster5 = musi.feat[musi.feat["ClusterLabel"] == 5].sort_values(by=["size"])
cluster5["size"].plot(kind="bar", x="genre")
plt.title("Tracks genre distribution - cluster 5")
plt.xticks(rotation=18)
plt.show()
cluster6 = musi.feat[musi.feat["ClusterLabel"] == 6].sort_values(by=["size"])
cluster6["size"].plot(kind="bar", x="genre")
plt.title("Tracks genre distribution - cluster 6")
plt.xticks(rotation=18)
plt.show()
print(musi.feat)
df_centroids.to_csv("centroidiclusters.csv", index=False)
musi.feat.to_csv("musicluster.csv", index=False)