我正在尝试使用tf.train.shuffle_batch
使用TensorFlow 1.0消耗TFRecord
文件中的数据批次。相关功能为:
def tfrecord_to_graph_ops(filenames_list):
file_queue = tf.train.string_input_producer(filenames_list)
reader = tf.TFRecordReader()
_, tfrecord = reader.read(file_queue)
tfrecord_features = tf.parse_single_example(
tfrecord,
features={'targets': tf.FixedLenFeature([], tf.string)}
)
## if no reshaping: `ValueError: All shapes must be fully defined` in
## `tf.train.shuffle_batch`
targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)
## if using `strided_slice`, always get the first record
# targets = tf.cast(
# tf.strided_slice(targets, [0], [1]),
# tf.int32
# )
## error on shapes being fully defined
# targets = tf.reshape(targets, [])
## get us: Invalid argument: Shape mismatch in tuple component 0.
## Expected [1], got [1000]
targets.set_shape([1])
return targets
def batch_generator(filenames_list, batch_size=BATCH_SIZE):
targets = tfrecord_to_graph_ops(filenames_list)
targets_batch = tf.train.shuffle_batch(
[targets],
batch_size=batch_size,
capacity=(20 * batch_size),
min_after_dequeue=(2 * batch_size)
)
targets_batch = tf.one_hot(
indices=targets_batch, depth=10, on_value=1, off_value=0
)
return targets_batch
def examine_batches(targets_batch):
with tf.Session() as sess:
coord = tf.train.Coordinator()
threads = tf.train.start_queue_runners(coord=coord)
for _ in range(10):
targets = sess.run([targets_batch])
print(targets)
coord.request_stop()
coord.join(threads)
该代码通过
examine_batches()
输入,并已交给batch_generator()
输出。 batch_generator()
调用tfrecord_to_graph_ops()
,我认为问题出在该函数中。我在打电话
targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)
在具有1,000个字节(数字0-9)的文件上。如果我在会话中对此调用
eval()
,它将显示所有1000个元素。但是,如果我尝试将其放入批处理生成器中,则会崩溃。如果不重塑
targets
,则在调用ValueError: All shapes must be fully defined
时会收到类似tf.train.shuffle_batch
的错误。如果我调用targets.set_shape([1])
,使人联想到Google的CIFAR-10 example code,则会在Invalid argument: Shape mismatch in tuple component 0. Expected [1], got [1000]
中收到类似tf.train.shuffle_batch
的错误。我还尝试使用tf.strided_slice
剪切大量原始数据-这不会崩溃,但只会导致第一个事件一遍又一遍。什么是正确的方法?要从
TFRecord
文件提取批次?请注意,我可以手动编写一个将原始字节数据切碎并进行某种批处理的函数-如果我使用
feed_dict
方法将数据放入图中特别容易-但我正在尝试学习如何使用TensorFlow的函数TFRecord
文件以及如何使用其内置的批处理功能。谢谢!
最佳答案
Allen Lavoie在评论中指出了正确的解决方案。重要的缺失部分是enqueue_many=True
作为tf.train.shuffle_batch()
的参数。编写这些函数的正确方法是:
def tfrecord_to_graph_ops(filenames_list):
file_queue = tf.train.string_input_producer(filenames_list)
reader = tf.TFRecordReader()
_, tfrecord = reader.read(file_queue)
tfrecord_features = tf.parse_single_example(
tfrecord,
features={'targets': tf.FixedLenFeature([], tf.string)}
)
targets = tf.decode_raw(tfrecord_features['targets'], tf.uint8)
targets = tf.reshape(targets, [-1])
return targets
def batch_generator(filenames_list, batch_size=BATCH_SIZE):
targets = tfrecord_to_graph_ops(filenames_list)
targets_batch = tf.train.shuffle_batch(
[targets],
batch_size=batch_size,
capacity=(20 * batch_size),
min_after_dequeue=(2 * batch_size),
enqueue_many=True
)
return targets_batch