图片由作者提供
当我们拥有不平衡的分类数据集时,模型学习决策边界的少数类样本很少。这也会整体影响模型的表现。
1. Google 网络安全证书 - 快速进入网络安全职业轨道。
2. Google 数据分析专业证书 - 提升你的数据分析技能
3. Google IT 支持专业证书 - 支持你的组织的 IT 工作
你可以通过过采样少数类来解决这个问题,你可以通过复制训练数据集中少数类的样本来实现。这将平衡类别分布,但不会提高模型性能,因为它没有为模型提供额外的信息。
那么,你如何同时平衡类别分布并提高模型性能呢?通过使用 SMOTE(合成少数类过采样技术)从少数类中合成新的样本。
SMOTE(合成少数类过采样技术)是一种平衡数据集中类别分布的过采样方法。它选择接近特征空间的少数样本。然后,它在特征空间中的样本之间绘制一条线,并在这条线上生成一个新的样本。
简单来说,算法从少数类中选择一个随机样本,并使用 K 最近邻选择一个随机邻居。在特征空间中,这个合成样本在两个样本之间创建。
使用 SMOTE 有一个缺点,因为在创建合成样本时,它没有考虑多数类。这可能会导致类别之间有很强的重叠。
让我们通过使用 Imbalanced-Learn 库来观察 SMOTE 的实际效果。
%pip install imbalanced-learn
**注意:**我们使用 Deepnote 笔记本来运行实验。
我们将使用 sci-kit learn 的数据集模块中的 make_classification 来创建一个不平衡的分类数据集。
from collections import Counter
from sklearn.datasets import make_classification
from imblearn.over_sampling import SMOTE
import matplotlib.pyplot as plt
# create a binary classification dataset
X, y = make_classification(
n_samples=1000,
n_features=2,
n_redundant=0,
n_clusters_per_class=1,
weights=[0.98],
random_state=125,
)
labels = Counter(y)
print("y labels after oversampling")
print(labels)
正如我们观察到的,样本总数为 1K。970 个属于0标签,只有 30 个属于1。
y labels after oversampling
Counter({0: 970, 1: 30})
然后我们将使用 matplotlib 的 pyplot 来可视化数据集。
正如我们所见,图表上只有少量的黄色点(1),而紫色点则更多。这是一个明显的不平衡数据集的例子。
plt.scatter(X[:, 0], X[:, 1], marker="o", c=y, s=50, edgecolor="k");
在我们使用过采样平衡数据集之前,我们需要为模型性能设定基准。
我们将使用决策树分类模型在数据集上进行 10 折 3 次交叉验证来进行训练和评估。简而言之,我们将在数据集上训练和评估 30 个模型。
RepeatedStratifiedKFold 中的分层意味着每次交叉验证的划分都具有与原始数据集相同的类别分布。
import numpy as np
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import RepeatedStratifiedKFold
from sklearn.tree import DecisionTreeClassifier
model = DecisionTreeClassifier()
cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1)
result = cross_val_score(model, X, y, scoring="roc_auc", cv=cv, n_jobs=-1)
print("Mean AUC: %.3f" % np.mean(result))
我们得到了ROC AUC平均得分为0.626,这个结果相当低。
Mean AUC: 0.626
我们现在将应用过采样方法 SMOTE 来平衡数据集。我们将使用imbalanced-learn的 SMOTE 函数,并提供特征(X)和标签(y)。
over = SMOTE()
X, y = over.fit_resample(X, y)
labels = Counter(y)
print("y labels after oversampling")
print(labels)
现在 0 和 1 标签的样本数量已经平衡,每个标签都有 970 个样本。
y labels after oversampling
Counter({0: 970, 1: 970})
让我们可视化合成平衡的数据集。我们可以清楚地看到黄色和紫色的点数相等。
plt.scatter(X[:, 0], X[:, 1], marker="o", c=y, s=50, edgecolor="k");
我们现在将在合成数据集上训练模型并评估结果。我们保持一切不变,以便将其与基准结果进行比较。
model = DecisionTreeClassifier()
cv = RepeatedStratifiedKFold(n_splits=10, n_repeats=3, random_state=1)
result = cross_val_score(model, X, y, scoring="roc_auc", cv=cv, n_jobs=-1)
print("Mean AUC: %.3f" % np.mean(result))
经过几秒钟的训练后,我们得到了改进的结果或ROC AUC平均得分为0.834。这清楚地表明过采样确实能提高模型性能。
Mean AUC: 0.834
原始的 SMOTE 论文建议将过采样(SMOTE)与多数类的欠采样相结合,因为 SMOTE 在创建新样本时不考虑多数类。少数类的过采样(SMOTE)和多数类的欠采样的组合可以给我们更好的结果。
在本教程中,我们了解了为什么使用 SMOTE 及其工作原理。我们还学习了 imbalanced-learn 库及其如何用于提高模型性能和平衡类分布。
希望你喜欢我的工作,别忘了关注我在社交媒体上,以了解有关数据科学、机器学习、自然语言处理、MLOps、Python、Julia、R 和 Tableau 的内容。
Abid Ali Awan (@1abidaliawan)是一位认证的数据科学专业人士,热衷于构建机器学习模型。目前,他专注于内容创作,并撰写有关机器学习和数据科学技术的技术博客。Abid 拥有技术管理硕士学位和电信工程学士学位。他的愿景是使用图神经网络构建一个 AI 产品,帮助那些在精神疾病方面挣扎的学生。