我正在尝试重新训练inception-resnet-v2的最后一层。这是我想出的:

  • 获取最后一层
  • 中的变量名称
  • 创建train_op以仅最小化这些变量wrt loss
  • 还原除最后一层以外的整个图,同时仅随机初始化最后一层。

  • 我实现了如下:

    with slim.arg_scope(arg_scope):
        logits = model(images_ph, is_training=True, reuse=None)
    loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(logits, labels_ph))
    accuracy = tf.contrib.metrics.accuracy(tf.argmax(logits, 1), labels_ph)
    
    train_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, 'InceptionResnetV2/Logits')
    optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
    
    train_op = optimizer.minimize(loss, var_list=train_list)
    
    # restore all variables whose names doesn't contain 'logits'
    restore_list = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='^((?!Logits).)*$')
    
    saver = tf.train.Saver(restore_list, write_version=tf.train.SaverDef.V2)
    
    with tf.Session() as session:
    
    
        init_op = tf.group(tf.local_variables_initializer(), tf.global_variables_initializer())
    
        session.run(init_op)
        saver.restore(session, '../models/inception_resnet_v2_2016_08_30.ckpt')
    
    
    # followed by code for running train_op
    

    这似乎不起作用(训练损耗,误差与初始值相比没有太大改善)。有没有更好/更好的方法来做到这一点?如果您还可以告诉我这里出了什么问题,对我来说是个好学习。

    最佳答案

    有几件事:

  • 学习率如何?太高的值会弄乱所有内容(可能不是原因)
  • 尝试使用随机梯度下降法,
  • 应该少一些问题
  • 是否正确设置了范围?如果您不使用梯度的L2正则化和批量归一化,您可能很快就会陷入局部最小值,并且网络将无法学习
    from nets import inception_resnet_v2 as net
    with net.inception_resnet_v2_arg_scope():
        logits, end_points = net.inception_resnet_v2(images_ph, num_classes=num_classes,
                                                     is_training=True)
    
  • 您应该将正则化变量添加到损失(或至少是最后一层的损失)中:
    regularization_losses = tf.get_collection(tf.GraphKeys.REGULARIZATION_LOSSES)
    all_losses = [loss] + regularization_losses
    total_loss = tf.add_n(all_losses, name='total_loss')
    
  • 仅训练完整的连接层可能不是一个好主意,我会训练所有网络,因为不一定需要在最后一层中定义类所需的功能,而在前几层中需要定义这些功能,因此您需要进行更改。
  • 再次检查train_op运行后是否丢失:
    with ops.name_scope('train_op'):
        train_op = control_flow_ops.with_dependencies([train_op], total_loss)
    
  • 关于python - 训练Inception-ResNet-v2的最后一层,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/41407124/

    10-12 23:52