1. 获取数据集并重新划分数据集
# 获取MNIST数据集
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1, cache=True, as_frame=False)
# 查看测试器和标签
X, y = mnist['data'], mnist['target']
X_train, X_test, y_train, y_test = X[:60000], X[60000:], y[:60000], y[60000:]
# 对数据进行洗牌,防止输入许多相似实例导致的执行性能不佳
import numpy as np
shuffle_index = np.random.permutation(60000)
X_train, y_train = X_train[shuffle_index], y_train[shuffle_index]
# 重新创建目标向量(以是5和非5作为二分类标准)
y_train_5 = (y_train == '5')
y_test_5 = (y_test == '5')
2. 使用SGD随机梯度下降进行多分类
some_digit = X[36000]
from sklearn.linear_model import SGDClassifier
sgd_clf = SGDClassifier(random_state=42)
sgd_clf.fit(X_train, y_train)
sgd_clf.predict([some_digit])
3. 对二分类算法强制使用一对一、一对多策略进行多分类
3.1 SGD
# 1. 使用OvO(一对一)策略,基于SGD创建多分类器
from sklearn.multiclass import OneVsOneClassifier
ovo_clf = OneVsOneClassifier(SGDClassifier(random_state=42))
ovo_clf.fit(X_train, y_train)
ovo_clf.predict([some_digit])
3.2 随机森林
# 1. 训练随机森林(因为随机森林本身就可以进行多分类)
from sklearn.ensemble import RandomForestClassifier
forest_clf = RandomForestClassifier(random_state=42)
forest_clf.fit(X_train, y_train)
forest_clf.predict([some_digit])
4. 对模型进行评估(使用准确率)
4.1 数据未标准化
# 1. 使用交叉验证对SGD多分类器进行评估
from sklearn.model_selection import cross_val_score
cross_val_score(sgd_clf, X_train, y_train, cv=3, scoring="accuracy")
4.2 数据标准化后
# 2. 对训练集进行标准化,再进行评估看看
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train.astype(np.float64)) # 标准化
cross_val_score(sgd_clf, X_train_scaled, y_train, cv=3, scoring="accuracy")
5. 绘制混淆矩阵并进行分类错误分析
5.1 原始混淆矩阵
# 混淆矩阵
from sklearn.model_selection import cross_val_predict
from sklearn.metrics import confusion_matrix
y_train_pred = cross_val_predict(sgd_clf, X_train_scaled, y_train, cv=3)
conf_mx = confusion_matrix(y_train, y_train_pred)
# 绘制混淆矩阵的图像
import matplotlib.pyplot as plt
plt.matshow(conf_mx, cmap=plt.cm.gray)
plt.show()
- 结论:
- 大多数图片都在主对角线上,说明它们被正确分类。
- 数字5稍微暗一点,可能数据集中5的图片比较少,也可能是分类器在5上的执行效果不如其他数字好。
5.2 将正确分类的剔除后只留下错误的
# 将混淆矩阵中的每个值 除以 相应类别中的图片数量
row_sums = conf_mx.sum(axis=1, keepdims=True) # 同行相加
norm_conf_mx = conf_mx / row_sums
# 用0填充对角线,保存错误,重新绘制混淆矩阵
np.fill_diagonal(norm_conf_mx, 0)
plt.matshow(norm_conf_mx, cmap=plt.cm.gray)
plt.show()
- 结论:
- 第8列、第9列比较亮,说明许多图片被错分为8和9;
- 第8行、第9行也偏亮,说明8、9容易和其他数字混淆;
- 行1很暗,说明大多数1都被正确分类;
- 数字5被分成8的数量比8分成5的数量更多。