如何检查tf.Tensor是否可变?
我想断言函数的参数具有正确的类型。
tf.Tensor可以是可变的:
import tensorflow as tf
import numpy as np
x = tf.get_variable('x', shape=(2,), dtype=np.float32)
print(x[1]) # x[1] is a tf.Tensor
tf.assign(x[1], 1.0)
最佳答案
这不是公共API的一部分,但查看tf.assign
的实现方式,我认为您可以做到:
import tensorflow as tf
def is_assignable(x):
return x.dtype._is_ref_dtype or (isinstance(x, tf.Tensor) and hasattr(x, 'assign'))