Skip to content

Commit

Permalink
Add clustering options
Browse files Browse the repository at this point in the history
  • Loading branch information
Muennighoff committed Jun 20, 2024
1 parent 9665a3c commit 22f0873
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 56 deletions.
36 changes: 27 additions & 9 deletions models.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,18 +100,18 @@ def retrieve(self, query, model_name, topk=1):
docs = [[query, "Title: " + docs[0][0]["title"] + "\n\n" + "Passage: " + docs[0][0]["text"]]]
return docs

def clustering_parallel(self, prompt, model_A, model_B, ncluster=1, ndim="3D"):
def clustering_parallel(self, prompt, model_A, model_B, ncluster=1, ndim="3D", dim_method="PCA", clustering_method="KMeans"):
if model_A == "" and model_B == "":
model_names = random.sample(list(self.model_meta.keys()), 2)
else:
model_names = [model_A, model_B]

with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [executor.submit(self.clustering, prompt, model, ncluster, ndim) for model in model_names]
futures = [executor.submit(self.clustering, prompt, model, ncluster, ndim, dim_method, clustering_method) for model in model_names]
results = [future.result() for future in futures]
return results[0], results[1], model_names[0], model_names[1]

def clustering(self, queries, model_name, ncluster=1, ndim="3D", method="PCA"):
def clustering(self, queries, model_name, ncluster=1, ndim="3D", dim_method="PCA", clustering_method="KMeans"):
"""
Sources:
- https://www.gradio.app/guides/plot-component-for-maps
Expand All @@ -120,8 +120,10 @@ def clustering(self, queries, model_name, ncluster=1, ndim="3D", method="PCA"):
"""
import pandas as pd
import plotly.express as px
from sklearn.cluster import KMeans
from sklearn.cluster import KMeans, MiniBatchKMeans
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from umap import UMAP

if len(queries) == 1:
# No need to do PCA; just return a 1D plot
Expand All @@ -132,24 +134,40 @@ def clustering(self, queries, model_name, ncluster=1, ndim="3D", method="PCA"):
elif (ndim == "2D") or (len(queries) < 4):
model = self.load_model(model_name)
emb = model.encode(queries)
vis_dims = PCA(n_components=2).fit_transform(emb)
if dim_method == "UMAP":
vis_dims = UMAP(n_components=2).fit_transform(emb)
elif dim_method == "TSNE":
vis_dims = TSNE(n_components=2, perplexity=min(30.0, len(queries)//2)).fit_transform(emb)
else:
vis_dims = PCA(n_components=2).fit_transform(emb)
data = {"txt": queries, "x": vis_dims[:, 0], "y": vis_dims[:, 1]}
if ncluster > 1:
data["cluster"] = KMeans(n_clusters=ncluster, n_init='auto', random_state=0).fit_predict(emb).tolist()
if clustering_method == "MiniBatchKMeans":
data["cluster"] = MiniBatchKMeans(n_clusters=ncluster, n_init="auto", random_state=0).fit_predict(emb).tolist()
else:
data["cluster"] = KMeans(n_clusters=ncluster, n_init='auto', random_state=0).fit_predict(emb).tolist()
df = pd.DataFrame(data)
df["txt"] = df["txt"].str[:90]
if ncluster > 1:
fig = px.scatter(df, x="x", y="y", color="cluster", template="plotly_dark", hover_name="txt")
else:
else:
fig = px.scatter(df, x="x", y="y", template="plotly_dark", hover_name="txt")
fig.update_traces(marker=dict(size=12))
else:
model = self.load_model(model_name)
emb = model.encode(queries)
vis_dims = PCA(n_components=3).fit_transform(emb)
if dim_method == "UMAP":
vis_dims = UMAP(n_components=3).fit_transform(emb)
elif dim_method == "TSNE":
vis_dims = TSNE(n_components=3, perplexity=min(30.0, len(queries)//2)).fit_transform(emb)
else:
vis_dims = PCA(n_components=3).fit_transform(emb)
data = {"txt": queries, "x": vis_dims[:, 0], "y": vis_dims[:, 1], "z": vis_dims[:, 2]}
if ncluster > 1:
data["cluster"] = KMeans(n_clusters=ncluster, n_init='auto', random_state=0).fit_predict(emb).tolist()
if clustering_method == "MiniBatchKMeans":
data["cluster"] = MiniBatchKMeans(n_clusters=ncluster, n_init="auto", random_state=0).fit_predict(emb).tolist()
else:
data["cluster"] = KMeans(n_clusters=ncluster, n_init='auto', random_state=0).fit_predict(emb).tolist()
df = pd.DataFrame(data)
df["txt"] = df["txt"].str[:90]
if ncluster > 1:
Expand Down
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
fire
gradio
mteb
plotly
plotly
umap-learn
Loading

0 comments on commit 22f0873

Please sign in to comment.