问题描述
我目前正在编写一个 tensorflow 程序,该程序需要将一批二维张量(形状为 [None,...]
的 3-D 张量)与一个二维矩阵相乘W
.这需要将 W
转换为 3-D 矩阵,这需要知道批量大小.
I'm currently writing a tensorflow program that requires multiplying a batch of 2-D tensors (a 3-D tensor of shape [None,...]
) with a 2-D matrix W
. This requires turning W
into a 3-D matrix, which requires knowing the batch size.
我无法做到这一点;tf.batch_matmul
不再可用,x.get_shape().as_list()[0]
返回 None
,对于整形无效/平铺操作.有什么建议?我看到有些人使用 config.cfg.batch_size
,但我不知道那是什么.
I have not been able to do this; tf.batch_matmul
is no longer usable, x.get_shape().as_list()[0]
returns None
, which is invalid for a reshaping/tiling operation. Any suggestions? I've seen some people use config.cfg.batch_size
, but I don't know what that is.
推荐答案
解决方案是结合使用 tf.shape
(返回运行时的形状)和tf.tile
(接受动态形状).
Solution is to use a combination of tf.shape
(which returns the shape at runtime) and tf.tile
(which accepts the dynamic shape).
x = tf.placeholder(shape=[None, 2, 3], dtype=tf.float32)
W = tf.Variable(initial_value=np.ones([3, 4]), dtype=tf.float32)
print(x.shape) # Dynamic shape: (?, 2, 3)
batch_size = tf.shape(x)[0] # A tensor that gets the batch size at runtime
W_expand = tf.expand_dims(W, axis=0)
W_tile = tf.tile(W_expand, multiples=[batch_size, 1, 1])
result = tf.matmul(x, W_tile) # Can multiply now!
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
feed_dict = {x: np.ones([10, 2, 3])}
print(sess.run(batch_size, feed_dict=feed_dict)) # 10
print(sess.run(result, feed_dict=feed_dict).shape) # (10, 2, 4)
这篇关于不知道批量大小的 3-D 批量矩阵乘法的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!