我想为我的模型创建一个新的评估指标(均值)。
假设我有:


logits形状为(None, n_class)的张量和
形状为y_target(None, )张量,其中包含从int0n_class-1值。
None是批次大小。


我希望我的输出是形状为(None, )的张量,并具有对应的y_target的倒数。
首先,我需要在logits中对元素进行排名,然后在索引y_target中获取元素的排名,最后,获取其倒数(或x + 1的倒数,具体取决于排名过程)。

一个简单的示例(用于单个观察):
如果我的y_target=1logits=[0.5, -2.0, 1.1, 3.5]
那么排名是logits_rank=[3, 4, 2, 1]
倒数是1.0 / logits_rank[y_target] = 0.25

这里的挑战是要在轴上应用一个函数,因为等级是未知的(在图形级别)。
我已经设法使用tf.nn.top_k(logits, k=n_class, sorted=True).indices取得了一些结果,但仅在session.run(sess, feed_dict)内。

任何帮助将不胜感激!

最佳答案

解决了!


 def tf_get_rank_order(input, reciprocal):
    """
    Returns a tensor of the rank of the input tensor's elements.
    rank(highest element) = 1.
    """
    assert isinstance(reciprocal, bool), 'reciprocal has to be bool'
    size = tf.size(input)
    indices_of_ranks = tf.nn.top_k(-input, k=size)[1]
    indices_of_ranks = size - tf.nn.top_k(-indices_of_ranks, k=size)[1]
    if reciprocal:
        indices_of_ranks = tf.cast(indices_of_ranks, tf.float32)
        indices_of_ranks = tf.map_fn(
            lambda x: tf.reciprocal(x), indices_of_ranks,
            dtype=tf.float32)
        return indices_of_ranks
    else:
        return indices_of_ranks


def get_reciprocal_rank(logits, targets, reciprocal=True):
    """
    Returns a tensor containing the (reciprocal) ranks
    of the logits tensor (wrt the targets tensor).
    The targets tensor should be a 'one hot' vector
    (otherwise apply one_hot on targets, such that index_mask is a one_hot).
    """
    function_to_map = lambda x: tf_get_rank_order(x, reciprocal=reciprocal)
    ordered_array_dtype = tf.float32 if reciprocal is not None else tf.int32
    ordered_array = tf.map_fn(function_to_map, logits,
                              dtype=ordered_array_dtype)

    size = int(logits.shape[1])
    index_mask = tf.reshape(
            targets, [-1,size])
    if reciprocal:
        index_mask = tf.cast(index_mask, tf.float32)

    return tf.reduce_sum(ordered_array * index_mask,1)

# use:
recip_rank = tf.reduce_mean(
                 get_reciprocal_rank(logits[-1],
                                     y_,
                                     True)

09-18 19:03