背景:
目标:
程序:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data#下载数据集
mnist = input_data.read_data_sets("C:/Users/WangT/Desktop/MNIST_data",one_hot=True)#导入数据
batch_size = 100#定义每一次批次处理的数据大小
n_batch = mnist.train.num_examples // batch_size
#计算分批处理次数,//是整除的除数,结果始终为整数,区别于/
#mnist.train.num_examples是训练集的数据大小,类似还有mnist.validation.num_examples, mnist.test.num_examples.
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])
#placeholder占位符,希望能输入任意数量的MNIST图像,每一张图像展平为784维的向量,用2维浮点数张量来表示这些图,这个张量的形状是【none,784】,此处None表示此张量的第一个维度可以是任意长度的。
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
#模型的参数,可以用Variable表示,可以计算输入值,也可以在计算中被修改,此处用全为零的张量来初始化w和b
prediction = tf.nn.softmax(tf.matmul(x,W)+b)
#得到预测结果
loss = tf.reduce_mean(tf.square(y - prediction))
#损失函数,评估模型好坏,tf.square是平方,tf.reduce_mean是取平均值
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)
#tf使用梯度下降法,以0.2的学习速率,不断修改模型参数来最小化loss
init = tf.global_variables_initializer()
#添加一个操作来初始化变量
correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(prediction,1))
#tf.equal()比对两个数,相同返回true不同返回false,tf.argmax(y,1)返回y最大时对应的x
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))
#tf.cast()将上述结果每一次转换成浮点型,累加并取平均值,得到准确率
with tf.Session() as sess:#定义对话
sess.run(init)
for epoch in range (21):#模型循环训练21次
for batch in range(n_batch):#每次训练要循环n_batch批次
batch_xs,batch_ys = mnist.train.next_batch(batch_size)#读取训练集的下一批数据
sess.run(train_step, feed_dict={x:batch_xs,y:batch_ys})#运行模型训练
acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y:mnist.test.labels})#每训练一次输出一次准确率,利用的是测试集的数据
print("Iter"+str(epoch)+",Testing Accuracy"+str(acc))
结果:
Iter0,Testing Accuracy0.8308
Iter1,Testing Accuracy0.8711
Iter2,Testing Accuracy0.8814
Iter3,Testing Accuracy0.8885
Iter4,Testing Accuracy0.8943
Iter5,Testing Accuracy0.8965
Iter6,Testing Accuracy0.8994
Iter7,Testing Accuracy0.9011
Iter8,Testing Accuracy0.9036
Iter9,Testing Accuracy0.9055
Iter10,Testing Accuracy0.9064
Iter11,Testing Accuracy0.9071
Iter12,Testing Accuracy0.908
Iter13,Testing Accuracy0.9086
Iter14,Testing Accuracy0.9098
Iter15,Testing Accuracy0.9106
Iter16,Testing Accuracy0.9116
Iter17,Testing Accuracy0.9129
Iter18,Testing Accuracy0.9131
Iter19,Testing Accuracy0.914
Iter20,Testing Accuracy0.9138
新的学习: