问题描述
MNIST的TensorFlow文档建议了多种不同的方式来加载MNIST数据集:
The TensorFlow documentation for MNIST recommends multiple different ways to load the MNIST dataset:
- https://www.tensorflow.org/tutorials/layers
- https://www.tensorflow.org/versions/r1 .2/get_started/mnist/beginners
- https://www.tensorflow.org/versions/r1 .2/get_started/mnist/pros
- https://www.tensorflow.org/tutorials/layers
- https://www.tensorflow.org/versions/r1.2/get_started/mnist/beginners
- https://www.tensorflow.org/versions/r1.2/get_started/mnist/pros
文档中描述的所有方式均会在TensorFlow 1.8中引发许多不建议使用的警告.
我当前加载MNIST并创建培训批次的方式:
The way I'm currently loading MNIST and creating batches for training:
class MNIST:
def __init__(self, optimizer):
...
self.mnist_dataset = input_data.read_data_sets("/tmp/data/", one_hot=True)
self.test_data = self.mnist_dataset.test.images.reshape((-1, self.timesteps, self.num_input))
self.test_label = self.mnist_dataset.test.labels
...
def train_run(self, sess):
batch_input, batch_output = self.mnist_dataset.train.next_batch(self.batch_size, shuffle=True)
batch_input = batch_input.reshape((self.batch_size, self.timesteps, self.num_input))
_, loss = sess.run(fetches=[self.train_step, self.loss], feed_dict={self.input_placeholder: batch_input, self.output_placeholder: batch_output})
...
def test_run(self, sess):
loss = sess.run(fetches=[self.loss], feed_dict={self.input_placeholder: self.test_data, self.output_placeholder: self.test_label})
...
仅使用当前的方法,我该怎么做完全相同的事情?
我找不到与此有关的任何文档.
I couldn't find any documentation on this.
在我看来,新方法类似于以下内容:
It seems to me that the new way is something in the lines of:
train, test = tf.keras.datasets.mnist.load_data()
self.mnist_train_ds = tf.data.Dataset.from_tensor_slices(train)
self.mnist_test_ds = tf.data.Dataset.from_tensor_slices(test)
但是如何在train_run
和test_run
方法中使用这些数据集?
But how can I use these datasets in my train_run
and test_run
method?
推荐答案
使用TF dataset API
加载MNIST数据集的示例:
An example of loading the MNIST dataset using TF dataset API
:
创建mnist数据集以加载训练图像,有效图像和测试图像:
您可以使用Dataset.from_tensor_slices
或Dataset.from_generator
为numpy输入创建dataset
. Dataset.from_tensor_slices
将整个数据集添加到计算图,因此我们将改用Dataset.from_generator
.
You can create a dataset
for numpy inputs, either using Dataset.from_tensor_slices
or Dataset.from_generator
. Dataset.from_tensor_slices
adds the whole dataset to the computational graph, so we will use Dataset.from_generator
instead.
#load mnist data
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.mnist.load_data()
def create_mnist_dataset(data, labels, batch_size):
def gen():
for image, label in zip(data, labels):
yield image, label
ds = tf.data.Dataset.from_generator(gen, (tf.float32, tf.int32), ((28,28 ), ()))
return ds.repeat().batch(batch_size)
#train and validation dataset with different batch size
train_dataset = create_mnist_dataset(x_train, y_train, 10)
valid_dataset = create_mnist_dataset(x_test, y_test, 20)
可反馈的迭代器,可以在训练和验证之间进行切换
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, train_dataset.output_types, train_dataset.output_shapes)
image, label = iterator.get_next()
train_iterator = train_dataset.make_one_shot_iterator()
valid_iterator = valid_dataset.make_one_shot_iterator()
示例运行:
#A toy network
y = tf.layers.dense(tf.layers.flatten(image),1,activation=tf.nn.relu)
loss = tf.losses.mean_squared_error(tf.squeeze(y), label)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# The `Iterator.string_handle()` method returns a tensor that can be evaluated
# and used to feed the `handle` placeholder.
train_handle = sess.run(train_iterator.string_handle())
valid_handle = sess.run(valid_iterator.string_handle())
# Run training
train_loss, train_img, train_label = sess.run([loss, image, label],
feed_dict={handle: train_handle})
# train_image.shape = (10, 784)
# Run validation
valid_pred, valid_img = sess.run([y, image],
feed_dict={handle: valid_handle})
#test_image.shape = (20, 784)
这篇关于如何通过TensorFlow加载MNIST(包括下载)?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!