本文首发于我的个人博客QIMING.INFO,转载请带上链接及署名。
在本站的这篇文章《TensorFlow实现简单神经网络》中,我们用TensorFlow实现了对MINST手写数字集的分类,分类的准确率达到了92%,本文中将优化此神经网络,将准确率提升至98%。
1 优化思路
对神经网络进行优化时,可以采取的思路主要有以下几种:
- 合适的损失函数
- 合适的激活函数
- 合适的优化器
- 神经网络的层数
- 学习率的设置
- 处理过拟合问题
- 增大训练样本量、训练轮次
本例中,交叉熵函数比二次代价函数更适合作为损失函数,激活函数采用了tanh()
函数,优化器选用了Adam函数。
神经网络的层数并不是越多越好(太复杂的神经网络解决数据量较小的问题极易出现过拟合现象),本例中设置了两层中间层。
设置学习率时,学习率太大会导致参数的值不停摇摆,而不会收敛到一个极小值,太小又会大大降低优化速度,所以我们可以先使用一个较大的学习率来快速得到一个比较优的解,然后随着迭代的继续逐步减小学习率,使得模型在训练后期更加稳定。
为防止过拟合问题,本例中使用了dropout机制。
在深度学习中,增大训练样本量可以使很多问题迎刃而解,但在本例中并不适用,因为本例已经使用了MNIST的全部训练数据。但是可以增加训练轮次,本例中将上文的21次提升到了51次。
好了,来敲敲代码看疗效吧~
2 代码及说明
import tensorflow as tf
import numpy as np
from tensorflow.examples.tutorials.mnist import input_data
# 载入数据集
mnist = input_data.read_data_sets("MNIST_data",one_hot=True)
# 每个批次的大小
batch_size = 100
# 计算一共有多少个批次
n_batch = mnist.train.num_examples // batch_size
# 定义placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
# 定义dropout
keep_prob = tf.placeholder(tf.float32)
# 定义一个可变的学习率变量
lr = tf.Variable(0.001,dtype=tf.float32)
# 创建神经网络
# 设置第一层中间层的节点数为1000个
W1 = tf.Variable(tf.truncated_normal([784,1000],stddev=0.1))
b1 = tf.Variable(tf.zeros([1000])+0.1)
L1 = tf.nn.tanh(tf.matmul(x,W1)+b1)
L1_drop = tf.nn.dropout(L1,keep_prob)
# 设置第二层中间层的节点数为500个
W2 = tf.Variable(tf.truncated_normal([1000,500],stddev=0.1))
b2 = tf.Variable(tf.zeros([500])+0.1)
L2 = tf.nn.tanh(tf.matmul(L1_drop,W2)+b2)
L2_drop = tf.nn.dropout(L2,keep_prob)
# 输出层
W3 = tf.Variable(tf.truncated_normal([500,10],stddev=0.1))
b3 = tf.Variable(tf.zeros([10])+0.1)
prediction = tf.nn.softmax(tf.matmul(L2_drop,W3)+b3)
# 交叉熵代价函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y,logits=prediction))
# 使用Adam作为优化器进行训练
train_step = tf.train.AdamOptimizer(lr).minimize(loss)
# 初始化变量
init = tf.global_variables_initializer()
# 结果存放在一个布尔型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1)) # argmax返回一维张量中最大的值所在的位置
# 求准确率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
with tf.Session() as sess:
sess.run(init)
for epoch in range(51):
# 每训练一轮 学习率降低
sess.run(tf.assign(lr,0.001 * (0.95 ** epoch)))
for batch in range(n_batch):
batch_xs,batch_ys = mnist.train.next_batch(batch_size)
sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys,keep_prob:0.7})
# 计算测试数据的准确率
test_acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels,keep_prob:1.0 })
# 计算训练数据的准确率
train_acc = sess.run(accuracy,feed_dict={x:mnist.train.images,y:mnist.train.labels,keep_prob:1.0})
# 输出训练轮次、测试数据准确率、训练数据准确率
print("Iter "+str(epoch)+",Testing Accuracy "+str(test_acc)+",Training Accuracy " + str(train_acc) )
3 结果
Iter 0,Testing Accuracy 0.9439,Training Accuracy 0.9438
Iter 1,Testing Accuracy 0.9515,Training Accuracy 0.9538364
Iter 2,Testing Accuracy 0.9582,Training Accuracy 0.96207273
Iter 3,Testing Accuracy 0.9616,Training Accuracy 0.9679273
Iter 4,Testing Accuracy 0.9659,Training Accuracy 0.9701818
Iter 5,Testing Accuracy 0.9668,Training Accuracy 0.9737818
Iter 6,Testing Accuracy 0.9691,Training Accuracy 0.9764364
Iter 7,Testing Accuracy 0.9718,Training Accuracy 0.979
Iter 8,Testing Accuracy 0.9707,Training Accuracy 0.9800364
Iter 9,Testing Accuracy 0.9716,Training Accuracy 0.98210907
Iter 10,Testing Accuracy 0.9744,Training Accuracy 0.9829818
Iter 11,Testing Accuracy 0.973,Training Accuracy 0.98376364
Iter 12,Testing Accuracy 0.9743,Training Accuracy 0.9856
Iter 13,Testing Accuracy 0.9749,Training Accuracy 0.9863091
Iter 14,Testing Accuracy 0.9755,Training Accuracy 0.9862546
Iter 15,Testing Accuracy 0.974,Training Accuracy 0.98661816
Iter 16,Testing Accuracy 0.9763,Training Accuracy 0.9874
Iter 17,Testing Accuracy 0.9751,Training Accuracy 0.9886909
Iter 18,Testing Accuracy 0.9768,Training Accuracy 0.98914546
Iter 19,Testing Accuracy 0.9756,Training Accuracy 0.98987275
Iter 20,Testing Accuracy 0.9766,Training Accuracy 0.9896182
Iter 21,Testing Accuracy 0.9771,Training Accuracy 0.9906545
Iter 22,Testing Accuracy 0.9786,Training Accuracy 0.9912364
Iter 23,Testing Accuracy 0.9781,Training Accuracy 0.99152726
Iter 24,Testing Accuracy 0.9782,Training Accuracy 0.9915636
Iter 25,Testing Accuracy 0.9778,Training Accuracy 0.9921273
Iter 26,Testing Accuracy 0.9799,Training Accuracy 0.99243635
Iter 27,Testing Accuracy 0.979,Training Accuracy 0.99258184
Iter 28,Testing Accuracy 0.9798,Training Accuracy 0.99285454
Iter 29,Testing Accuracy 0.9784,Training Accuracy 0.99294543
Iter 30,Testing Accuracy 0.9789,Training Accuracy 0.99307275
Iter 31,Testing Accuracy 0.9794,Training Accuracy 0.99325454
Iter 32,Testing Accuracy 0.9786,Training Accuracy 0.9934545
Iter 33,Testing Accuracy 0.9791,Training Accuracy 0.9937818
Iter 34,Testing Accuracy 0.9797,Training Accuracy 0.9938545
Iter 35,Testing Accuracy 0.9799,Training Accuracy 0.9941273
Iter 36,Testing Accuracy 0.9802,Training Accuracy 0.99407274
Iter 37,Testing Accuracy 0.9807,Training Accuracy 0.99438184
Iter 38,Testing Accuracy 0.9814,Training Accuracy 0.9944182
Iter 39,Testing Accuracy 0.9805,Training Accuracy 0.99447274
Iter 40,Testing Accuracy 0.9809,Training Accuracy 0.9945091
Iter 41,Testing Accuracy 0.9813,Training Accuracy 0.9946182
Iter 42,Testing Accuracy 0.9811,Training Accuracy 0.99474543
Iter 43,Testing Accuracy 0.9809,Training Accuracy 0.9948364
Iter 44,Testing Accuracy 0.9812,Training Accuracy 0.99485457
Iter 45,Testing Accuracy 0.9814,Training Accuracy 0.99487275
Iter 46,Testing Accuracy 0.9824,Training Accuracy 0.9948909
Iter 47,Testing Accuracy 0.9817,Training Accuracy 0.9950182
Iter 48,Testing Accuracy 0.982,Training Accuracy 0.9950909
Iter 49,Testing Accuracy 0.9821,Training Accuracy 0.9951091
Iter 50,Testing Accuracy 0.982,Training Accuracy 0.9951091
可以看出,在训练了51轮后,测试数据的准确率已经达到了98.2%,训练数据的准确率达到了99.5% 。
4 参考资料
[1]@Bilibili.深度学习框架Tensorflow学习与应用.2018-03
[2]郑泽宇,梁博文,顾思宇.TensorFlow:实战Goole深度学习框架(第2版)[M].北京:电子工业出版社.2018-02