如何使用谓词功能过滤存储在队列中的数据?例如,假设我们有一个队列来存储特征和标签的张量,而我们只需要满足谓词的那些即可。我尝试了以下实现,但未成功:

feature, label = queue.dequeue()
if (predicate(feature, label)):
    enqueue_op = another_queue.enqueue(feature, label)

最佳答案

最简单的方法是使批处理出队,通过谓词测试运行它们,使用 tf.where 生成与谓词匹配的密集向量,然后使用 tf.gather 收集结果,并将该批列队。如果您希望这种情况自动发生,则可以在第二个队列上启动队列运行器-最简单的方法是使用tf.train.batch:

例子:

import numpy as np
import tensorflow as tf

a = tf.constant(np.array([5, 1, 9, 4, 7, 0], dtype=np.int32))

q = tf.FIFOQueue(6, dtypes=[tf.int32], shapes=[])
enqueue = q.enqueue_many([a])
dequeue = q.dequeue_many(6)
predmatch = tf.less(dequeue, [5])
selected_items = tf.reshape(tf.where(predmatch), [-1])
found = tf.gather(dequeue, selected_items)

secondqueue = tf.FIFOQueue(6, dtypes=[tf.int32], shapes=[])
enqueue2 = secondqueue.enqueue_many([found])
dequeue2 = secondqueue.dequeue_many(3) # XXX, hardcoded

with tf.Session() as sess:
  sess.run(tf.global_variables_initializer())
  sess.run(enqueue)  # Fill the first queue
  sess.run(enqueue2) # Filter, push into queue 2
  print sess.run(dequeue2) # Pop items off of queue2

谓词产生 bool 向量; tf.where生成真实值索引的密集向量,并且tf.gather基于这些索引从原始张量中收集项。

当然,在此示例中,很多事情都是经过硬编码的,因此您实际上需要进行非硬编码,但是希望它可以显示您要执行的操作的结构(创建过滤管道)。在实践中,您希望在那里的QueueRunners自动保持搅动。使用tf.train.batch对自动处理非常有用-有关更多详细信息,请参见Threading and Queues

关于tensorflow - 如何基于 tensorflow 中的某些谓词从队列中过滤张量?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/33903569/

10-13 07:42