Python 【机器学习】 进阶 之 【实战案例】MNIST手写数字分类处理 之 [ 训练二分类器 ] [ 性能评估 ] [ 准确率与召回率 ] | 1/2

目录

Python 【机器学习】 进阶 之 【实战案例】MNIST手写数字分类处理 之 [ 训练二分类器 ] [ 性能评估 ] [ 准确率与召回率 ] | 1/2

 一、简单介绍

二、机器学习

1、为什么使用机器学习?

2、机器学习系统的类型,及其对应的学习算法

3、机器学习可利用的开源数据

三、MNIST 手写数字

四、训练一个二分类器

五、对性能的评估

1、实现交叉验证

2、混淆矩阵

3、准确率与召回率

4、准确率/召回率之间的折衷

附录:

一、一些知识点

1、StratifiedKFold

二、源码工程

三、该案例的环境 package 信息如下


 一、简单介绍

Python是一种跨平台的计算机程序设计语言。是一种面向对象的动态类型语言,最初被设计用于编写自动化脚本(shell),随着版本的不断更新和语言新功能的添加,越多被用于独立的、大型项目的开发。Python是一种解释型脚本语言,可以应用于以下领域: Web 和 Internet开发、科学计算和统计、人工智能、教育、桌面界面开发、软件开发、后端开发、网络爬虫。

通过 Python 进行机器学习,开发者可以利用其丰富的工具和库来处理数据、构建模型、评估模型性能,并将模型部署到实际应用中。Python 的易用性和庞大的社区支持使得机器学习在各个领域都得到了广泛的应用和发展。

二、机器学习

机器学习(Machine Learning)是人工智能(AI)的一个分支领域,其核心思想是通过计算机系统的学习和自动化推理,使计算机能够从数据中获取知识和经验,并利用这些知识和经验进行模式识别、预测和决策。机器学习算法能够自动地从数据中学习并改进自己的性能,而无需明确地编程。这一过程涉及对大量输入数据的分析和解释,以识别数据中的模式和趋势,并生成可以应用于新数据的预测模型。

1、为什么使用机器学习?

2、机器学习系统的类型,及其对应的学习算法

3、机器学习可利用的开源数据

(注意:代码执行的时候,可能需要科学上网)

三、MNIST 手写数字

MNIST手写数字分类是使用机器学习或深度学习技术来识别和分类手写数字图像的任务。在Python中,这通常涉及到使用特定的库和框架来加载数据、训练模型、进行预测和评估性能。

MNIST数据集的结构和特点

接下来,我们将会使用 MNIST 这个数据集,它有着 70000 张规格较小的手写数字图片,由美国的高中生和美国人口调查局的职员手写而成。这相当于机器学习当中的“Hello World”,人们无论什么时候提出一个新的分类算法,都想知道该算法在这个数据集上的表现如何。机器学习的初学者迟早也会处理 MNIST 这个数据集。

Scikit-Learn 提供了许多辅助函数,以便于下载流行的数据集。MNIST 是其中一个。下面的代码获取 MNIST:

def sort_by_target(mnist):
    # 创建一个列表,包含训练集中目标标签和对应的索引,然后按目标标签排序
    reorder_train = np.array(sorted([(target, i) for i, target in enumerate(mnist.target[:60000])]))[:, 1]
    
    # 创建一个列表,包含测试集中目标标签和对应的索引,然后按目标标签排序
    reorder_test = np.array(sorted([(target, i) for i, target in enumerate(mnist.target[60000:])]))[:, 1]
    
    # 根据排序后的索引重新排序训练集的数据和目标标签
    mnist.data[:60000] = mnist.data[reorder_train]
    mnist.target[:60000] = mnist.target[reorder_train]
    
    # 根据排序后的索引重新排序测试集的数据和目标标签
    # 注意这里使用了reorder_test索引加上60000,因为测试集的索引是从60000开始的
    mnist.data[60000:] = mnist.data[reorder_test + 60000]
    mnist.target[60000:] = mnist.target[reorder_test + 60000]

import numpy as np

try:
    # 尝试从scikit-learn的fetch_openml函数导入MNIST数据集
    from sklearn.datasets import fetch_openml
    # 使用fetch_openml()函数加载MNIST数据集,指定版本为1,启用缓存,不以数据框形式返回
    mnist = fetch_openml('mnist_784', version=1, cache=True, as_frame=False)
    
    # 将目标(target)列的数据类型转换为np.int8,因为fetch_openml()返回的目标是字符串类型
    mnist.target = mnist.target.astype(np.int8)
    
    # 调用自定义的sort_by_target函数对数据集进行排序
    sort_by_target(mnist) # fetch_openml()返回的数据集是未排序的
except ImportError:
    # 如果fetch_openml无法导入(例如scikit-learn版本不支持),则尝试使用fetch_mldata函数
    from sklearn.datasets import fetch_mldata
    # 使用fetch_mldata()函数加载MNIST数据集
    mnist = fetch_mldata('MNIST original')
    
# 这里的mnist是一个字典,包含"data"和"target"两个键,分别对应数据集和目标
mnist["data"], mnist["target"]

运行结果:

(array([[0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        ...,
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0],
        [0, 0, 0, ..., 0, 0, 0]]),
 array([0, 0, 0, ..., 9, 9, 9], dtype=int8))

一般而言,由 sklearn 加载的数据集有着相似的字典结构,这包括:

  • DESCR键描述数据集
  • data键存放一个数组,数组的一行表示一个样例,一列表示一个特征
  • target键存放一个标签数组

让我们看一下这些数组:

# 获取mnist数据集中图像数据的形状
# mnist.data是一个NumPy数组,shape属性返回一个元组,表示数组的维度大小
# 对于MNIST数据集,mnist.data的形状通常是(60000, 784)或(60000,),具体取决于数据是否被展平
mnist.data.shape

运行结果:

(70000, 784)
# 从MNIST数据集中分离特征和标签
# mnist["data"]是所有图像的像素值,mnist["target"]是对应的标签
X, y = mnist["data"], mnist["target"]

# 打印特征数据的形状
# X.shape将返回一个元组,表示X的维度大小,通常对于MNIST数据集,形状是(60000, 784)
# 其中60000是样本数量,784是每个样本的像素值数量(28x28像素的图像被展平)
print(X.shape)

# 打印标签数据的形状
# y.shape将返回一个元组,表示y的维度大小
# 对于MNIST数据集,如果标签是一维数组,形状是(60000,)
# 如果标签是二维的(例如,one-hot编码),形状可能是(60000, 10)
print(y.shape)

运行结果:

(70000, 784)
(70000,)

MNIST 有 70000 张图片,每张图片有 784 个特征。这是因为每个图片都是28*28像素的,并且每个像素的值介于 0~255 之间。让我们看一看数据集的某一个数字。你只需要将某个实例的特征向量,reshape28*28的数组,然后使用 Matplotlib 的imshow函数展示出来。

# 导入matplotlib的pyplot模块,这是用于绘图的主要模块
# 给它一个别名plt,这样我们就可以使用plt来访问这个模块的所有函数
import matplotlib.pyplot as plt

# 导入必要的其他库,例如numpy,如果还没有导入的话
import numpy as np

# 假设X是已经加载的MNIST数据集中的特征矩阵
# some_digit是X中的一个784维向量,表示一张28x28像素的图像
some_digit = X[36000]

# 使用reshape方法将784维向量重新排列成28x28的二维数组
some_digit_image = some_digit.reshape(28, 28)

# 使用matplotlib的imshow函数显示图像
# cmap设置为binary,这通常用于显示黑白图像
# interpolation设置为"nearest",表示使用最近邻插值方法
plt.imshow(some_digit_image, cmap=plt.cm.binary, interpolation="nearest")

# 关闭图像的坐标轴显示
plt.axis("off")

# 保存图像,函数savefig用于保存当前的绘图
# 需要确保save_fig函数已经被定义,并且能够正确保存图像
plt.savefig("images/some_digit_plot.png",bbox_inches='tight')

# 显示图像
plt.show()

运行结果:

Python 【机器学习】 进阶 之 【实战案例】MNIST手写数字分类处理 之 [ 训练二分类器 ] [ 性能评估 ] [ 准确率与召回率 ] | 1/2-LMLPHP

这看起来像个 5,实际上它的标签告诉我们:

# 访问MNIST数据集中第36000个样本的标签
# y是一个数组,包含了数据集中所有样本的标签
# 假设y是一个一维数组,索引36000将返回该索引处的标签值
y[36000]

运行结果:

np.int8(5)
def plot_digit(data):
    # 将输入数据reshape成28x28的二维数组,以恢复为图像的原始尺寸
    # 这个操作假设输入数据是一个784维的一维数组
    image = data.reshape(28, 28)
    
    # 使用matplotlib的imshow函数显示图像
    # cmap设置为binary,这通常用于显示黑白图像
    # interpolation设置为"nearest",表示使用最近邻插值方法
    plt.imshow(image, cmap=plt.cm.binary, interpolation="nearest")
    
    # 关闭图像的坐标轴显示,使图像看起来更清晰
    plt.axis("off")


# EXTRA
def plot_digits(instances, images_per_row=10, **options):
    # 定义图像的尺寸,MNIST图像是28x28像素
    size = 28
    
    # 限制每行显示的图像数量,不超过实例的数量
    images_per_row = min(len(instances), images_per_row)
    
    # 将每个实例重塑成28x28的二维数组
    images = [instance.reshape(size, size) for instance in instances]
    
    # 计算需要显示的行数
    n_rows = (len(instances) - 1) // images_per_row + 1
    
    # 初始化一个列表,用于存储每一行的图像
    row_images = []
    
    # 计算需要添加的空白图像数量,以填满最后一行
    n_empty = n_rows * images_per_row - len(instances)
    
    # 添加空白图像以确保最后一行完整
    images.append(np.zeros((size, size * n_empty)))
    
    # 遍历每一行,将图像按行拼接起来
    for row in range(n_rows):
        # 获取当前行的图像列表
        rimages = images[row * images_per_row : (row + 1) * images_per_row]
        # 将当前行的图像水平拼接
        row_images.append(np.concatenate(rimages, axis=1))
    
    # 垂直拼接所有行的图像,形成一个大的图像
    image = np.concatenate(row_images, axis=0)
    
    # 使用matplotlib的imshow函数显示拼接后的图像
    plt.imshow(image, cmap=plt.cm.binary, **options)
    
    # 关闭图像的坐标轴显示
    plt.axis("off")
# 设置matplotlib图形的尺寸为9x9英寸
plt.figure(figsize=(9, 9))

# 使用numpy的r_函数来创建一个新的数组,这个数组包含从X中每隔一定步长选取的图像
# 这里选取了三组图像,每组图像的起始索引和步长都不同
# X是MNIST数据集的特征矩阵,其中包含了所有图像的像素值
example_images = np.r_[X[:12000:600], X[13000:30600:600], X[30600:60000:590]]

# 调用plot_digits函数来绘制这些图像
# 每行显示10个图像
plot_digits(example_images, images_per_row=10)

# 假设 savefig 是一个已经定义的函数,用于保存当前的图形
# 这里需要确保save_fig函数能够正确执行,并且有保存图形的权限
plt.savefig("images/more_digits_plot.png",bbox_inches='tight')

# 显示图形
plt.show()

运行结果:

图 3-1 展示了一些来自 MNIST 数据集的图片。当你处理更加复杂的分类任务的时候,它会让你更有感觉。

先等一下!你总是应该先创建测试集,并且在验证数据之前先把测试集晾到一边。MNIST 数据集已经事先被分成了一个训练集(前 60000 张图片)和一个测试集(最后 10000 张图片)

# 将特征数据X分割为训练集和测试集
# X_train包含前60000个样本的特征数据
X_train = X[:60000]

# X_test包含后10000个样本的特征数据(假设X的总样本数为70000)
X_test = X[60000:]

# 将标签数据y分割为训练集和测试集
# y_train包含前60000个样本的标签
y_train = y[:60000]

# y_test包含后10000个样本的标签
y_test = y[60000:]

让我们打乱训练集。这可以保证交叉验证的每一折都是相似(你不会期待某一折缺少某类数字)。而且,一些学习算法对训练样例的顺序敏感,当它们在一行当中得到许多相似的样例,这些算法将会表现得非常差。打乱数据集将保证这种情况不会发生。

# 导入NumPy库,并给它一个常用的别名np
import numpy as np

# 使用np.random.permutation生成一个从0到59999的随机排列
# 这个排列将用于随机打乱训练数据集的索引
shuffle_index = np.random.permutation(60000)

# 使用上面生成的随机索引来打乱X_train和y_train
# 这样,每个特征数据X_train[i]和对应的标签y_train[i]都会根据shuffle_index进行重新排序
# 这有助于防止模型训练过程中的过拟合,并确保训练数据的多样性
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]

四、训练一个二分类器

现在我们简化一下问题,只尝试去识别一个数字,比如说,数字 5。这个“数字 5 检测器”就是一个二分类器,能够识别两类别,“是 5”和“非 5”。让我们为这个分类任务创建目标向量:

# 使用比较操作符'=='和训练集标签y_train创建一个布尔数组y_train_5
# 这个数组将包含与y_train中每个元素对应的布尔值,如果标签等于5,则为True,否则为False
y_train_5 = (y_train == 5)

# 使用比较操作符'=='和测试集标签y_test创建一个布尔数组y_test_5
# 这个数组将包含与y_test中每个元素对应的布尔值,如果标签等于5,则为True,否则为False
y_test_5 = (y_test == 5)

现在让我们挑选一个分类器去训练它。用随机梯度下降分类器 SGD,是一个不错的开始。使用 Scikit-Learn 的SGDClassifier类。这个分类器有一个好处是能够高效地处理非常大的数据集。这部分原因在于 SGD 一次只处理一条数据,这也使得 SGD 适合在线学习(online learning)。我们在稍后会看到它。让我们创建一个SGDClassifier和在整个数据集上训练它。

from sklearn.linear_model import SGDClassifier
import numpy as np

# 创建SGDClassifier实例,设置一些参数
# max_iter=5:设置迭代次数为5
# tol=0:设置容忍度为0,这意味着训练将不会因收敛而提前停止
# random_state=42:设置随机状态为42,以确保结果的可重复性
sgd_clf = SGDClassifier(max_iter=5, tol=0, random_state=42)

# 使用fit方法训练分类器
# X_train是训练数据的特征集
# y_train_5是一个布尔数组,表示训练数据的标签是否为5
# 这里只训练模型以识别数字5,忽略了其他标签的样本
sgd_clf.fit(X_train, y_train_5)

运行结果:

Python 【机器学习】 进阶 之 【实战案例】MNIST手写数字分类处理 之 [ 训练二分类器 ] [ 性能评估 ] [ 准确率与召回率 ] | 1/2-LMLPHP

现在你可以用它来查出数字 5 的图片。

# 使用训练好的SGDClassifier模型sgd_clf对单个图像some_digit进行预测
# some_digit是一个784维的NumPy数组,表示一个已经预处理和标准化的手写数字图像
# 这个数组需要与训练数据具有相同的格式和特征
# predict方法将返回模型预测的标签
prediction = sgd_clf.predict([some_digit])
prediction

运行结果:

array([ True])

分类器猜测这个数字代表 5(True)。看起来在这个例子当中,它猜对了。现在让我们评估这个模型的性能。

五、对性能的评估

评估一个分类器,通常比评估一个回归器更加玄学。所以我们将会花大量的篇幅在这个话题上。有许多量度性能的方法,所以拿来一杯咖啡和准备学习许多新概念和首字母缩略词吧。

评估一个模型的好方法是使用交叉验证。

1、实现交叉验证

在交叉验证过程中,有时候你会需要更多的控制权,相较于函数cross_val_score()或者其他相似函数所提供的功能。这种情况下,你可以实现你自己版本的交叉验证。事实上它相当简单。以下代码粗略地做了和cross_val_score()相同的事情,并且输出相同的结果。

from sklearn.model_selection import StratifiedKFold
from sklearn.base import clone

# 创建StratifiedKFold对象,用于分层K折交叉验证
# n_splits=3:设置折数为3
# random_state=42:设置随机状态为42,确保结果的可重复性
# shuffle=True:在分割前对数据进行打乱
skfolds = StratifiedKFold(n_splits=3, random_state=42, shuffle=True)

# 遍历StratifiedKFold生成的每个训练集和测试集索引
for train_index, test_index in skfolds.split(X_train, y_train_5):
    # 使用clone函数复制原始的SGDClassifier模型
    clone_clf = clone(sgd_clf)

    # 根据生成的索引获取当前折的训练集和测试集数据
    X_train_folds = X_train[train_index]
    y_train_folds = y_train_5[train_index]  # 注意这里y_train_5已经是一个布尔数组
    X_test_fold = X_train[test_index]
    y_test_fold = y_train_5[test_index]     # 同上

    # 使用当前折的训练集数据训练复制的模型
    clone_clf.fit(X_train_folds, y_train_folds)

    # 使用训练好的模型对当前折的测试集进行预测
    y_pred = clone_clf.predict(X_test_fold)

    # 计算预测正确的数量
    n_correct = sum(y_pred == y_test_fold)

    # 打印当前折的准确率,即预测正确的样本数占总样本的比例
    print(n_correct / len(y_pred))

运行结果:

Python 【机器学习】 进阶 之 【实战案例】MNIST手写数字分类处理 之 [ 训练二分类器 ] [ 性能评估 ] [ 准确率与召回率 ] | 1/2-LMLPHP

让我们使用cross_val_score()函数来评估SGDClassifier模型,同时使用 K 折交叉验证,此处让k=3。记住:K 折交叉验证意味着把训练集分成 K 折(此处 3 折),然后使用一个模型对其中一折进行预测,对其他折进行训练。

from sklearn.model_selection import cross_val_score

# 使用cross_val_score函数执行交叉验证
# sgd_clf是之前创建并配置好的SGDClassifier模型
# X_train是训练数据的特征集
# y_train_5是布尔数组,表示训练数据标签是否为5
# cv=3:设置交叉验证的折数为3
# scoring="accuracy":设置评分标准为准确率
# 这个函数将返回模型在每个折上的准确率得分
scores = cross_val_score(sgd_clf, X_train, y_train_5, cv=3, scoring="accuracy")
scores

运行结果:

array([0.93745, 0.96445, 0.9504 ])

哇!在交叉验证上有大于 93% 的精度(accuracy)?这看起来很令人吃惊。先别高兴,让我们来看一个非常笨的分类器去分类,看看其在“非 5”这个类上的表现。

from sklearn.base import BaseEstimator

# 定义一个名为Never5Classifier的类,它继承自scikit-learn的BaseEstimator类
class Never5Classifier(BaseEstimator):
    
    def fit(self, X, y=None):
        # fit方法不接受任何参数,不执行任何操作
        # 这个方法是必需的,因为它是BaseEstimator接口的一部分
        pass
    
    def predict(self, X):
        # predict方法接受输入数据X,并返回一个布尔数组
        # 这个数组的长度与输入数据X的样本数相同,但总是预测为0(False)
        # 即这个分类器总是预测结果为非5(非True)
        # dtype=bool确保返回数组的数据类型为布尔型
        return np.zeros((len(X), 1), dtype=bool)

你能猜到这个模型的精度吗?揭晓谜底:

# 创建Never5Classifier类的实例,命名为never_5_clf
never_5_clf = Never5Classifier()

# 使用cross_val_score函数执行交叉验证
# never_5_clf是Never5Classifier的实例,它总是预测0(即非5)
# X_train是训练数据的特征集
# y_train_5是布尔数组,表示训练数据标签是否为5
# cv=3:设置交叉验证的折数为3
# scoring="accuracy":设置评分标准为准确率
# 这个函数将返回never_5_clf模型在每个折上的准确率得分
scores = cross_val_score(never_5_clf, X_train, y_train_5, cv=3, scoring="accuracy")
scores

运行结果:

array([0.91085, 0.90765, 0.91045])

没错,这个笨的分类器也有 90% 的精度。这是因为只有 10% 的图片是数字 5,所以你总是猜测某张图片不是 5,你也会有 90% 的可能性是对的。

这证明了为什么精度通常来说不是一个好的性能度量指标,特别是当你处理有偏差的数据集,比方说其中一些类比其他类频繁得多。

2、混淆矩阵

对分类器来说,一个好得多的性能评估指标是混淆矩阵。大体思路是:输出类别 A 被分类成类别 B 的次数。举个例子,为了知道分类器将 5 误分为 3 的次数,你需要查看混淆矩阵的第五行第三列。

为了计算混淆矩阵,首先你需要有一系列的预测值,这样才能将预测值与真实值做比较。你或许想在测试集上做预测。但是我们现在先不碰它。(记住,只有当你处于项目的尾声,当你准备上线一个分类器的时候,你才应该使用测试集)。相反,你应该使用cross_val_predict()函数

from sklearn.model_selection import cross_val_predict

# 使用cross_val_predict函数执行交叉验证预测
# sgd_clf是之前创建并配置好的SGDClassifier模型
# X_train是训练数据的特征集
# y_train_5是布尔数组,表示训练数据标签是否为5
# cv=3:设置交叉验证的折数为3
# 这个函数将返回在每个折上使用训练数据预测的标签
y_train_pred = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)

就像 cross_val_score()cross_val_predict()也使用 K 折交叉验证。它不是返回一个评估分数,而是返回基于每一个测试折做出的一个预测值。这意味着,对于每一个训练集的样例,你得到一个干净的预测(“干净”是说一个模型在训练过程当中没有用到测试集的数据)。

现在使用 confusion_matrix()函数,你将会得到一个混淆矩阵。传递目标类(y_train_5)和预测类(y_train_pred)给它。

from sklearn.metrics import confusion_matrix

# 使用confusion_matrix函数计算模型预测结果y_train_pred与实际标签y_train_5之间的混淆矩阵
# y_train_5是训练数据的实际标签的布尔数组,表示标签是否为5
# y_train_pred是使用cross_val_predict得到的预测结果
# 这个函数将返回一个混淆矩阵,显示了预测为正类(即5)和负类(非5)的样本数量
# 混淆矩阵可以帮助我们了解模型在哪些情况下容易混淆,以及它的召回率和精确度
conf_mat = confusion_matrix(y_train_5, y_train_pred)
conf_mat

运行结果:

array([[53352,  1227],
       [ 1727,  3694]])

混淆矩阵中的每一行表示一个实际的类, 而每一列表示一个预测的类。该矩阵的第一行认为“非 5”(反例)中的 53352 张被正确归类为 “非 5”(他们被称为真反例,true negatives), 而其余 1227 被错误归类为"是 5" (假正例,false positives)。第二行认为“是 5” (正例)中的 1727 被错误地归类为“非 5”(假反例,false negatives),其余 3694 正确分类为 “是 5”类(真正例,true positives)。一个完美的分类器将只有真反例和真正例,所以混淆矩阵的非零值仅在其主对角线(左上至右下)。

# 假设y_train_5是训练数据的实际标签的布尔数组
# y_train_perfect_predictions是完美预测结果的数组,这里直接将实际标签赋值给它
# 这表示我们假设模型的预测是完美的,即预测结果与实际标签完全一致
y_train_perfect_predictions = y_train_5

# 使用confusion_matrix函数计算完美预测情况下的混淆矩阵
# 这里,真实的标签y_train_5和预测的完美标签y_train_perfect_predictions是相同的
# 因此,混淆矩阵将只包含真正例(TP)和真负例(TN),不会有假正例(FP)或假负例(FN)
# 这个混淆矩阵提供了模型在没有误差时的理论上限性能
conf_mat_perfect = confusion_matrix(y_train_5, y_train_perfect_predictions)
conf_mat_perfect

运行结果:

array([[54579,     0],
       [    0,  5421]])

混淆矩阵可以提供很多信息。有时候你会想要更加简明的指标。一个有趣的指标是正例预测的精度,也叫做分类器的准确率(precision)。

Python 【机器学习】 进阶 之 【实战案例】MNIST手写数字分类处理 之 [ 训练二分类器 ] [ 性能评估 ] [ 准确率与召回率 ] | 1/2-LMLPHP

公式 3-1 准确率

其中TP是真正例的数目,FP是假正例的数目。

想要一个完美的准确率,一个平凡的方法是构造一个单一正例的预测和确保这个预测是正确的(precision = 1/1 = 100%)。但是这什么用,因为分类器会忽略所有样例,除了那一个正例。所以准确率一般会伴随另一个指标一起使用,这个指标叫做召回率(recall),也叫做敏感度(sensitivity)或者真正例率(true positive rate, TPR)。这是正例被分类器正确探测出的比率。

Python 【机器学习】 进阶 之 【实战案例】MNIST手写数字分类处理 之 [ 训练二分类器 ] [ 性能评估 ] [ 准确率与召回率 ] | 1/2-LMLPHP

公式 3-2 Recall

FN是假反例的数目。

如果你对于混淆矩阵感到困惑,图 3-2 将对你有帮助

3、准确率与召回率

Scikit-Learn 提供了一些函数去计算分类器的指标,包括准确率和召回率。

from sklearn.metrics import precision_score, recall_score

# 使用precision_score函数计算模型预测的精确度
# y_train_5是训练数据的实际标签的布尔数组,表示标签是否为5
# y_train_pred是使用cross_val_predict得到的预测结果
# 精确度(precision)是在所有被模型预测为正类(即5)的样本中,实际为正类的比例
# 这个指标有助于评估模型在预测正类时的准确性
precision = precision_score(y_train_5, y_train_pred)
precision

运行结果:

np.float64(0.7506604348709612)
3694/(3694+1227)

运行结果:

0.7506604348709612
from sklearn.metrics import recall_score

# 使用recall_score函数计算模型预测的召回率
# y_train_5是训练数据的实际标签的布尔数组,表示标签是否为5
# y_train_pred是使用cross_val_predict得到的预测结果
# 召回率(recall)是在所有实际为正类(即5)的样本中,被模型正确预测为正类的比例
# 这个指标有助于评估模型捕捉所有正类样本的能力,即使在数据不平衡的情况下
recall = recall_score(y_train_5, y_train_pred)
recall

运行结果:

np.float64(0.6814240914960339)
3694/(3694+1727)

运行结果:

0.6814240914960339

当你去观察精度的时候,你的“数字 5 探测器”看起来还不够好。当它声明某张图片是 5 的时候,它只有 75% 的可能性是正确的。而且,它也只检测出“是 5”类图片当中的 68%。

通常结合准确率和召回率会更加方便,这个指标叫做“F1 值”,特别是当你需要一个简单的方法去比较两个分类器的优劣的时候。F1 值是准确率和召回率的调和平均。普通的平均值平等地看待所有的值,而调和平均会给小的值更大的权重。所以,要想分类器得到一个高的 F1 值,需要召回率和准确率同时高。

Python 【机器学习】 进阶 之 【实战案例】MNIST手写数字分类处理 之 [ 训练二分类器 ] [ 性能评估 ] [ 准确率与召回率 ] | 1/2-LMLPHP

公式 3-3 F1 值

为了计算 F1 值,简单调用f1_score()

from sklearn.metrics import f1_score

# 使用f1_score函数计算模型预测的F1分数
# y_train_5是训练数据的实际标签的布尔数组,表示标签是否为5
# y_train_pred是使用cross_val_predict得到的预测结果
# F1分数是精确度和召回率的调和平均数,它在两者之间取得一个平衡
# 当精确度和召回率相差较大时,F1分数能提供一个比单独的精确度或召回率更全面的性能指标
f1 = f1_score(y_train_5, y_train_pred)
f1

运行结果:

np.float64(0.7143685940823825)

F1 支持那些有着相近准确率和召回率的分类器。这不会总是你想要的。有的场景你会绝大程度地关心准确率,而另外一些场景你会更关心召回率。举例子,如果你训练一个分类器去检测视频是否适合儿童观看,你会倾向选择那种即便拒绝了很多好视频、但保证所保留的视频都是好(高准确率)的分类器,而不是那种高召回率、但让坏视频混入的分类器(这种情况下你或许想增加人工去检测分类器选择出来的视频)。另一方面,加入你训练一个分类器去检测监控图像当中的窃贼,有着 30% 准确率、99% 召回率的分类器或许是合适的(当然,警卫会得到一些错误的报警,但是几乎所有的窃贼都会被抓到)。

不幸的是,你不能同时拥有两者。增加准确率会降低召回率,反之亦然。这叫做准确率与召回率之间的折衷。

4、准确率/召回率之间的折衷

为了弄懂这个折衷,我们看一下SGDClassifier是如何做分类决策的。对于每个样例,它根据决策函数计算分数,如果这个分数大于一个阈值,它会将样例分配给正例,否则它将分配给反例。图 3-3 显示了几个数字从左边的最低分数排到右边的最高分。假设决策阈值位于中间的箭头(介于两个 5 之间):您将发现 4 个真正例(数字 5)和一个假正例(数字 6)在该阈值的右侧。因此,使用该阈值,准确率为 80%(4/5)。但实际有 6 个数字 5,分类器只检测 4 个, 所以召回是 67%(4/6)。现在,如果你 提高阈值(移动到右侧的箭头),假正例(数字 6)成为一个真反例,从而提高准确率(在这种情况下高达 100%),但一个真正例 变成假反例,召回率降低到 50%。相反,降低阈值可提高召回率、降低准确率。

Scikit-Learn 不让你直接设置阈值,但是它给你提供了设置决策分数的方法,这个决策分数可以用来产生预测。它不是调用分类器的predict()方法,而是调用decision_function()方法。这个方法返回每一个样例的分数值,然后基于这个分数值,使用你想要的任何阈值做出预测。

# 使用SGDClassifier模型的decision_function方法为单个图像some_digit生成分数
# some_digit是一个784维的NumPy数组,表示一个手写数字图像
# decision_function方法返回的是一个数组,其中包含模型对输入样本的原始决策分数
# 对于SGDClassifier,这个分数表示模型预测样本为正类(例如数字5)的相对可能性
# 返回的y_scores将是一个一维数组,包含对应于每个类别的分数
y_scores = sgd_clf.decision_function([some_digit])

# 打印y_scores,这将显示模型对some_digit属于每个类别的评分
y_scores

运行结果:

array([192711.73865475])
# 设置一个阈值为0
# 在使用决策函数分数进行预测时,阈值用于确定样本应该被分类到哪一类
# 如果模型是用于二分类问题,且决策函数的分数表示为正类的概率或对数几率,阈值通常设为0
threshold = 0

# 使用决策函数的分数y_scores和阈值threshold来生成预测
# y_some_digit_pred是一个布尔数组,表示模型预测的样本是否属于正类
# 如果y_scores中的分数大于阈值,则预测为正类(例如数字5),否则为负类(非5)
# 这里因为我们只考虑了是否为数字5,所以结果是二元的(True或False)
y_some_digit_pred = (y_scores > threshold)
y_some_digit_pred

运行结果:

array([ True])

SGDClassifier用了一个等于 0 的阈值,所以前面的代码返回了跟predict()方法一样的结果(都返回了true)。让我们提高这个阈值:

# 设置一个较高的阈值200000
# 在实际应用中,阈值的选择取决于模型输出的分数范围和问题的具体需求
# 较高的阈值意味着只有当模型对样本属于正类(例如数字5)的预测非常有信心时
# 才会将其分类为正类,这可能会降低召回率但提高精确度
threshold = 200000

# 使用决策函数的分数y_scores和设置的阈值threshold来生成预测
# y_some_digit_pred是一个布尔数组,根据y_scores是否大于阈值来预测样本是否属于正类
# 如果y_scores中的分数大于阈值,则预测为正类(例如数字5)
# 否则,预测为负类(非5)
y_some_digit_pred = (y_scores > threshold)

# 打印y_some_digit_pred,这将显示模型对some_digit图像的预测结果
# 结果将是True或False,表示模型是否预测some_digit为正类
y_some_digit_pred

运行结果:

array([False])

这证明了提高阈值会降调召回率。这个图片实际就是数字 5,当阈值等于 0 的时候,分类器可以探测到这是一个 5,当阈值提高到 20000 的时候,分类器将不能探测到这是数字 5。

那么,你应该如何使用哪个阈值呢?首先,你需要再次使用cross_val_predict()得到每一个样例的分数值,但是这一次指定返回一个决策分数,而不是预测值。

# 使用cross_val_predict函数执行交叉验证预测
# sgd_clf是之前创建并配置好的SGDClassifier模型
# X_train是训练数据的特征集
# y_train_5是布尔数组,表示训练数据标签是否为5
# cv=3:设置交叉验证的折数为3
# method="decision_function":指定cross_val_predict使用decision_function方法
#   来获取模型的决策函数分数,而不是直接的预测标签
#   这些分数表示模型预测样本属于各个类别的原始分数
y_scores = cross_val_predict(sgd_clf, X_train, y_train_5, cv=3, method="decision_function")

现在有了这些分数值。对于任何可能的阈值,使用precision_recall_curve(),你都可以计算准确率和召回率:

from sklearn.metrics import precision_recall_curve

# 使用precision_recall_curve函数计算精确度-召回率曲线
# y_train_5是训练数据的实际标签的布尔数组,表示标签是否为5
# y_scores是使用cross_val_predict得到的决策函数分数
# 这个函数返回三个数组:
# precisions:在不同的阈值下计算的精确度值
# recalls:在不同的阈值下计算的召回率值
# thresholds:与每个精确度和召回率值对应的阈值
precisions, recalls, thresholds = precision_recall_curve(y_train_5, y_scores)

最后,你可以使用 Matplotlib 画出准确率和召回率(图 3-4),这里把准确率和召回率当作是阈值的一个函数。

import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve

# 自定义函数plot_precision_recall_vs_threshold,用于绘制精确度-召回率曲线
def plot_precision_recall_vs_threshold(precisions, recalls, thresholds):
    # 绘制精确度曲线,使用蓝色虚线表示
    # precisions[:-1]:去掉最后一个点,因为它可能是一个不稳定的值
    plt.plot(thresholds, precisions[:-1], "b--", label="Precision", linewidth=2)
    
    # 绘制召回率曲线,使用绿色实线表示
    # recalls[:-1]:同样去掉最后一个点
    plt.plot(thresholds, recalls[:-1], "g-", label="Recall", linewidth=2)
    
    # 设置x轴标签和字体大小
    plt.xlabel("Threshold", fontsize=16)
    
    # 添加图例,设置位置和字体大小
    plt.legend(loc="upper left", fontsize=16)
    
    # 设置y轴显示的范围在0到1之间
    plt.ylim([0, 1])

# 设置绘图的大小为8x4英寸
plt.figure(figsize=(8, 4))

# 调用自定义函数绘制精确度-召回率曲线
plot_precision_recall_vs_threshold(precisions, recalls, thresholds)

# 设置x轴显示的范围,根据实际的阈值范围进行调整
plt.xlim([-700000, 700000])

# 假设 savefig 是一个已经定义的函数,用于保存当前的图形
# 这里需要确保save_fig函数能够正确执行,并且有保存图形的权限
plt.savefig("images/precision_recall_vs_threshold_plot.png",bbox_inches='tight')

# 显示图形
plt.show()

运行结果:

现在你可以选择适合你任务的最佳阈值。

# 使用cross_val_predict函数和decision_function方法获得的y_scores进行预测
# 阈值设为0,意味着任何决策函数分数大于或等于0的样本都将被预测为正类
# y_train_pred是早前使用cross_val_predict(sgd_clf, X_train, y_train_5, cv=3)得到的预测结果
# 此处的比较操作(y_scores > 0)会为每个样本生成一个布尔数组,表示样本是否根据决策函数分数被预测为正类

# all()函数用于检查比较操作结果的布尔数组是否全为True
# 如果all()返回True,则表示使用阈值0得到的预测结果与y_train_pred完全一致
# 如果有任何不一致,则all()将返回False
is_perfect_match = (y_train_pred == (y_scores > 0)).all()
is_perfect_match

运行结果:

np.True_

另一个选出好的准确率/召回率折衷的方法是直接画出准确率对召回率的曲线,如图 3-5 所示。

import matplotlib.pyplot as plt

# 自定义函数plot_precision_vs_recall,用于绘制精确度-召回率曲线
def plot_precision_vs_recall(precisions, recalls):
    """
    绘制精确度-召回率曲线的函数。
    
    参数:
    precisions : ndarray
        在不同召回率水平上的精确度值数组。
    recalls : ndarray
        召回率值数组,通常从0到1。
    """
    # 使用蓝色实线绘制召回率与精确度之间的关系
    plt.plot(recalls, precisions, "b-", linewidth=2)
    
    # 设置x轴标签为"Recall",字体大小为16
    plt.xlabel("Recall", fontsize=16)
    
    # 设置y轴标签为"Precision",字体大小为16
    plt.ylabel("Precision", fontsize=16)
    
    # 设置轴的范围在[0, 1]之间,确保可以完整显示曲线
    plt.axis([0, 1, 0, 1])

# 设置绘图的尺寸为8x6英寸
plt.figure(figsize=(8, 6))

# 调用自定义函数plot_precision_vs_recall绘制精确度-召回率曲线
# 这里precisions和recalls应该是之前计算或获取的精确度和召回率数组
plot_precision_vs_recall(precisions, recalls)

# 假设 savefig 是一个已经定义的函数,用于保存当前的图形
# 这里需要确保save_fig函数能够正确执行,并且有保存图形的权限
plt.savefig("images/precision_vs_recall_plot.png",bbox_inches='tight')

# 显示图形,使用plt.show()可以在绘图窗口中查看图形
plt.show()

运行结果:

可以看到,在召回率在 80% 左右的时候,准确率急剧下降。你可能会想选择在急剧下降之前选择出一个准确率/召回率折衷点。比如说,在召回率 60% 左右的点。当然,这取决于你的项目需求。

我们假设你决定达到 90% 的准确率。你查阅第一幅图(放大一些),在 70000 附近找到一个阈值。为了作出预测(目前为止只在训练集上预测),你可以运行以下代码,而不是运行分类器的predict()方法。

# 使用决策函数的分数y_scores和阈值70000来生成预测
# y_scores是使用cross_val_predict方法和decision_function参数得到的决策函数分数
# 阈值70000是一个非常高的值,这意味着只有当模型对样本属于正类(例如数字5)的预测非常有信心时
# 才会将其分类为正类,这可能会显著降低召回率但提高精确度
# y_train_pred_90是一个布尔数组,表示模型预测的样本是否属于正类,基于阈值70000
y_train_pred_90 = (y_scores > 70000)

让我们检查这些预测的准确率和召回率:

from sklearn.metrics import precision_score

# 使用precision_score函数计算在阈值70000下的模型预测精确度
# y_train_5是训练数据的实际标签的布尔数组,表示标签是否为5
# y_train_pred_90是在阈值70000下模型的预测结果布尔数组
# 由于70000是一个远高于y_scores可能值的阈值,y_train_pred_90可能很少有True的预测
# 这可能会导致一个非常高的精确度(如果至少有一个True预测),或者如果没有任何预测为True,则为0
# 这种阈值的选择可以帮助我们了解在非常保守的分类策略下模型的表现
precision_90 = precision_score(y_train_5, y_train_pred_90)
precision_90

运行结果:

np.float64(0.8366812227074236)
from sklearn.metrics import recall_score

# 使用recall_score函数计算在阈值70000下的模型预测召回率
# y_train_5是训练数据的实际标签的布尔数组,表示标签是否为5
# y_train_pred_90是在阈值70000下模型的预测结果布尔数组
# 由于70000是一个远高于y_scores可能值的阈值,y_train_pred_90可能很少有True的预测
# 这将导致召回率非常低,因为大多数实际为正类的样本可能没有被预测为正类
# 召回率衡量的是所有实际为正类的样本中,被模型正确预测为正类的比例
recall_90 = recall_score(y_train_5, y_train_pred_90)
recall_90

运行结果:

np.float64(0.560044272274488)

如果有人说“让我们达到 99% 的准确率”,你应该问“相应的召回率是多少?”

附录:

一、一些知识点

1、StratifiedKFold

StratifiedKFoldscikit-learn 库中的一个类,它用于实现分层K折交叉验证(Stratified Cross-Validation)。这种交叉验证方法特别适用于分类问题,尤其是在数据集中的类别分布不均衡时。以下是 StratifiedKFold 的一些关键点:

二、源码工程

GitHub - XANkui/PythonMachineLearnIntermediateLevel: Python 机器学习是利用 Python 编程语言中的各种工具和库来实现机器学习算法和技术的过程。Python 是一种功能强大且易于学习和使用的编程语言,因此成为了机器学习领域的首选语言之一。这里我们一起开始一场Python 机器学习进阶之旅。

下的 02HandwritingDatabaseClassificationHandler

三、该案例的环境 package 信息如下

Package                   Version
------------------------- --------------
anyio                     4.4.0
argon2-cffi               23.1.0
argon2-cffi-bindings      21.2.0
arrow                     1.3.0
asttokens                 2.4.1
async-lru                 2.0.4
attrs                     23.2.0
Babel                     2.15.0
beautifulsoup4            4.12.3
bleach                    6.1.0
certifi                   2024.7.4
cffi                      1.16.0
charset-normalizer        3.3.2
colorama                  0.4.6
comm                      0.2.2
contourpy                 1.2.1
cycler                    0.12.1
debugpy                   1.8.2
decorator                 5.1.1
defusedxml                0.7.1
executing                 2.0.1
fastjsonschema            2.20.0
fonttools                 4.53.1
fqdn                      1.5.1
h11                       0.14.0
httpcore                  1.0.5
httpx                     0.27.0
idna                      3.7
ipykernel                 6.29.5
ipython                   8.26.0
ipywidgets                8.1.3
isoduration               20.11.0
jedi                      0.19.1
Jinja2                    3.1.4
joblib                    1.4.2
json5                     0.9.25
jsonpointer               3.0.0
jsonschema                4.23.0
jsonschema-specifications 2023.12.1
jupyter                   1.0.0
jupyter_client            8.6.2
jupyter-console           6.6.3
jupyter_core              5.7.2
jupyter-events            0.10.0
jupyter-lsp               2.2.5
jupyter_server            2.14.2
jupyter_server_terminals  0.5.3
jupyterlab                4.2.4
jupyterlab_pygments       0.3.0
jupyterlab_server         2.27.3
jupyterlab_widgets        3.0.11
kiwisolver                1.4.5
MarkupSafe                2.1.5
matplotlib                3.9.1
matplotlib-inline         0.1.7
mistune                   3.0.2
nbclient                  0.10.0
nbconvert                 7.16.4
nbformat                  5.10.4
nest-asyncio              1.6.0
notebook                  7.2.1
notebook_shim             0.2.4
numpy                     2.0.1
overrides                 7.7.0
packaging                 24.1
pandas                    2.2.2
pandocfilters             1.5.1
parso                     0.8.4
pillow                    10.4.0
pip                       24.1.2
platformdirs              4.2.2
prometheus_client         0.20.0
prompt_toolkit            3.0.47
psutil                    6.0.0
pure_eval                 0.2.3
pycparser                 2.22
Pygments                  2.18.0
pyparsing                 3.1.2
python-dateutil           2.9.0.post0
python-json-logger        2.0.7
pytz                      2024.1
pywin32                   306
pywinpty                  2.0.13
PyYAML                    6.0.1
pyzmq                     26.0.3
qtconsole                 5.5.2
QtPy                      2.4.1
referencing               0.35.1
requests                  2.32.3
rfc3339-validator         0.1.4
rfc3986-validator         0.1.1
rpds-py                   0.19.1
scikit-learn              1.5.1
scipy                     1.14.0
Send2Trash                1.8.3
setuptools                70.1.1
six                       1.16.0
sniffio                   1.3.1
soupsieve                 2.5
stack-data                0.6.3
terminado                 0.18.1
threadpoolctl             3.5.0
tinycss2                  1.3.0
tornado                   6.4.1
traitlets                 5.14.3
types-python-dateutil     2.9.0.20240316
typing_extensions         4.12.2
tzdata                    2024.1
uri-template              1.3.0
urllib3                   2.2.2
wcwidth                   0.2.13
webcolors                 24.6.0
webencodings              0.5.1
websocket-client          1.8.0
wheel                     0.43.0
widgetsnbextension        4.0.11

08-06 10:56