toredTrainingSession在训练和验证数据集之间切

toredTrainingSession在训练和验证数据集之间切

本文介绍了如何使用tf.MonitoredTrainingSession在训练和验证数据集之间切换?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我想在tensorflow Dataset API中使用 feedable 迭代器设计,因此在经过一些培训之后,我可以切换到验证数据。但是,如果我切换到验证数据,它将结束整个会话。

I want to use feedable iterator design in tensorflow Dataset API, so I can switch to validation data after some training steps. But if I switched to validation data, it will end the whole session.

以下代码演示了我要执行的操作:

The following code demonstrate what I want to do:

import tensorflow as tf


graph = tf.Graph()
with graph.as_default():
    training_ds = tf.data.Dataset.range(32).batch(4)
    validation_ds = tf.data.Dataset.range(8).batch(4)

    handle = tf.placeholder(tf.string, shape=[])
    iterator = tf.data.Iterator.from_string_handle(
        handle, training_ds.output_types, training_ds.output_shapes)
    next_element = iterator.get_next()

    training_iterator = training_ds.make_initializable_iterator()
    validation_iterator = validation_ds.make_initializable_iterator()


with graph.as_default():

    with tf.train.MonitoredTrainingSession() as sess:
        training_handle = sess.run(training_iterator.string_handle())
        validation_handle = sess.run(validation_iterator.string_handle())
        sess.run(training_iterator.initializer)
        count_training = 0
        while not sess.should_stop():
            x = sess.run(next_element, feed_dict={handle: training_handle})
            count_training += 1
            print('{} [training] {}'.format(count_training, x.shape))
            # print(x)

            # we do periodic validation
            if count_training % 4 == 0:
                sess.run(validation_iterator.initializer)
                count_validation = 0
                while not sess.should_stop():
                    y = sess.run(next_element, feed_dict={handle: validation_handle})
                    count_validation += 1
                    print('  {} [validation] {}'.format(count_validation, y.shape))
                    # print(y)

训练数据有32个元素,并与4,所以有8批次
,我们每4个步骤进行一次验证,所以我期望:

The training data has 32 elements, batched with 4, so got 8 batcheswe do validation every 4 steps, so I expect:

#  1 [training]
# 2 [training]
# 3 [training]
# 4 [training]
#      1 [validation]
#      2 [validation]
# 5 [training]
# 6 [training]
# 7 [training]
# 8 [training]
#      1 [validation]
#      2 [validation]

,但在完成第一次验证后便会停止:

but it stops when the first validation is done:

# 1 [training]
# 2 [training]
# 3 [training]
# 4 [training]
#      1 [validation]
#      2 [validation]

因此,如何使用此 tf.MonitoredTrainingSession

推荐答案

我建议捕获验证数据集结尾处引发的 tf.errors.OutOfRangeError (您也可以检查,以使用 repeat 数据集)进行另一种解决方案):

I would suggest to catch tf.errors.OutOfRangeError raised at the end of the validation dataset (you can also check the processing multiple epochs section in the official API for another solution using the repeat dataset ):

while not sess.should_stop():
    x = sess.run(next_element, feed_dict={handle: training_handle})
    count_training += 1
    print('{} [training] {}'.format(count_training, x.shape))

    # we do periodic validation
    if count_training % 4 == 0:
        sess.run(validation_iterator.initializer)
        count_validation = 0
        while True:
            try:
                y = sess.run(next_element, feed_dict={handle: validation_handle})
                count_validation += 1
                print('  {} [validation] {}'.format(count_validation, y.shape))
            except tf.errors.OutOfRangeError:
                break

这段代码打印出来:

1 [training] (4,)
2 [training] (4,)
3 [training] (4,)
4 [training] (4,)
  1 [validation] (4,)
  2 [validation] (4,)
5 [training] (4,)
6 [training] (4,)
7 [training] (4,)
8 [training] (4,)
  1 [validation] (4,)
  2 [validation] (4,)

这篇关于如何使用tf.MonitoredTrainingSession在训练和验证数据集之间切换?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

08-23 03:39