如何绘制混淆矩阵

How to plot a confusion matrix

提问人: 提问时间:7/28/2023 最后编辑:desertnaut 更新时间:10/31/2023 访问量:473

问:

我正在尝试使用混淆矩阵评估我的 renet50 模型,但混淆矩阵如下所示:

matrix = confusion_matrix(y_test, y_pred, normalize="pred")
print(matrix)
    
# output
array([[1, 0],
      [1, 2]], dtype=int64)

我正在使用 scikit-learn 来生成混淆矩阵,并使用 tf keras 来制作模型

但是有什么方法可以绘制/可视化混淆矩阵吗?

我已经尝试使用sklearn.metrics.plot_confusion_matrix(matrix)

还有这个: 如何绘制混淆矩阵,但我得到了这个:

tutorial from stackoverflow

python 机器学习 keras scikit-learn 混淆矩阵

评论

1赞 stateMachine 7/28/2023
from sklearn.metrics import ConfusionMatrixDisplay调用该函数并将矩阵作为参数传递,如下所示: 然后只需绘制它: 并显示它。希望这就是你要找的。查看文档disp = ConfusionMatrixDisplay(confusion_matrix=matrix)disp.plot()

答:

1赞 stateMachine 7/28/2023 #1

包括以下导入:

from sklearn.metrics import ConfusionMatrixDisplay 
from matplotlib import pyplot as plt

现在,调用该函数并将 your 作为参数传递,如下所示:ConfusionMatrixDisplaymatrix

disp = ConfusionMatrixDisplay(confusion_matrix=matrix) 
# Then just plot it: 
disp.plot() 
# And show it: 
plt.show()

此外,您可以在函数中将参数设置为 以在图中显示归一化计数。查看文档以获取进一步的参考和其他可接受参数。normalizeTrueConfusionMatrixDisplay