Hey
如何处理多分类问题
有关多分类问题有两种方法
二分类器只能区分两个类,而多类分类器(也被叫做多项式分类器)可以区分多于两个类。
一些算法(比如随机森林分类器或者朴素贝叶斯分类器)可以直接处理多类分类问题。其他一些算法(比如 SVM 分类器或者线性分类器)则是严格的二分类器。然后,有许多策略可以让你用二分类器去执行多类分类。
举例子,创建一个可以将图片分成 10 类(从 0 到 9)的系统的一个方法是:训练10个二分类器,每一个对应一个数字(探测器 0,探测器 1,探测器 2,以此类推)。然后当你想对某张图片进行分类的时候,让每一个分类器对这个图片进行分类,选出决策分数最高的那个分类器。这叫做“一对所有”(OvA)策略(也被叫做“一对其他”)。
另一个策略是对每一对数字都训练一个二分类器:一个分类器用来处理数字 0 和数字 1,一个用来处理数字 0 和数字 2,一个用来处理数字 1 和 2,以此类推。这叫做“一对一”(OvO)策略。如果有 N 个类。你需要训练N*(N-1)/2
个分类器。对于 MNIST 问题,需要训练 45 个二分类器!当你想对一张图片进行分类,你必须将这张图片跑在全部45个二分类器上。然后看哪个类胜出。OvO 策略的主要优点是:每个分类器只需要在训练集的部分数据上面进行训练。这部分数据是它所需要区分的那两个类对应的数据。
一些算法(比如 SVM 分类器)在训练集的大小上很难扩展,所以对于这些算法,OvO 是比较好的,因为它可以在小的数据集上面可以更多地训练,较之于巨大的数据集而言。但是,对于大部分的二分类器来说,OvA 是更好的选择。
Scikit-Learn 可以探测出你想使用一个二分类器去完成多分类的任务,它会自动地执行 OvA(除了 SVM 分类器,它使用 OvO)。让我们试一下SGDClassifier
.
1 | # y_train, not y_train_5 sgd_clf.fit(X_train, y_train) |
很容易。上面的代码在训练集上训练了一个SGDClassifier
。这个分类器处理原始的目标class,从 0 到 9(y_train
),而不是仅仅探测是否为 5 (y_train_5
)。然后它做出一个判断(在这个案例下只有一个正确的数字)。在幕后,Scikit-Learn 实际上训练了 10 个二分类器,每个分类器都产到一张图片的决策数值,选择数值最高的那个类。
为了证明这是真实的,你可以调用decision_function()
方法。不是返回每个样例的一个数值,而是返回 10 个数值,一个数值对应于一个类。
1 | some_digit_scores = sgd_clf.decision_function([some_digit]) |
最高数值是对应于类别 5 :
1 | np.argmax(some_digit_scores) |
一个分类器被训练好了之后,它会保存目标类别列表到它的属性
classes_
中去,按照值排序。在本例子当中,在classes_
数组当中的每个类的索引方便地匹配了类本身,比如,索引为 5 的类恰好是类别 5 本身。但通常不会这么幸运。
如果你想强制 Scikit-Learn 使用 OvO 策略或者 OvA 策略,你可以使用OneVsOneClassifier
类或者OneVsRestClassifier
类。创建一个样例,传递一个二分类器给它的构造函数。举例子,下面的代码会创建一个多类分类器,使用 OvO 策略,基于SGDClassifier
。
1 | from sklearn.multiclass import OneVsOneClassifier |
训练一个RandomForestClassifier
同样简单:
1 | forest_clf.fit(X_train, y_train) |
这次 Scikit-Learn 没有必要去运行 OvO 或者 OvA,因为随机森林分类器能够直接将一个样例分到多个类别。你可以调用predict_proba()
,得到样例对应的类别的概率值的列表:
1 | forest_clf.predict_proba([some_digit]) |
你可以看到这个分类器相当确信它的预测:在数组的索引 5 上的 0.8,意味着这个模型以 80% 的概率估算这张图片代表数字 5。它也认为这个图片可能是数字 0 或者数字 3,分别都是 10% 的几率。
现在当然你想评估这些分类器。像平常一样,你想使用交叉验证。让我们用cross_val_score()
来评估SGDClassifier
的精度。
1 | 3, scoring="accuracy") cross_val_score(sgd_clf, X_train, y_train, cv= |
在所有测试折(test fold)上,它有 84% 的精度。如果你是用一个随机的分类器,你将会得到 10% 的正确率。所以这不是一个坏的分数,但是你可以做的更好。举例子,简单将输入正则化,将会提高精度到 90% 以上。
1 | from sklearn.preprocessing import StandardScaler |