1.导入MNIST数据集
直接使用fetch_mldata会报错,错误信息是python3.7把fetch_mldata方法移除了,所以需要单独下载数据集
从这个网站上下载数据集:
https://github.com/amplab/datascience-sp14/raw/master/lab7/mldata/mnist-original.mat
使用如下方法获取路径:
from sklearn.datasets.base import get_data_home
print (get_data_home()) # 如我的电脑上的目录为: C:\Users\Mr.Wmn\scikit_learn_data\mldata
#下载好mnist-original.mat数据集放到获取的路径里,在输入如下内容便不会报错了
#导入数据集拆分工具
from sklearn.model_selection import train_test_split
#导入数据集获取工具
from sklearn.datasets import fetch_mldata
#导入MLP神经网络
from sklearn.neural_network import MLPClassifier
#导入numpy
import numpy as np
#加载MNIST手写数字数据集
mnist = fetch_mldata('MNIST original')
mnist
{'DESCR': 'mldata.org dataset: mnist-original',
'COL_NAMES': ['label', 'data'],
'target': array([0., 0., 0., ..., 9., 9., 9.]),
'data': array([[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
...,
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0],
[0, 0, 0, ..., 0, 0, 0]], dtype=uint8)}
print('\n\n\n')
print('代码运行结果')
print('====================================\n')
#打印样本数量和样本特征数
print('样本数量:{} 样本特征数:{}'.format(mnist.data.shape[0],mnist.data.shape[1]))
print('\n====================================')
print('\n\n\n')
代码运行结果
==================================== 样本数量:70000 样本特征数:784 ====================================
#建立训练数据集和测试数据集
X = mnist.data/255
y = mnist.target
X_train,X_test,y_train,y_test = train_test_split(X,y,train_size=1000,random_state=62)
print(X_train,X_test,y_train,y_test)
[[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]] [[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]] [5. 7. 7. 0. 0. 4. 4. 5. 1. 2. 8. 7. 4. 8. 2. 3. 9. 7. 2. 5. 9. 7. 9. 6.
7. 1. 1. 3. 2. 6. 9. 4. 3. 1. 8. 3. 2. 0. 8. 0. 9. 7. 2. 9. 7. 9. 5. 1.
1. 0. 8. 2. 5. 6. 5. 2. 2. 8. 1. 6. 0. 5. 9. 9. 6. 5. 4. 3. 1. 7. 5. 9.
2. 4. 2. 3. 2. 2. 8. 1. 8. 9. 1. 3. 3. 4. 7. 7. 9. 8. 5. 0. 3. 0. 1. 0.
7. 5. 8. 2. 5. 8. 3. 7. 0. 1. 9. 4. 6. 7. 0. 7. 4. 0. 6. 0. 3. 4. 0. 6.
9. 0. 6. 1. 7. 0. 3. 9. 7. 3. 8. 0. 8. 8. 5. 6. 1. 2. 3. 9. 1. 9. 2. 1.
3. 4. 9. 1. 0. 8. 6. 3. 1. 8. 5. 0. 2. 0. 6. 5. 6. 3. 1. 3. 1. 1. 1. 5.
6. 1. 9. 2. 8. 5. 7. 8. 5. 9. 1. 2. 8. 1. 0. 9. 4. 2. 4. 2. 1. 8. 8. 8.
1. 2. 9. 7. 8. 4. 1. 6. 8. 6. 7. 8. 7. 2. 5. 2. 5. 8. 1. 4. 3. 0. 6. 3.
5. 8. 6. 9. 4. 9. 1. 6. 6. 6. 3. 0. 9. 2. 4. 2. 2. 4. 3. 9. 2. 5. 8. 4.
1. 0. 8. 1. 8. 5. 3. 1. 8. 7. 7. 9. 6. 9. 3. 7. 8. 9. 1. 1. 5. 8. 7. 7.
0. 0. 5. 7. 6. 2. 5. 4. 4. 4. 3. 0. 5. 0. 8. 1. 8. 5. 2. 6. 2. 9. 9. 2.
8. 3. 8. 0. 9. 4. 8. 5. 7. 7. 0. 1. 7. 2. 4. 2. 5. 3. 7. 0. 4. 4. 9. 1.
9. 0. 4. 3. 7. 4. 5. 8. 0. 8. 1. 2. 9. 2. 1. 3. 3. 0. 5. 3. 2. 8. 7. 6.
6. 6. 4. 3. 1. 8. 4. 0. 3. 9. 2. 1. 7. 0. 5. 2. 5. 4. 3. 5. 6. 9. 4. 7.
4. 4. 2. 4. 1. 1. 3. 1. 3. 8. 5. 7. 5. 0. 1. 8. 2. 8. 2. 2. 8. 7. 1. 6.
7. 7. 1. 7. 4. 0. 5. 1. 8. 5. 1. 9. 5. 6. 8. 6. 1. 7. 9. 0. 9. 2. 3. 6.
2. 4. 8. 2. 6. 8. 1. 9. 5. 0. 7. 5. 8. 2. 0. 5. 4. 3. 1. 8. 8. 7. 8. 9.
6. 6. 0. 9. 3. 9. 8. 9. 0. 5. 0. 6. 0. 1. 9. 3. 0. 3. 9. 8. 0. 6. 5. 3.
4. 8. 5. 3. 9. 5. 8. 4. 3. 7. 1. 4. 8. 9. 6. 4. 9. 1. 0. 3. 2. 8. 4. 1.
4. 9. 7. 5. 8. 2. 6. 0. 2. 8. 1. 2. 6. 6. 0. 8. 4. 7. 5. 7. 8. 9. 0. 0.
9. 2. 6. 4. 0. 3. 6. 9. 1. 1. 1. 9. 8. 4. 1. 6. 5. 4. 1. 3. 0. 0. 4. 7.
7. 3. 5. 3. 6. 0. 1. 9. 6. 3. 2. 2. 5. 9. 2. 7. 5. 1. 0. 1. 3. 9. 0. 4.
3. 6. 7. 5. 7. 5. 9. 3. 3. 4. 8. 1. 8. 0. 1. 2. 9. 8. 3. 6. 3. 0. 7. 1.
3. 2. 2. 9. 7. 8. 0. 6. 5. 6. 1. 5. 3. 4. 5. 1. 9. 4. 9. 6. 6. 7. 4. 2.
4. 5. 8. 9. 1. 9. 6. 7. 7. 0. 0. 6. 9. 0. 6. 0. 5. 2. 8. 9. 8. 1. 4. 3.
0. 6. 5. 4. 6. 9. 7. 1. 9. 0. 5. 4. 7. 6. 0. 5. 0. 3. 0. 0. 0. 1. 4. 0.
7. 5. 0. 9. 5. 3. 4. 9. 9. 7. 6. 0. 6. 1. 3. 5. 7. 5. 2. 9. 6. 0. 5. 7.
8. 5. 0. 9. 2. 8. 1. 7. 0. 8. 7. 7. 7. 7. 5. 5. 5. 7. 2. 1. 9. 9. 7. 2.
9. 4. 0. 4. 8. 3. 3. 4. 9. 3. 7. 0. 2. 9. 8. 8. 7. 5. 8. 7. 9. 0. 6. 9.
8. 6. 1. 7. 2. 8. 9. 8. 2. 4. 7. 1. 8. 8. 7. 3. 1. 8. 8. 9. 7. 4. 7. 7.
1. 1. 8. 8. 2. 7. 1. 7. 6. 3. 7. 6. 5. 2. 3. 7. 2. 0. 7. 3. 9. 8. 0. 0.
5. 4. 4. 2. 9. 9. 5. 8. 4. 7. 4. 8. 5. 8. 3. 1. 7. 4. 8. 9. 2. 3. 8. 7.
3. 2. 7. 2. 6. 7. 7. 9. 1. 0. 8. 6. 6. 9. 4. 7. 9. 3. 1. 4. 6. 0. 8. 6.
5. 1. 8. 2. 1. 5. 3. 7. 1. 2. 5. 4. 6. 2. 6. 3. 2. 1. 7. 6. 8. 6. 3. 8.
3. 9. 0. 4. 2. 2. 8. 8. 3. 7. 8. 4. 3. 5. 3. 2. 2. 8. 0. 1. 0. 9. 4. 3.
6. 1. 6. 1. 2. 3. 3. 4. 0. 0. 7. 2. 6. 0. 3. 7. 2. 6. 4. 6. 6. 3. 9. 5.
6. 5. 8. 1. 3. 7. 8. 8. 9. 0. 1. 9. 3. 4. 1. 4. 1. 1. 7. 4. 8. 5. 8. 1.
3. 6. 3. 8. 5. 9. 0. 6. 4. 8. 0. 3. 3. 9. 1. 0. 4. 1. 3. 4. 6. 4. 9. 2.
8. 4. 0. 5. 3. 9. 7. 0. 9. 7. 8. 6. 7. 7. 6. 8. 9. 5. 1. 1. 7. 4. 5. 9.
8. 5. 1. 1. 7. 3. 1. 9. 9. 9. 3. 8. 2. 9. 7. 1. 7. 1. 1. 4. 3. 1. 1. 3.
0. 1. 3. 3. 4. 5. 8. 1. 2. 0. 4. 6. 7. 1. 2. 1.] [8. 2. 5. ... 3. 6. 9.]
2.训练MLP神经网络
#设置神经网络有两个100个节点的隐藏层
mlp_hw = MLPClassifier(solver='lbfgs',hidden_layer_sizes=[100,100],activation='relu',alpha = 1e-5,random_state=62)
#使用数据训练神经网络模型
mlp_hw.fit(X_train,y_train)
print('\n\n\n')
print('代码运行结果')
print('====================================\n')
#打印模型分数
print('测试数据集得分:{:.2f}% '.format(mlp_hw.score(X_test,y_test)*100))
print('\n====================================')
print('\n\n\n')
代码运行结果
==================================== 测试数据集得分:85.79% ====================================
3.使用模型进行数字识别
#导入图像处理工具
from PIL import Image
#打开图像
image = Image.open("4.png").convert('F')
#调整图像的大小
image = image.resize((28,28))
arr = []
#将图像中的像素作为预测数据点的特征
for i in range(28):
for j in range(28):
pixel = 1.0 - float(image.getpixel((j,i)))/255.
arr.append(pixel)
#由于只有一个样本,所以需要进行reshape操作
arr1 = np.array(arr).reshape(1,-1)
#进行图像识别
print('图片中的数字是:{:.0f}'.format(mlp_hw.predict(arr1)[0]))
图片中的数字是:5
总结 :
scikit-learn中的MLP分类和回归在易用性表现不错,但是仅限于处理小数据集,对于更庞大或更复杂的数据集就不怎么友好了.
在计算能力充足并且参数设置合适的情况下,神经网络可以比其他的机器学习算法表现更加优异
其缺点也很明显,如模型训练时间相对更长,对数据预处理的要求较高.
文章引自 : 《深入浅出python机器学习》