本文介绍了TensorFlow:Dataset 的 apply 方法的简单自定义transformation_func 的示例实现的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我正在尝试为 transformation_funcnofollow noreferrer">apply 方法在数据集 API 中,但没有发现文档特别有用.

I am trying to implement a simple custom transformation_func for the apply method in the Dataset API, but didn't find the docs particularly helpful.

具体来说,我的dataset包含视频帧和相应的标签:{[frame_0, label_0], [frame_1, label_1], [frame_2, label_2],...}.

Specifically, my dataset contains video frames and corresponding labels: {[frame_0, label_0], [frame_1, label_1], [frame_2, label_2],...}.

我想对它进行转换,使其额外包含每个标签的前一帧:{[frame_0, frame_1, label_1], [frame_1, frame_2, label_2], [frame_2, frame_3, label_3],...}.

I'd like to transform it so that it additionally contains the previous frame for each label: {[frame_0, frame_1, label_1], [frame_1, frame_2, label_2], [frame_2, frame_3, label_3],...}.

这可能可以通过执行类似 tf.data.Dataset.zip(dataset, dataset.skip(1)) 之类的操作来实现,但那样我就会有重复的标签.

This could probably be achieved by doing something like tf.data.Dataset.zip(dataset, dataset.skip(1)), but then I would have duplicated labels.

我找不到 transformation_func 的参考实现.有没有人能让我开始做这件事?

I have not been able to find a reference implementation of a transformation_func. Is anyone able to get me started on this?

推荐答案

apply 只是为了方便与现有的转换函数一起使用,ds.apply(func) 是与 func(ds) 几乎相同,只是以更可链接"的方式.这是一种可能的方法来做你想做的事:

apply is simply a convenience to use with existing transformation functions, ds.apply(func) is pretty much the same as func(ds), only in a more "chainable" way. Here is one possible way to do what you want:

import tensorflow as tf

frames = tf.constant([  1,   2,   3,   4,   5,   6], dtype=tf.int32)
labels = tf.constant(['a', 'b', 'c', 'd', 'e', 'f'], dtype=tf.string)
# Create dataset
ds = tf.data.Dataset.from_tensor_slices((frames, labels))
# Zip it with itself but skipping the first one
ds = tf.data.Dataset.zip((ds, ds.skip(1)))
# Make desired output structure
ds = ds.map(lambda fl1, fl2: (fl1[0], fl2[0], fl2[1]))
# Iterate
it = ds.make_one_shot_iterator()
elem = it.get_next()
# Test
with tf.Session() as sess:
    while True:
        try: print(sess.run(elem))
        except tf.errors.OutOfRangeError: break

输出:

(1, 2, b'b')
(2, 3, b'c')
(3, 4, b'd')
(4, 5, b'e')
(5, 6, b'f')

这篇关于TensorFlow:Dataset 的 apply 方法的简单自定义transformation_func 的示例实现的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

08-06 09:54