我正在实现一个涉及交替优化的算法。也就是说,在每次迭代时,该算法获取一个数据批,并使用该数据批按顺序优化两个损失。我目前使用tf.data.Dataasettf.data.Iterator的实现是这样的(如下所述确实不正确):

data_batch = iterator.get_next()
train_op_1 = get_train_op(data_batch)
train_op_2 = get_train_op(data_batch)

for _ in range(num_steps):
    sess.run(train_op_1)
    sess.run(train_op_2)

注意,上面的内容是不正确的,因为sess.run的每个调用都将推进迭代器以获取下一个数据批。因此train_op_1train_op_2确实在使用不同的数据批。
我也不能做类似于sess.run([train_op_1, train_op_2])的事情,因为这两个优化步骤需要是连续的(即,第二个优化步骤取决于第一个优化步骤的最新变量值)
我想知道有什么方法可以“冻结”迭代器,这样它就不会在sess.run调用中前进了吗?

最佳答案

我做了一些类似的事情,所以这是我代码的一部分,去掉了一些不必要的东西。由于它有训练和验证迭代器,所以它做得更多,但是您应该了解使用is_keep_previous标志的想法。基本上以True的形式传递,它将强制重用迭代器的前一个值,如果False它将得到新的值。

iterator_t = ds_t.make_initializable_iterator()
iterator_v = ds_v.make_initializable_iterator()

iterator_handle = tf.placeholder(tf.string, shape=[], name="iterator_handle")
iterator = tf.data.Iterator.from_string_handle(iterator_handle,
                                               iterator_t.output_types,
                                               iterator_t.output_shapes)

def get_next_item():
  # sometimes items need casting
  next_elem = iterator.get_next(name="next_element")
  x, y = tf.cast(next_elem[0], tf.float32), next_elem[1]
  return x, y

def old_data():
        # just forward the existing batch
        return inputs, target

is_keep_previous = tf.placeholder_with_default(tf.constant(False),shape=[], name="keep_previous_flag")

inputs, target =  tf.cond(is_keep_previous, old_data, new_data)

with tf.Session() as sess:
 sess.run([tf.global_variables_initializer(),tf.local_variables_initializer()])
 handle_t = sess.run(iterator_t.string_handle())
 handle_v = sess.run(iterator_v.string_handle())
 # Run data iterator initialisation
 sess.run(iterator_t.initializer)
 sess.run(iterator_v.initializer)
 while True:
   try:
     inputs_, target_ = sess.run([inputs, target], feed_dict={iterator_handle: handle_t, is_keep_previous:False})
     print(inputs_, target_)
     inputs_, target_ = sess.run([inputs, target], feed_dict={iterator_handle: handle_t, is_keep_previous:True})
     print(inputs_, target_)
     inputs_, target_ = sess.run([inputs, target], feed_dict={iterator_handle: handle_v})
     print(inputs_, target_)
   except tf.errors.OutOfRangeError:
     # now we know we run out of elements in the validationiterator
     break

08-25 03:04