我正在使用具有纪元限制的 tf.train.string_input_producer 将数据输入到我的模型中。如何在训练期间获得此操作的当前纪元?

我注意到图中有一些与此操作相关的节点,其中一个包含纪元限制,但我找不到实际当前值的存储位置。这肯定是在某个地方被跟踪?

更一般地说,我如何监视 TFRecords 管道中的当前纪元?

最佳答案

我无法在 TF 的任何地方找到它。

我的解决方案是手动完成,通过对(无限)重复进行批处理,并根据需要随时调用我的节点(通过计算数据集中的项目数,除以批次大小 = 一个时期来预先确定)。

最近发布的 TF 使用 tensorflow.contrib.data.TFRecordDataset 使这变得更容易:

d = TFRecordDataset('some_filename.tfrecords')
d = d.map(function_which_parses_your_protobuf_format)
d = d.repeat()
d = d.shuffle()
d = d.batch(batch_size)

然后,您可以使用以下方法确定数据集的大小
record_count = sum([1 for r in tf.python_io.tf_record_iteration('your_filename.tfrecord')])

这似乎是更多的工作,但它提供了更好的灵活性,因为您可以,例如,使用缓存,因此您不必提前预处理您的数据集,因此可以将原始未触及的数据集存储在 tfrecord 文件中。

关于tensorflow - 从输入管道获取当前纪元,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/44090867/

10-12 19:30