本文介绍了如何实现标记嵌入的中心损失和其他运行平均值的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

最近的一篇论文(此处)介绍了一个二级损失函数,他们称之为中心损失.它基于批次中嵌入与每个相应类的运行平均嵌入之间的距离.在 TF Google 群组中进行了一些讨论(此处) 关于如何计算和更新此类嵌入中心.我在下面的答案中汇总了一些代码来生成类平均嵌入.

A recent paper (here) introduced a secondary loss function that they called center loss. It is based on the distance between the embeddings in a batch and the running average embedding for each of the respective classes. There has been some discussion in the TF Google groups (here) regarding how such embedding centers can be computed and updated. I've put together some code to generate class-average embeddings in my answer below.

这是最好的方法吗?

推荐答案

之前发布的方法对于中心损失等情况来说太简单了,在这种情况下,随着模型变得更加精细,嵌入的期望值会随着时间而变化.这是因为之前的中心寻找例程自开始以来对所有实例进行平均,因此非常缓慢地跟踪预期值的变化.相反,移动窗口平均值是首选.指数移动窗口变体如下:

The previously posted method is too simple for cases like center loss where the expected value of the embeddings change over time as the model becomes more refined. This is because the previous center-finding routine averages all instances since start and therefore tracks changes in expected value very slowly. Instead, a moving window average is preferred. An exponential moving-window variant is as follows:

def get_embed_centers(embed_batch, label_batch):
    ''' Exponential moving window average. Increase decay for longer windows [0.0 1.0]
    '''
    decay = 0.95
    with tf.variable_scope('embed', reuse=True):
        embed_ctrs = tf.get_variable("ctrs")

    label_batch = tf.reshape(label_batch, [-1])
    old_embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
    dif = (1 - decay) * (old_embed_ctrs_batch - embed_batch)
    embed_ctrs = tf.scatter_sub(embed_ctrs, label_batch, dif)
    embed_ctrs_batch = tf.gather(embed_ctrs, label_batch)
    return embed_ctrs_batch


with tf.Session() as sess:
    with tf.variable_scope('embed'):
        embed_ctrs = tf.get_variable("ctrs", [nclass, ndims], dtype=tf.float32,
                        initializer=tf.constant_initializer(0), trainable=False)
    label_batch_ph = tf.placeholder(tf.int32)
    embed_batch_ph = tf.placeholder(tf.float32)
    embed_ctrs_batch = get_embed_centers(embed_batch_ph, label_batch_ph)
    sess.run(tf.initialize_all_variables())
    tf.get_default_graph().finalize()

这篇关于如何实现标记嵌入的中心损失和其他运行平均值的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

10-19 17:41