我要解决的问题如下:
我有一个文件名列表trainimgs。我已经定义了

  • tf.RandomShuffleQueue及其capacity=len(trainimgs)min_after_dequeue=0
  • tf.RandomShuffleQueue预计将由trainimgs填充指定的epochlimit次数。
  • 预期许多线程可以并行工作。每个线程从tf.RandomShuffleQueue中取出一个元素,并对它进行一些操作,然后将其排队到另一个队列中。我说对了。
  • 但是,一旦处理了1 epochtrainimgstf.RandomShuffleQueue为空,只要当前时代e < epochlimit,就必须再次填充队列并且线程必须再次工作。

  • 好消息是:在某些情况下,我已经开始使用它了(请参阅 PS !)

    坏消息是:我认为这样做有更好的方法。

    我现在用来执行此操作的方法如下(我简化了功能,并删除了基于图像处理的预处理和后续入队,但是处理的核心保持不变!):
    with tf.Session() as sess:
        train_filename_queue = tf.RandomShuffleQueue(capacity=len(trainimgs), min_after_dequeue=0, dtypes=tf.string, seed=0)
        queue_size = train_filename_queue.size()
        trainimgtensor = tf.constant(trainimgs)
        close_queue = train_filename_queue.close()
        epoch = tf.Variable(initial_value=1, trainable=False, dtype=tf.int32)
        incrementepoch = tf.assign(epoch, epoch + 1, use_locking=True)
        supplyimages = train_filename_queue.enqueue_many(trainimgtensor)
        value = train_filename_queue.dequeue()
    
        init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        sess.run(init_op)
        coord = tf.train.Coordinator()
        tf.train.start_queue_runners(sess, coord)
        sess.run(supplyimages)
        lock = threading.Lock()
        threads = [threading.Thread(target=work, args=(coord, value, sess, epoch, incrementepoch, supplyimages, queue_size, lock, close_queue)) for  i in range(200)]
        for t in threads:
            t.start()
        coord.join(threads)
    

    工作功能如下:
    def work(coord, val, sess, epoch, incrementepoch, supplyimg, q, lock,\
             close_op):
    while not coord.should_stop():
        if sess.run(q) > 0:
            filename, currepoch = sess.run([val, epoch])
            filename = filename.decode(encoding='UTF-8')
            print(filename + ' ' + str(currepoch))
        elif sess.run(epoch) < 2:
            lock.acquire()
            try:
                if sess.run(q) == 0:
                    print("The previous epoch = %d"%(sess.run(epoch)))
                    sess.run([incrementepoch, supplyimg])
                    sz = sess.run(q)
                    print("The new epoch = %d"%(sess.run(epoch)))
                    print("The new queue size = %d"%(sz))
            finally:
                lock.release()
        else:
            try:
                sess.run(close_op)
            except tf.errors.CancelledError:
                print('Queue already closed.')
            coord.request_stop()
    return None
    

    因此,尽管这行得通,但我感觉有一种更好,更清洁的方法可以实现这一目标。因此,简而言之,我的问题是:
  • 在TensorFlow中是否有更简单,更干净的方法来完成此任务?
  • 该代码的逻辑有问题吗?我对多线程方案的经验不是很丰富,因此任何忽略我的明显错误都会对我有很大帮助。

  • P.S:看来这段代码毕竟不是完美的。当我运行120万个图像和200个线程时,它就运行了。但是,当我为10个图像和20个线程运行它时,出现以下错误:
    CancelledError (see above for traceback): RandomShuffleQueue '_0_random_shuffle_queue' is closed.
         [[Node: random_shuffle_queue_EnqueueMany = QueueEnqueueManyV2[Tcomponents=[DT_STRING], timeout_ms=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](random_shuffle_queue, Const)]]
    

    我以为我已经被except tf.errors.CancelledError覆盖了。这到底是怎么回事 ?

    最佳答案

    我终于找到答案了。问题在于,多个线程在work()函数中的各个点上冲突在一起。
    以下work()函数可完美运行。

    def work(coord, val, sess, epoch, maxepochs, incrementepoch, supplyimg, q, lock, close_op):
        print('I am thread number %s'%(threading.current_thread().name))
        print('I can see a queue with size %d'%(sess.run(q)))
        while not coord.should_stop():
            lock.acquire()
            if sess.run(q) > 0:
                filename, currepoch = sess.run([val, epoch])
                filename = filename.decode(encoding='UTF-8')
                tid = threading.current_thread().name
                print(filename + ' ' + str(currepoch) + ' thread ' + str(tid))
            elif sess.run(epoch) < maxepochs:
                print('Thread %s has acquired the lock'%(threading.current_thread().name))
                print("The previous epoch = %d"%(sess.run(epoch)))
                sess.run([incrementepoch, supplyimg])
                sz = sess.run(q)
                print("The new epoch = %d"%(sess.run(epoch)))
                print("The new queue size = %d"%(sz))
        else:
                coord.request_stop()
            lock.release()
    
        return None
    

    关于TensorFlow : Enqueuing and dequeuing a queue from multiple threads,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/42514206/

    10-13 22:06