在纪元结束时阻塞

在纪元结束时阻塞

本文介绍了张量流 shuffle_batch() 在纪元结束时阻塞的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在使用 tf.train.shuffle_batch() 来创建批量输入图像.它包含一个 min_after_dequeue 参数,以确保内部队列中有指定数量的元素,如果没有则阻塞其他所有元素.

I'm using tf.train.shuffle_batch() to create batches of input images. It includes a min_after_dequeue parameter that makes sure there's a specified number of elements inside the internal queue, and blocks everything else if there isn't.

images, label_batch = tf.train.shuffle_batch(
  [image, label],
  batch_size=FLAGS.batch_size,
  num_threads=num_preprocess_threads,
  capacity=FLAGS.min_queue_size + 3 * FLAGS.batch_size,
  min_after_dequeue=FLAGS.min_queue_size)

在一个 epoch 结束时,当我进行评估时(我确定这在训练中也是如此,但我还没有测试过),一切都被阻止了.我发现在同一时刻,内部 shuffle 批处理队列将剩下少于 min_after_dequeue 元素.此时在程序中,我希望仅将剩余元素出列,但我不确定如何.

At the end of an epoch, when I'm doing evaluation (I'm sure this is also true in training but I haven't tested it), everything blocks. I figured out it's at the same moment the internal shuffle batch queue would be left with less than min_after_dequeue elements. At this time in the program I would ideally like to just dequeue the remaining elements but I'm not sure how.

显然,当您知道没有更多元素要使用 .close() 方法排队时,可以关闭 TF 队列中的这种类型的阻塞.但是,由于底层队列隐藏在函数内部,我该如何调用该方法?

Apparently this type of blocking inside TF queues can be shut off when you know there's no more elements to enqueue with the .close() method. However, since the underlying queue is hidden inside the function, how do I call that method?

推荐答案

这是我最终开始工作的代码,尽管有一堆警告说我排队的元素被取消了.

Here's the code that I eventually got to work, although with a bunch of warnings that elements I enqueued were cancelled.

lv = tf.constant(label_list)

label_fifo = tf.FIFOQueue(len(filenames),tf.int32,shapes=[[]])
# if eval_data:
    # num_epochs = 1
# else:
    # num_epochs = None
file_fifo = tf.train.string_input_producer(filenames, shuffle=False, capacity=len(filenames))
label_enqueue = label_fifo.enqueue_many([lv])


reader = tf.WholeFileReader()
result.key, value = reader.read(file_fifo)
image = tf.image.decode_jpeg(value, channels=3)
image.set_shape([128,128,3])
result.uint8image = image
result.label = label_fifo.dequeue()

images, label_batch = tf.train.shuffle_batch(
  [result.uint8image, result.label],
  batch_size=FLAGS.batch_size,
  num_threads=num_preprocess_threads,
  capacity=FLAGS.min_queue_size + 3 * FLAGS.batch_size,
  min_after_dequeue=FLAGS.min_queue_size)

#in eval file:
label_enqueue, images, labels = load_input.inputs()
#restore from checkpoint in between
coord = tf.train.Coordinator()
try:
  threads = []
  for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
    threads.extend(qr.create_threads(sess, coord=coord, daemon=True,
                                     start=True))

  num_iter = int(math.ceil(FLAGS.num_examples / FLAGS.batch_size))
  true_count = 0  # Counts the number of correct predictions.
  total_sample_count = num_iter * FLAGS.batch_size

  sess.run(label_enqueue)
  step = 0
  while step < num_iter and not coord.should_stop():
    end_epoch = False
    if step > 0:
        for qr in tf.get_collection(tf.GraphKeys.QUEUE_RUNNERS):
            #check if not enough elements in queue
            size = qr._queue.size().eval()
            if size - FLAGS.batch_size < FLAGS.min_queue_size:
                end_epoch = True
    if end_epoch:
        #enqueue more so that we can finish
        sess.run(label_enqueue)
    #actually run step
    predictions = sess.run([top_k_op])

这篇关于张量流 shuffle_batch() 在纪元结束时阻塞的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

08-12 12:59
查看更多