本文介绍了将 tf.data.Dataset 包装到 tf.function 中会提高性能吗?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

鉴于下面的两个示例,在对 tf.data.Dataset 签名时是否有性能改进?

Given the two examples below, is there a performance improvement when autographing the tf.data.Dataset?

数据集不在 tf.function 中

Dataset not in tf.function

import tensorflow as tf


class MyModel(tf.keras.Model):

    def call(self, inputs):
        return tf.ones([1, 1]) * inputs


model = MyModel()
model2 = MyModel()


@tf.function
def train_step(data):
    output = model(data)
    output = model2(output)
    return output


dataset = tf.data.Dataset.from_tensors(tf.ones([1, 1]))

for data in dataset:
    train_step(data)

tf.function 中的数据集

Dataset in tf.function

import tensorflow as tf


class MyModel(tf.keras.Model):

    def call(self, inputs):
        return tf.ones([1, 1]) * inputs


model = MyModel()
model2 = MyModel()


@tf.function
def train():
    dataset = tf.data.Dataset.from_tensors(tf.ones([1, 1]))
    def train_step(data):
        output = model(data)
        output = model2(output)
        return output
    for data in dataset:
        train_step(data)


train()

推荐答案

添加 @tf.function 确实显着提高了速度.看看这个:

Adding @tf.function does add significant speedup. Take a look at this:

import tensorflow as tf

data = tf.random.normal((1000, 10, 10, 1))
dataset = tf.data.Dataset.from_tensors(data).batch(10)

def iterate_1(dataset):
    for x in dataset:
        x = x

@tf.function
def iterate_2(dataset):
    for x in dataset:
        x = x

%timeit -n 1000 iterate_1(dataset) # 1.46 ms ± 8.2 µs per loop
%timeit -n 1000 iterate_2(dataset) # 239 µs ± 10.2 µs per loop

如您所见,使用 @tf.function 进行迭代的速度提高了 6 倍以上.

As you can see iterating with @tf.function is more than 6 times faster.

这篇关于将 tf.data.Dataset 包装到 tf.function 中会提高性能吗?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

08-06 09:55