是否可以在混淆矩阵中添加额外的列?

Is it possible to add extra columns to confusion matrix?

提问人:Jan D.M. 提问时间:4/18/2021 最后编辑:Jan D.M. 更新时间:4/18/2021 访问量:608

问:

我创建了一个多类分类器,现在我想以一种干净的方式显示每个类的混淆矩阵和准确性。

我已经在sklearn中找到了一个函数,它使我有可能显示混淆矩阵:sklearn.metrics.plot_confusion_matrix,但我没有看到添加额外列的方法,我可以在其中放置每个类/行的准确性。

这是有关如何绘制混淆矩阵的示例:

import matplotlib.pyplot as plt  
from sklearn.datasets import make_classification
from sklearn.metrics import plot_confusion_matrix
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
X, y = make_classification(random_state=0)
X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
clf = SVC(random_state=0)
clf.fit(X_train, y_train)
plot_confusion_matrix(clf, X_test, y_test)  
plt.show() 

在下图中,我用油漆画了一些东西来显示我所说的“添加额外列”的含义:

有没有办法更改此示例并添加额外的列? 或者是否有其他库支持我想做的事情?

python-3.x scikit-learn 浮动精度 混淆矩阵

评论

0赞 Alexander L. Hayes 4/18/2021
你想要一个归一化的混淆矩阵吗?PyCM 还具有更通用的混淆矩阵选项。如果两者都不合适,你能画一张图片来表达你的意思吗?我很难理解“添加一个额外的列,我可以在其中放置每个类/行的准确性”可以参考什么。
1赞 Jan D.M. 4/18/2021
@AlexanderL.Hayes:我在 Paint 中添加了一个快速绘图来显示我的意思。

答:

1赞 Alexander L. Hayes 4/18/2021 #1

看起来没有什么东西是开箱即用的,所以我写了一个:

def plot_class_accuracies(plotted_cm, axis, display_labels=None, cmap="viridis"):
    """
    plotted_cm : instance of `ConfusionMatrixDisplay`
        Result of `sklearn.metrics.plot_confusion_matrix`
    axis : matplotlib `AxesSubplot`
        Result of `fig, (ax1, ax2) = plt.subplots(1, 2)`
    display_labels : list of labels or None
        Human-readable class names
    cmap : colormap, optional
        Optional colormap
    """
    cmatrix = plotted_cm.confusion_matrix
    normalized_cmatrix = np.diag(cmatrix) / np.sum(cmatrix, axis=1)
    n_classes = len(normalized_cmatrix)

    cmap_min, cmap_max = plotted_cm.im_.cmap(0), plotted_cm.im_.cmap(256)
    thresh = (normalized_cmatrix.max() + normalized_cmatrix.min()) / 2.0

    if display_labels is None:
        labels = np.arange(n_classes)
    else:
        labels = display_labels

    axis.imshow(
        normalized_cmatrix.reshape(n_classes, 1),
        interpolation="nearest",
        cmap=cmap,
    )

    for i, value in enumerate(normalized_cmatrix):
        color = cmap_min if value > thresh else cmap_max
        axis.text(0, i, format(value, ".2g"), ha="center", va="center", color=color)

    axis.set(
        yticks=np.arange(len(normalized_cmatrix)),
        ylabel="True label",
        xlabel="Class accuracy",
        yticklabels=labels,
    )
    axis.tick_params(
        axis="x", bottom=False, labelbottom=False,
    )
    axis.set_ylim((len(normalized_cmatrix) - 0.5, -0.5))

假设这是在一个文件中:cmatrix.py

from sklearn.linear_model import LogisticRegression
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import plot_confusion_matrix

# Import `plot_class_accuracies` from `cmatrix.py`
from cmatrix import plot_class_accuracies

if __name__ == "__main__":

    class ExampleClassifier(LogisticRegression):
        def __init__(self):
            self.classes_ = None
        def predict(self, X_test):
            self.classes_ = np.unique(X_test)
            return X_test

    X_test = np.array([0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 2])
    y_test = np.array([0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 3, 3, 3])

    fig, (ax1, ax2) = plt.subplots(1, 2)
    clf = ExampleClassifier()

    disp = plot_confusion_matrix(
        clf, X_test, y_test, ax=ax1, cmap=plt.cm.Blues, normalize="true"
    )

    plot_class_accuracies(disp, ax2, cmap=plt.cm.Blues)
    plt.show()

结果:

A confusion matrix is on the left side of the image, and a plot showing class accuracy is on the right side. The diagonal of the left is the same as the right.

下面是一个基于 sklearn 文档中混淆矩阵示例的示例的示例:

import numpy as np
import matplotlib.pyplot as plt
from sklearn import svm, datasets
from sklearn.model_selection import train_test_split
from sklearn.metrics import plot_confusion_matrix

from cmatrix import plot_class_accuracies

iris = datasets.load_iris()
X = iris.data
y = iris.target
class_names = iris.target_names

X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)
classifier = svm.SVC(kernel='linear', C=0.01).fit(X_train, y_train)

fig, (ax1, ax2) = plt.subplots(1, 2)

disp = plot_confusion_matrix(classifier, X_test, y_test,
                             display_labels=class_names,
                             ax=ax1,
                             cmap=plt.cm.Blues)

plot_class_accuracies(disp, ax2, display_labels=class_names, cmap=plt.cm.Blues)

plt.show()

结果:

Same idea as the previous image, but shows example from the iris data set with performance on setosa, versicolor, and virginica.