问题描述
以此问题为基础 我希望在满足 tf.where 条件的情况下连续第一次更新二维张量的值.这是我用来模拟的示例代码:
Building on this question I am looking to update the values of a 2-D tensor the first time in a row the tf.where condition is met. Here is a sample code I am using to simulate:
tf.reset_default_graph()
graph = tf.Graph()
with graph.as_default():
val = "hello"
new_val = "goodbye"
matrix = tf.constant([["word","hello","hello"],
["word", "other", "hello"],
["hello", "hello","hello"],
["word", "word", "word"]
])
matching_indices = tf.where(tf.equal(matrix, val))
first_matching_idx = tf.segment_min(data = matching_indices[:, 1],
segment_ids = matching_indices[:, 0])
sess = tf.InteractiveSession(graph=graph)
print(sess.run(first_matching_idx))
这将输出 [1, 2, 0] 其中 1 是第 1 行第一个 hello 的位置,2 是第 2 行第一个 hello 的位置,0 是第一个 hello 的位置在第 3 行.
This will output [1, 2, 0] where the 1 is the placement of the first hello in row 1, the 2 is the placement of the first hello in row 2, and the 0 is the placement of the first hello in row 3.
但是,我想不出一种方法让第一个匹配的索引更新为新值——基本上我希望第一个hello"变成goodbye".我曾尝试使用 tf.scatter_update() 但它似乎不适用于 2D 张量.有没有办法修改所描述的二维张量?
However, I can't figure out a way to get the first matching index to be updated with the new value -- basically I want the first "hello" to be turned into "goodbye". I have tried using tf.scatter_update() but it does not seem to work on 2D tensors. Is there any way to modify the 2-D tensor as described?
推荐答案
一个简单的解决方法是使用 tf.py_func
和 numpy 数组
One easy workaround is to use tf.py_func
with numpy array
def ch_val(array, val, new_val):
idx = np.array([[s, list(row).index(val)]
for s, row in enumerate(array) if val in row])
idx = tuple((idx[:, 0], idx[:, 1]))
array[idx] = new_val
return array
...
matrix = tf.Variable([["word","hello","hello"],
["word", "other", "hello"],
["hello", "hello","hello"],
["word", "word", "word"]
])
matrix = tf.py_func(ch_val, [matrix, 'hello', 'goodbye'], tf.string)
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
print(sess.run(matrix))
# results: [['word' 'goodbye' 'hello']
['word' 'other' 'goodbye']
['goodbye' 'hello' 'hello']
['word' 'word' 'word']]
这篇关于Tensorflow 更新每行中的第一个匹配元素的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!