问题描述
我想在tensorflow Dataset API中使用 feedable
迭代器设计,因此在经过一些培训之后,我可以切换到验证数据。但是,如果我切换到验证数据,它将结束整个会话。
I want to use feedable
iterator design in tensorflow Dataset API, so I can switch to validation data after some training steps. But if I switched to validation data, it will end the whole session.
以下代码演示了我要执行的操作:
The following code demonstrate what I want to do:
import tensorflow as tf
graph = tf.Graph()
with graph.as_default():
training_ds = tf.data.Dataset.range(32).batch(4)
validation_ds = tf.data.Dataset.range(8).batch(4)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, training_ds.output_types, training_ds.output_shapes)
next_element = iterator.get_next()
training_iterator = training_ds.make_initializable_iterator()
validation_iterator = validation_ds.make_initializable_iterator()
with graph.as_default():
with tf.train.MonitoredTrainingSession() as sess:
training_handle = sess.run(training_iterator.string_handle())
validation_handle = sess.run(validation_iterator.string_handle())
sess.run(training_iterator.initializer)
count_training = 0
while not sess.should_stop():
x = sess.run(next_element, feed_dict={handle: training_handle})
count_training += 1
print('{} [training] {}'.format(count_training, x.shape))
# print(x)
# we do periodic validation
if count_training % 4 == 0:
sess.run(validation_iterator.initializer)
count_validation = 0
while not sess.should_stop():
y = sess.run(next_element, feed_dict={handle: validation_handle})
count_validation += 1
print(' {} [validation] {}'.format(count_validation, y.shape))
# print(y)
训练数据有32个元素,并与4,所以有8批次
,我们每4个步骤进行一次验证,所以我期望:
The training data has 32 elements, batched with 4, so got 8 batcheswe do validation every 4 steps, so I expect:
# 1 [training]
# 2 [training]
# 3 [training]
# 4 [training]
# 1 [validation]
# 2 [validation]
# 5 [training]
# 6 [training]
# 7 [training]
# 8 [training]
# 1 [validation]
# 2 [validation]
,但在完成第一次验证后便会停止:
but it stops when the first validation is done:
# 1 [training]
# 2 [training]
# 3 [training]
# 4 [training]
# 1 [validation]
# 2 [validation]
因此,如何使用此 tf.MonitoredTrainingSession
?
推荐答案
我建议捕获验证数据集结尾处引发的 tf.errors.OutOfRangeError
(您也可以检查,以使用 repeat
数据集)进行另一种解决方案):
I would suggest to catch tf.errors.OutOfRangeError
raised at the end of the validation dataset (you can also check the processing multiple epochs section in the official API for another solution using the repeat
dataset ):
while not sess.should_stop():
x = sess.run(next_element, feed_dict={handle: training_handle})
count_training += 1
print('{} [training] {}'.format(count_training, x.shape))
# we do periodic validation
if count_training % 4 == 0:
sess.run(validation_iterator.initializer)
count_validation = 0
while True:
try:
y = sess.run(next_element, feed_dict={handle: validation_handle})
count_validation += 1
print(' {} [validation] {}'.format(count_validation, y.shape))
except tf.errors.OutOfRangeError:
break
这段代码打印出来:
1 [training] (4,)
2 [training] (4,)
3 [training] (4,)
4 [training] (4,)
1 [validation] (4,)
2 [validation] (4,)
5 [training] (4,)
6 [training] (4,)
7 [training] (4,)
8 [training] (4,)
1 [validation] (4,)
2 [validation] (4,)
这篇关于如何使用tf.MonitoredTrainingSession在训练和验证数据集之间切换?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!