forked from asayeed/lt2326-h24-wa_modeling
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcluster.py
151 lines (124 loc) · 5.28 KB
/
cluster.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# cluster
import os
import torch
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import tqdm
from wikiart import WikiArtDataset, WikiArtPart2
import json
import argparse
import numpy as np
from sklearn.cluster import KMeans
from sklearn.preprocessing import StandardScaler
from sklearn.decomposition import PCA
import pandas as pd
parser = argparse.ArgumentParser()
parser.add_argument('-c', '--config', help='configuration file',
default='config.json')
args, unknown = parser.parse_known_args()
config = json.load(open(args.config))
device = config['device']
print('Running...')
def get_encodings(testingdir: str = config['testingdir'], device='cpu',
predict: bool = True) -> dict:
"""Loads data from testdir and encodes images with model. Returns a
dictionary with encoded images and their gold class labels."""
# Load test data
testingdataset = WikiArtDataset(testingdir, device)
loader = DataLoader(testingdataset, batch_size=1)
# Initialise encodings dict
encodings_dict = {'encodings': list(), 'y_gold': list()}
# Get encodings and gold labels
for batch in tqdm.tqdm(loader):
X, true_y = batch
encodings_dict['y_gold'].append(true_y)
# Skip encoding step if encodings were pre-loaded, run function
# only to retrieve gold labels.
if predict:
encoded_img, _ = model(X)
encodings_dict['encodings'].append(encoded_img)
return encodings_dict
def format_encodings(encodings: pd.Series) -> list:
"""Flatten and standard scale encodings, return as list of ndarrays."""
print('Formatting encodings...')
encodings_flat = encodings.apply(np.ndarray.flatten)
scaler = StandardScaler()
scaler.fit(np.array(list(encodings_flat)))
encodings_scaled = [scaler.transform(encoding.reshape(1, -1))[0]
for encoding in encodings_flat]
return encodings_scaled
def pca_reduce(encodings_scaled: list) -> np.ndarray:
"""Apply principal component analysis to scaled encodings."""
print('Running PCA...')
# Reduce each encoding to a single value.
pca = PCA(n_components=1)
pca_encodings = pca.fit_transform(encodings_scaled)
return pca_encodings
def cluster(encodings_scaled: list) -> np.ndarray:
"""Run KMeans on flattened and scaled image encodings, return clusters."""
# Fit and predict on formatted encodings.
print('Running KMeans...')
kmeans = KMeans(n_clusters=27)
kmeans.fit(encodings_scaled)
clust_labels = kmeans.predict(np.float64(encodings_scaled))
return clust_labels
def plot_cluster(data_df):
"""Plot PCA-reduced image encodings against predicted clusters, colour-
(and shape to an extent) coded by true art style class."""
# Create wide rectangle plot to spread data points apart
fig = plt.figure()
fig.set_figwidth(15)
ax = fig.add_subplot(111)
# Set marker style variation & color for gold class
m = ['^', 'o', '*', '+'] * 7
cmap = plt.colormaps['hsv']
color_list = cmap(np.linspace(0, 1, 27))
# Plot image's PCA value (x-ax) against its kmeans class (y-ax).
for i in range(len(data_df)):
plt.scatter(data_df['encodings_PCA'][i],
data_df['cluster_pred'][i],
marker=m[int(data_df['y_gold'][i])],
color=color_list[int(data_df['y_gold'][i])]
).set_cmap('hsv')
# Add information to fig
ax.set_title('K-Means clustering on scaled encodings')
ax.set_xlabel('Image PCA value')
ax.set_ylabel('KMeans cluster ID')
cbar = plt.colorbar(label='True class from dataset')
cbar.set_ticks(np.arange(0, 1, (1/7)))
cbar.set_ticklabels([str(idx) for idx in range(0, 27)[::4]])
plt.show()
if __name__=='__main__':
# Load model
model = WikiArtPart2()
model.load_state_dict(torch.load(config['modelfile2'], weights_only=True))
model = model.to(config['device'])
model.eval()
# Get encodings and their true class, either loaded from file, or created anew.
if os.path.isfile(config['encodingsfile']):
print('Loading encodings from file...')
encodings = np.load(config['encodingsfile'])
encodings_dict = get_encodings(predict=False)
encodings_dict['encodings'] = list(encodings)
else:
# Use model to encode images in testdir
print('Encoding images from testdir using model...')
encodings_dict = get_encodings(device=config['device'])
encodings_matrix = np.array(
[encodings_dict['encodings'][i].numpy(force=True)
for i in range(len(encodings_dict['encodings']))])
# Save encodings to file
if config['encodingsfile']:
np.save(config['encodingsfile'], encodings_matrix)
## Cluster and plot encoded images
# Format data in df
data_df = pd.DataFrame(encodings_dict)
# Flatten and scale encodings
scaled_encodings = format_encodings(data_df['encodings']) # 630, 32448
data_df['encodings_scaled'] = scaled_encodings
# cluster encodigs, then perform PCA for plotting against clusters
data_df['cluster_pred'] = cluster(scaled_encodings)
data_df['encodings_PCA'] = pca_reduce(scaled_encodings)
# plot clusters against pca reduced encodings
plot_cluster(data_df)
print('--- end ---')