Skip to content

Latest commit

 

History

History
105 lines (53 loc) · 7 KB

k-means-faster-lower-error-scikit-learn.md

File metadata and controls

105 lines (53 loc) · 7 KB

K-Means 比 Scikit-learn 快 8 倍,误差低 27 倍,代码仅需 25 行

原文:www.kdnuggets.com/2021/01/k-means-faster-lower-error-scikit-learn.html

评论

Jakub Adamczyk,计算机科学学生,Python 和机器学习爱好者


我们的前三大课程推荐

1. 谷歌网络安全证书 - 快速通道进入网络安全职业。

2. 谷歌数据分析专业证书 - 提升您的数据分析技能

3. 谷歌 IT 支持专业证书 - 支持您的组织的 IT


我上一篇关于 faiss 库的文章中,我展示了如何通过使用Facebook 的 faiss 库将 kNN 加速至比 Scikit-learn 快 300 倍,仅需 20 行代码。但我们可以用它做更多的事情,包括更快且更准确的 K-Means 聚类,代码仅需 25 行!

K-Means 是一种迭代算法,它将数据点聚类到 k 个簇中,每个簇由一个均值/中心点(质心)表示。训练从一些初始猜测开始,然后在两个步骤之间交替进行:分配和更新。

在分配阶段,我们将每个点分配到最近的簇(使用点与质心之间的欧几里得距离)。在更新步骤中,我们通过计算当前步骤中分配给该簇的所有点的均值来重新计算每个质心。

聚类的最终质量计算为簇内距离的总和,对于每个簇,我们计算该簇内点与其质心之间的欧几里得距离的总和。这也称为惯性。

对于预测,我们在新点和质心之间执行 1-最近邻搜索(kNN,k = 1)。

Scikit-learn 与 faiss

在这两个库中,我们需要指定算法的超参数:簇的数量、重启的次数(每次从不同的初始猜测开始)以及最大迭代次数。

从例子中可以看出,该算法的核心是搜索最近邻,特别是最近的质心,适用于训练和预测。这也是 faiss 比 Scikit-learn 快几个数量级的地方!它利用了出色的 C++ 实现、尽可能的并发,甚至可以使用 GPU(如果你需要的话)。

使用 faiss 实现 K-Means 聚类

faiss 的一个伟大特点是它提供了安装和构建说明以及带有示例的优秀文档。安装后,我们可以编写实际的聚类代码。代码相当简单,因为我们只是模仿 Scikit-learn 的 API。

重要元素:

  • faiss 有一个内置的Kmeans类专门用于此任务,但其参数的名称与 Scikit-learn 不同(参见文档)。

  • 我们必须确保使用np.float32类型,因为 faiss 仅支持这种类型。

  • kmeans.obj通过训练返回一个误差列表,因此为了得到最终的误差,就像在 Scikit-learn 中一样,我们使用[-1]索引。

  • 预测是通过Index数据结构完成的,这是 faiss 的基本构建块,所有的最近邻搜索都使用它。

  • 在预测中,我们进行 kNN 搜索,k = 1,返回*自.cluster_centers_中的最近质心的索引(索引[1],因为index.search()*返回距离和索引)。

时间和准确性比较

我选择了一些在 Scikit-learn 中可用的流行数据集进行比较。比较了训练和预测时间。为了更容易阅读,我明确写出基于 faiss 的聚类比 Scikit-learn 快多少倍。为了比较误差,我只写了基于 faiss 的聚类实现了多少倍的更低误差(因为数字较大且不太具参考性)。

所有这些时间都是通过*time.process_time()*函数测量的,该函数测量进程时间而非挂钟时间,以获得更准确的结果。结果是 100 次运行的平均值,除了 MNIST,因为 Scikit-learn 花费的时间太长,我只做了 5 次运行。

训练时间(图像由作者提供)。

预测时间(图像由作者提供)。

训练误差(图像由作者提供)。

如我们所见,对于小数据集的 K-Means 聚类(前 4 个数据集),基于 faiss 的版本训练速度较慢且误差较大。对于预测,它的表现普遍更快。

对于较大的 MNIST 数据集,faiss 明显胜出。训练速度快 20.5 倍是巨大的,特别是因为它将时间从将近 3 分钟减少到不到 8 秒!预测速度快 1.5 倍也是不错的。然而,真正的成就的是误差降低了 27.5 倍。这意味着对于较大的现实世界数据集,基于 faiss 的版本准确性高得多。而且这只需要 25 行代码!

基于此:如果你有一个大(至少几千个样本)的现实世界数据集,基于 faiss 的版本就更好。对于小的玩具数据集,Scikit-learn 是更好的选择;然而,如果你有 GPU,GPU 加速的 faiss 版本可能会更快(我还未检查过以确保公平的 CPU 比较)。

总结

通过 25 行代码,我们可以利用 faiss 库为 K-Means 聚类提供显著的速度和准确性提升,适用于合理大小的数据集。如果需要,你可以使用 GPU、多个 GPU 等进一步提高,faiss 文档中对此做了很好的解释。

原文。已获许可转载。

相关:

更多相关话题