我具有以下批处理形状:

 [?,227,227]

以及以下权重变量:
 weight_tensor = tf.truncated_normal([227,227],**{'stddev':0.1,'mean':0.0})

 weight_var = tf.Variable(weight_tensor)

但是当我做tf.batch_matmul时:
 matrix = tf.batch_matmul(prev_net_2d,weight_var)

我失败并出现以下错误:



所以我的问题变成:我该怎么做?

我如何在2D中有一个weight_variable乘以每个单独的图片(227x227),以便获得(227x227)输出?此操作的平面版本完全耗尽了资源...加上渐变不会以平面形式正确更改权重...

或者:如何沿批处理维度(?,)拆分传入的张量,以便可以使用我的weight_variable在每个拆分的张量上运行tf.matmul函数?

最佳答案

您可以沿第一个维度平铺权重

weight_tensor = tf.truncated_normal([227,227],**{'stddev':0.1,'mean':0.0})
weight_var = tf.Variable(weight_tensor)
weight_var_batch = tf.tile(tf.expand_dims(weight_var, axis=0), [batch_size, 1, 1])
matrix = tf.matmul(prev_net_2d,weight_var_batch)

虽然batch_matmul不再存在

09-25 18:42