我正在尝试学习TensorFlow并在以下位置研究示例:https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/autoencoder.ipynb
然后,我在下面的代码中有一些疑问:
for epoch in range(training_epochs):
# Loop over all batches
for i in range(total_batch):
batch_xs, batch_ys = mnist.train.next_batch(batch_size)
# Run optimization op (backprop) and cost op (to get loss value)
_, c = sess.run([optimizer, cost], feed_dict={X: batch_xs})
# Display logs per epoch step
if epoch % display_step == 0:
print("Epoch:", '%04d' % (epoch+1),
"cost=", "{:.9f}".format(c))
由于mnist只是一个数据集,
mnist.train.next_batch
的确切含义是什么? dataset.train.next_batch
是如何定义的?谢谢!
最佳答案
mnist
对象是从read_data_sets()
模块中定义的tf.contrib.learn
function返回的。 mnist.train.next_batch(batch_size)
方法是通过here实现的,它返回两个数组的元组,其中第一个表示一批batch_size
MNIST图像,第二个表示与这些图像相对应的一批batch-size
标签。
图像以大小为[batch_size, 784]
的二维NumPy数组返回(因为MNIST图像中有784个像素),标签返回的大小为大小为[batch_size]
的一维NumPy数组(如果使用read_data_sets()
调用)或大小为one_hot=False
的二维NumPy数组(如果使用[batch_size, 10]
调用read_data_sets()
)。