更新每行中的第一个匹配元素

更新每行中的第一个匹配元素

本文介绍了Tensorflow 更新每行中的第一个匹配元素的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

以此问题为基础 我希望在满足 tf.where 条件的情况下连续第一次更新二维张量的值.这是我用来模拟的示例代码:

Building on this question I am looking to update the values of a 2-D tensor the first time in a row the tf.where condition is met. Here is a sample code I am using to simulate:

tf.reset_default_graph()
graph = tf.Graph()
with graph.as_default():
    val = "hello"
    new_val = "goodbye"
    matrix = tf.constant([["word","hello","hello"],
                          ["word", "other", "hello"],
                          ["hello", "hello","hello"],
                          ["word", "word", "word"]
                         ])
    matching_indices = tf.where(tf.equal(matrix, val))
    first_matching_idx = tf.segment_min(data = matching_indices[:, 1],
                                 segment_ids = matching_indices[:, 0])

sess = tf.InteractiveSession(graph=graph)
print(sess.run(first_matching_idx))

这将输出 [1, 2, 0] 其中 1 是第 1 行第一个 hello 的位置,2 是第 2 行第一个 hello 的位置,0 是第一个 hello 的位置在第 3 行.

This will output [1, 2, 0] where the 1 is the placement of the first hello in row 1, the 2 is the placement of the first hello in row 2, and the 0 is the placement of the first hello in row 3.

但是,我想不出一种方法让第一个匹配的索引更新为新值——基本上我希望第一个hello"变成goodbye".我曾尝试使用 tf.scatter_update() 但它似乎不适用于 2D 张量.有没有办法修改所描述的二维张量?

However, I can't figure out a way to get the first matching index to be updated with the new value -- basically I want the first "hello" to be turned into "goodbye". I have tried using tf.scatter_update() but it does not seem to work on 2D tensors. Is there any way to modify the 2-D tensor as described?

推荐答案

一个简单的解决方法是使用 tf.py_func 和 numpy 数组

One easy workaround is to use tf.py_func with numpy array

def ch_val(array, val, new_val):
    idx = np.array([[s, list(row).index(val)]
                for s, row in enumerate(array) if val in row])
    idx = tuple((idx[:, 0], idx[:, 1]))
    array[idx] = new_val
    return array

...
matrix = tf.Variable([["word","hello","hello"],
                      ["word", "other", "hello"],
                      ["hello", "hello","hello"],
                      ["word", "word", "word"]
                     ])
matrix = tf.py_func(ch_val, [matrix, 'hello', 'goodbye'], tf.string)
with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    print(sess.run(matrix))
    # results: [['word' 'goodbye' 'hello']
      ['word' 'other' 'goodbye']
      ['goodbye' 'hello' 'hello']
      ['word' 'word' 'word']]

这篇关于Tensorflow 更新每行中的第一个匹配元素的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

08-05 20:17