我正在将一个numpy代码转换为Tensorflow。
它具有以下行:
netout[..., 5:] *= netout[..., 5:] > obj_threshold
这与Tensorflow语法不同,我很难找到具有相同行为的函数。
首先,我尝试:
netout[..., 5:] * netout[..., 5:] > obj_threshold
但是返回的只是布尔型张量。
在这种情况下,我希望
obj_threshold
以下的所有值都为0。 最佳答案
如果您只想将obj_threshold
以下的所有值设为0,则可以执行以下操作:
netout = tf.where(netout > obj_threshold, netout, tf.zeros_like(netout))
要么:
netout = netout * tf.cast(netout > obj_threshold, netout.dtype)
但是,您的情况要复杂一些,因为您只希望更改影响张量的一部分。因此,您可以做的是为大于
True
的值或最后一个索引低于5的值创建一个布尔掩码为obj_threshold
。mask = (netout > obj_threshold) | (tf.range(tf.shape(netout)[-1]) < 5)
然后,您可以将其与任何以前的方法一起使用:
netout = tf.where(mask, netout, tf.zeros_like(netout))
netout = netout * tf.cast(mask, netout.dtype)