我正在使用while_loop来迭代更新矩阵。循环在密集张量下运行良好,但是当我使用稀疏张量时,出现以下错误:
InvalidArgumentError:a_indices的行数不匹配
a_values [[节点:
while / SparseTensorDenseMatMul / SparseTensorDenseMatMul =
SparseTensorDenseMatMul [T = DT_FLOAT,Tindices = DT_INT64,
adjoint_a = false,adjoint_b = false,
_device =“ / job:localhost / replica:0 / task:0 / device:GPU:0”](while / SparseTensorDenseMatMul / SparseTensorDenseMatMul / Enter,
while / SparseTensorDenseMatMul / SparseTensorDenseMatMul / Enter_1,
ConstantFolding / dense_to_sparse / Shape_enter / _1,而/ Switch_1:1)]]]
[[Node:while / Exit_1 / _5 = _Recvclient_terminated = false,
recv_device =“ / job:localhost /副本:0 / task:0 / device:CPU:0”,
send_device =“ / job:localhost /副本:0 / task:0 / device:GPU:0”,
send_device_incarnation = 1,tensor_name =“ edge_62_while / Exit_1”,
tensor_type = DT_FLOAT,
_device =“ / job:localhost /副本:0 / task:0 / device:CPU:0”]]
我在两个版本之间进行的唯一更改是将HH转换为HH = tf.contrib.layers.dense_to_sparse(HH)并使用tf.sparse_tensor_dense_matmul(HH,f)而不是tf.matmul(HH,f)-如图所示下面的注释代码。
with tf.device('/gpu:0'):
g=tf.constant(g,shape=[np.size(g),1],dtype=tf.float32)
H=tf.constant(H,dtype=tf.float32);
Ht=tf.transpose(H)
HH=tf.matmul(Ht,H)
#HH=tf.contrib.layers.dense_to_sparse(HH)
a=tf.matmul(Ht,g)
i=tf.constant(0,dtype=tf.int32)
f=tf.constant(f,dtype=tf.float32)
body = lambda i,f:(tf.add(i,1),tf.divide(tf.multiply(f,a),tf.matmul(HH,f)+10e-9))
#body = lambda i,f:(tf.add(i,1),tf.divide(tf.multiply(f,a),tf.sparse_tensor_dense_matmul(HH,f)+10e-9))
cond= lambda i,f:tf.less(i,iterations)
i,f=tf.while_loop(cond,body,(i,f))
sess=tf.Session()
i,f=sess.run([i,f])
请注意,只要H,g和f足够小,此代码就可以工作。例如,对于H.shape =(8000,3840),g.shape =(8000,1),f.shape =(3840,1)和更大的值,会发生此错误,但对于H.shape =(8000, 3584),g.shape =(8000,1),f.shape =(3584,1)和更小。我需要在while循环中对稀疏张量做一些特殊的事情以确保它们保持其形状吗?
最佳答案
我尝试从tensorflow 1.8更新到1.12,并且tensorflow完全停止工作(ts.Session会无限期挂起)。因此,我改变了anaconda环境,并从tensorflow 1.12重新开始。在此更新/重新安装后,稀疏张量的问题消失了,尽管尚不清楚该问题是否与tensorflow的版本或我的anaconda环境中的其他问题有关。
关于python - 使用稀疏张量的while_loop中的InvalidArgumentError,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/54444626/