本文介绍了TF 2.0 while_loop 和 parallel_iterations的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试使用 tf.while_loop 并行运行循环.但是,在以下玩具示例中,循环似乎没有并行运行.

I am trying to use tf.while_loop to run loops in parallel. However, in the following toy examples,loops don't appear to be running in parallel.

iteration = tf.constant(0)
c = lambda i: tf.less(i, 1000)
def print_fun(iteration):
    print(f"This is iteration {iteration}")
    iteration+=1
    return (iteration,)
r = tf.while_loop(c, print_fun, [iteration], parallel_iterations=10)

i = tf.constant(0)
c = lambda i: tf.less(i, 1000)
b = lambda i: (tf.add(i, 1),)
r = tf.while_loop(c, b, [i])

是什么阻止了 tf.while_loop 并行化循环?

What is preventing the tf.while_loop from parallelizing the loop?

此外,如果维护 Tensorflow 文档的任何人看到此页面,他/她应该修复第一个示例中的错误.请参阅此处的讨论.

In addition, if anyone who maintain the Tensorflow documentation see this page, he/she should fix the bug in the first example. See the discussion here.

谢谢.

推荐答案

parallel_iterations 在 Eager 模式下运行时没有任何意义,但您始终可以使用 tf.function 装饰器并获得显着的加速.这可以在这张图片中看到:运行时间

parallel_iterations doesn't mean anything when running in eager mode, but you can always use tf.function decorator and gain significant speedups. This can be seen in this picture: running times

你可以像这样用 tf.function 包裹你的 tf.while_loop

You can wrap your tf.while_loop with tf.function like this

@tf.function
def run_graph():
    iteration = tf.constant(0)
    r = tf.while_loop(c, print_fun, [iteration], parallel_iterations=4)

然后在需要时调用 run_graph.

这篇关于TF 2.0 while_loop 和 parallel_iterations的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

09-05 10:18