获取可变批次维度的大小

获取可变批次维度的大小

本文介绍了获取可变批次维度的大小的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

假设网络的输入是一个批大小可变的placeholder,即:

assuming the input to the network is a placeholder with variable batch size, i.e.:

x = tf.placeholder(..., shape=[None, ...])

是否可以在输入后获得 x 的形状?tf.shape(x)[0] 仍然返回 None.

is it possible to get the shape of x after it has been fed? tf.shape(x)[0] still returns None.

推荐答案

如果 x 具有可变的批大小,则获得实际形状的唯一方法是使用 tf.shape() 运算符.此运算符在 tf.张量,因此它可以用作其他TensorFlow操作的输入,但要获得形状的具体Python值,您需要将其传递给Session.run().

If x has a variable batch size, the only way to get the actual shape is to use the tf.shape() operator. This operator returns a symbolic value in a tf.Tensor, so it can be used as the input to other TensorFlow operations, but to get a concrete Python value for the shape, you need to pass it to Session.run().

x = tf.placeholder(..., shape=[None, ...])
batch_size = tf.shape(x)[0]  # Returns a scalar `tf.Tensor`

print x.get_shape()[0]  # ==> "?"

# You can use `batch_size` as an argument to other operators.
some_other_tensor = ...
some_other_tensor_reshaped = tf.reshape(some_other_tensor, [batch_size, 32, 32])

# To get the value, however, you need to call `Session.run()`.
sess = tf.Session()
x_val = np.random.rand(37, 100, 100)
batch_size_val = sess.run(batch_size, {x: x_val})
print x_val  # ==> "37"

这篇关于获取可变批次维度的大小的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!

08-13 08:57