本文介绍了科学工具包(Sci-kit)了解如何为混淆矩阵打印标签?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

所以我正在使用sci-kit学习对一些数据进行分类.我有13个不同的类值/将数据分类到的类别.现在,我已经可以使用交叉验证并打印混淆矩阵.但是,它仅显示TP和FP等而没有类标签,因此我不知道哪个类是什么.下面是我的代码和输出:

So I'm using sci-kit learn to classify some data. I have 13 different class values/categorizes to classify the data to. Now I have been able to use cross validation and print the confusion matrix. However, it only shows the TP and FP etc without the classlabels, so I don't know which class is what. Below is my code and my output:

def classify_data(df, feature_cols, file):
    nbr_folds = 5
    RANDOM_STATE = 0
    attributes = df.loc[:, feature_cols]  # Also known as x
    class_label = df['task']  # Class label, also known as y.
    file.write("\nFeatures used: ")
    for feature in feature_cols:
        file.write(feature + ",")
    print("Features used", feature_cols)

    sampler = RandomOverSampler(random_state=RANDOM_STATE)
    print("RandomForest")
    file.write("\nRandomForest")
    rfc = RandomForestClassifier(max_depth=2, random_state=RANDOM_STATE)
    pipeline = make_pipeline(sampler, rfc)
    class_label_predicted = cross_val_predict(pipeline, attributes, class_label, cv=nbr_folds)
    conf_mat = confusion_matrix(class_label, class_label_predicted)
    print(conf_mat)
    accuracy = accuracy_score(class_label, class_label_predicted)
    print("Rows classified: " + str(len(class_label_predicted)))
    print("Accuracy: {0:.3f}%\n".format(accuracy * 100))
    file.write("\nClassifier settings:" + str(pipeline) + "\n")
    file.write("\nRows classified: " + str(len(class_label_predicted)))
    file.write("\nAccuracy: {0:.3f}%\n".format(accuracy * 100))
    file.writelines('\t'.join(str(j) for j in i) + '\n' for i in conf_mat)

#Output
Rows classified: 23504
Accuracy: 17.925%
0   372 46  88  5   73  0   536 44  317 0   200 127
0   501 29  85  0   136 0   655 9   154 0   172 67
0   97  141 78  1   56  0   336 37  429 0   435 198
0   135 74  416 5   37  0   507 19  323 0   128 164
0   247 72  145 12  64  0   424 21  296 0   304 223
0   190 41  36  0   178 0   984 29  196 0   111 43
0   218 13  71  7   52  0   917 139 177 0   111 103
0   215 30  84  3   71  0   1175    11  55  0   102 62
0   257 55  156 1   13  0   322 184 463 0   197 160
0   188 36  104 2   34  0   313 99  827 0   69  136
0   281 80  111 22  16  0   494 19  261 0   313 211
0   207 66  87  18  58  0   489 23  157 0   464 239
0   113 114 44  6   51  0   389 30  408 0   338 315

如您所见,您无法真正知道哪一列是什么,并且打印内容也未对齐",因此很难理解.

As you can see, you can't really know what column is what and the print is also "misaligned" so it's difficult to understand.

有没有办法打印标签?

推荐答案

来自文档,似乎没有这样的选项可以打印混淆矩阵的行和列标签.但是,您可以使用参数 labels = ...

From the doc, it seems that there is no such option to print the rows and column labels of the confusion matrix. However, you can specify the label order using argument labels=...

示例:

from sklearn.metrics import confusion_matrix

y_true = ['yes','yes','yes','no','no','no']
y_pred = ['yes','no','no','no','no','no']
print(confusion_matrix(y_true, y_pred))
# Output:
# [[3 0]
#  [2 1]]
print(confusion_matrix(y_true, y_pred, labels=['yes', 'no']))
# Output:
# [[1 2]
#  [0 3]]

如果要打印带有标签的混淆矩阵,则可以尝试 pandas 并设置 index columns DataFrame .

If you want to print the confusion matrix with labels, you may try pandas and set the index and columns of the DataFrame.

import pandas as pd
cmtx = pd.DataFrame(
    confusion_matrix(y_true, y_pred, labels=['yes', 'no']),
    index=['true:yes', 'true:no'],
    columns=['pred:yes', 'pred:no']
)
print(cmtx)
# Output:
#           pred:yes  pred:no
# true:yes         1        2
# true:no          0        3

unique_label = np.unique([y_true, y_pred])
cmtx = pd.DataFrame(
    confusion_matrix(y_true, y_pred, labels=unique_label),
    index=['true:{:}'.format(x) for x in unique_label],
    columns=['pred:{:}'.format(x) for x in unique_label]
)
print(cmtx)
# Output:
#           pred:no  pred:yes
# true:no         3         0
# true:yes        2         1

这篇关于科学工具包(Sci-kit)了解如何为混淆矩阵打印标签?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

08-13 18:57