我正在尝试学习TensorFlow并在以下位置研究示例:https://github.com/aymericdamien/TensorFlow-Examples/blob/master/notebooks/3_NeuralNetworks/autoencoder.ipynb

然后,我在下面的代码中有一些疑问:

for epoch in range(training_epochs):
    # Loop over all batches
    for i in range(total_batch):
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        # Run optimization op (backprop) and cost op (to get loss value)
        _, c = sess.run([optimizer, cost], feed_dict={X: batch_xs})
    # Display logs per epoch step
    if epoch % display_step == 0:
        print("Epoch:", '%04d' % (epoch+1),
              "cost=", "{:.9f}".format(c))


由于mnist只是一个数据集,mnist.train.next_batch的确切含义是什么? dataset.train.next_batch是如何定义的?

谢谢!

最佳答案

mnist对象是从read_data_sets()模块中定义的tf.contrib.learn function返回的。 mnist.train.next_batch(batch_size)方法是通过here实现的,它返回两个数组的元组,其中第一个表示一批batch_size MNIST图像,第二个表示与这些图像相对应的一批batch-size标签。

图像以大小为[batch_size, 784]的二维NumPy数组返回(因为MNIST图像中有784个像素),标签返回的大小为大小为[batch_size]的一维NumPy数组(如果使用read_data_sets()调用)或大小为one_hot=False的二维NumPy数组(如果使用[batch_size, 10]调用read_data_sets())。

09-10 00:04
查看更多