提问人:John Sall 提问时间:5/11/2019 最后编辑:sentenceJohn Sall 更新时间:1/5/2022 访问量:33951
如何绘制多类分类器的精度和召回率?
How to plot precision and recall of multiclass classifier?
问:
我正在使用 scikit learn,我想绘制精度和召回率曲线。我使用的分类器是 .scikit learn 文档中的所有资源都使用二元分类。另外,我可以绘制多类的 ROC 曲线吗?RandomForestClassifier
另外,我只找到了用于多标签的 SVM,它有一个没有的decision_function
RandomForest
答:
53赞
sentence
5/12/2019
#1
来自 scikit-learn 文档:
- 精确召回:
精确召回率曲线通常用于二元分类,以 研究分类器的输出。为了扩展 多类或 多标签分类,需要对输出进行二值化。 每个标签可以绘制一条曲线,但也可以绘制一条曲线 通过考虑标签的每个元素得出精确召回率曲线 作为二元预测的指标矩阵(微平均)。
ROC 曲线通常用于二元分类,以研究 分类器的输出。为了将 ROC 曲线和 ROC 面积扩展到 多类或多标签分类,需要二值化 输出。每个标签可以绘制一条 ROC 曲线,但也可以绘制一条 ROC 曲线 通过考虑标签指示器的每个元素来绘制 ROC 曲线 矩阵作为二元预测(微平均)。
因此,您应该对输出进行二值化,并考虑每个类的精度召回率和 roc 曲线。此外,您将使用 predict_proba
来获取类概率。
我将代码分为三个部分:
- 常规设置、学习和预测
- 精确召回率曲线
- ROC曲线
1. 常规设置、学习和预测
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.multiclass import OneVsRestClassifier
from sklearn.metrics import precision_recall_curve, roc_curve
from sklearn.preprocessing import label_binarize
import matplotlib.pyplot as plt
#%matplotlib inline
mnist = fetch_openml("mnist_784")
y = mnist.target
y = y.astype(np.uint8)
n_classes = len(set(y))
Y = label_binarize(mnist.target, classes=[*range(n_classes)])
X_train, X_test, y_train, y_test = train_test_split(mnist.data,
Y,
random_state = 42)
clf = OneVsRestClassifier(RandomForestClassifier(n_estimators=50,
max_depth=3,
random_state=0))
clf.fit(X_train, y_train)
y_score = clf.predict_proba(X_test)
2. 精确召回曲线
# precision recall curve
precision = dict()
recall = dict()
for i in range(n_classes):
precision[i], recall[i], _ = precision_recall_curve(y_test[:, i],
y_score[:, i])
plt.plot(recall[i], precision[i], lw=2, label='class {}'.format(i))
plt.xlabel("recall")
plt.ylabel("precision")
plt.legend(loc="best")
plt.title("precision vs. recall curve")
plt.show()
3. ROC曲线
# roc curve
fpr = dict()
tpr = dict()
for i in range(n_classes):
fpr[i], tpr[i], _ = roc_curve(y_test[:, i],
y_score[:, i]))
plt.plot(fpr[i], tpr[i], lw=2, label='class {}'.format(i))
plt.xlabel("false positive rate")
plt.ylabel("true positive rate")
plt.legend(loc="best")
plt.title("ROC curve")
plt.show()
评论
3赞
John Sall
5/12/2019
为什么使用 OneVsRestClassifier?RandomForest 不是已经支持多类了吗?
0赞
John Sall
5/12/2019
当我运行第一部分时,我遇到了这些错误:UserWarning:所有训练示例中都存在标签不 0 UserWarning:所有训练示例中都存在标签不 1 UserWarning:所有训练示例中都存在标签不 2
0赞
sentence
5/12/2019
请注意,警告不是错误。考虑到这一行,您应该在数据集中提供类。在我的示例中,类是 .Y = label_binarize(mnist.target, classes=[*range(n_classes)])
[0,1,2,...,9]
0赞
Sole Galli
7/26/2021
如何使用微平均线创建 PR 曲线或 ROC 曲线?据我所知,如果你有 3 个类,你会得到 3 个概率向量,每个类的概率为 1。然后将观察结果分配给概率最高的班级。也就是说,独立于阈值。但是对于 ROC 和 PR 曲线,您需要一个阈值,那么您将如何进行微平均值呢?如何根据特定阈值将观察任务分配给班级?
0赞
Federico Gentile
1/5/2022
我只是试图反向计算精度并在阈值等于 0 时召回,看看它是否与 classification_report() 函数给出的结果匹配,但它返回的结果却截然不同。我在这里解决这个问题:stats.stackexchange.com/questions/559203/......
评论