已弃用的 confusion_matrix 方法

Deprecated confusion_matrix method

提问人:Tanner Tolman 提问时间:11/14/2023 最后编辑:desertnautTanner Tolman 更新时间:11/15/2023 访问量:24

问:

我正在学习 Udemy 课程,似乎它的某些方面尚未经过编辑以反映最近的更新。我正在尝试使用朴素贝叶斯为分类问题创建混淆矩阵,但我无法获得使用更新函数所需的材料。

数据集是关于航空公司评论的,我已将其精简为评论文本和标记的情绪(积极、消极、中立)。这是相关的代码,我在更改方法而不是我的参数时收到错误。问题出现在“报告”方法中。

from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn.metrics import confusion_matrix,ConfusionMatrixDisplay,classification_report

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=101)
tfidf = TfidfVectorizer(stop_words='english')
tfidf.fit(X_train)
X_train_tfidf = tfidf.transform(X_train)
X_test_tfidf = tfidf.transform(X_test)

nb = MultinomialNB()
nb.fit(X_train_tfidf,y_train)

def report(model):
    preds = model.predict(X_test_tfidf)
    print(classification_report(y_test,preds))
    plot_confusion_matrix(model,X_test_tfidf,y_test)

print("NB MODEL")
report(nb)

输出:

NB MODEL
              precision    recall  f1-score   support

    negative       0.66      0.99      0.79      1817
     neutral       0.79      0.15      0.26       628
    positive       0.89      0.14      0.24       483

    accuracy                           0.67      2928
   macro avg       0.78      0.43      0.43      2928
weighted avg       0.73      0.67      0.59      2928


TypeError: too many positional arguments

尝试使用新方法创建混淆矩阵,但收到错误。

机器学习 scikit-learn 混淆矩阵

评论


答:

0赞 N0bita 11/14/2023 #1

plot_confusion_matrix在 sklearn 中已弃用和删除

但到位 sklearn 添加了ConfusionMatrixDisplay

from matplotlib import pyplot as plt
def report(model):
    preds = model.predict(X_test_tfidf)
    cm = confusion_matrix(y_test, preds)
    print(classification_report(y_test,preds))
    cm_plot = ConfusionMatrixDisplay(confusion_matrix = cm, display_labels=model.classes_)
    cm_plot.plot()
    plt.show()

这些对代码的更改对我来说效果很好。