我尝试将新的tensorflow函数tf.contrib.data.prefetch加载到设备。

我的简单代码示例

model = build_network()

N=1000

def gen():
    while True:
        batch = np.random.rand(N, 48, 48, 3)
        # Do some heavy calculation
        yield batch

dataset = tf.data.Dataset.from_generator(gen, tf.float32)
dataset = dataset.apply(tf.contrib.data.prefetch_to_device('/gpu:0'))

iterator = dataset.make_one_shot_iterator()
x = iterator.get_next()

output = model(x)

g = gen()
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    for i in range(100):
        if i == 50:
            start = time.time()
        result = sess.run(output)
        #result = model.predict(next(g))
    end = time.time()
print('\nAverage time of one forward pass: {}\n'.format((end-start)/50))
print('Done')


这给出了错误:


  InvalidArgumentError(请参阅上面的回溯):无法分配设备
  对于操作“ IteratorGetDevice”:无法满足显式设备
  规范'/ device:GPU:0',因为不支持GPU内核
  设备可用。托管调试信息:托管组具有
  以下类型和设备:IteratorToStringHandle:CPU
  IteratorGetDevice:CPU OneShotIterator:CPU
  
  托管主机和用户请求的设备:OneShotIterator
  (OneShotIterator)IteratorGetDevice(IteratorGetDevice)
  /设备:GPU:0 IteratorToStringHandle(IteratorToStringHandle)
  
  注册的内核:device ='CPU'
  
  [[节点:IteratorGetDevice =
  IteratorGetDevice_device =“ / device:GPU:0”]]


这个新功能是不能与from_generator一起使用还是其他功能?

最佳答案

这是TensorFlow 1.8rc0候选版本中的一个错误。感谢您引起我们的注意!

现在它已在master branch中修复,并将在下一个每晚构建中使用。我也提交了cherry-pick to the 1.8 release branch,它应该包含在TensorFlow 1.8的下一个候选版本(和最终版本)中。

关于python - Tensorflow prefetch_to_device,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/49876643/

10-09 03:21