SparseTensorDenseMatMul

SparseTensorDenseMatMul

我正在使用while_loop来迭代更新矩阵。循环在密集张量下运行良好,但是当我使用稀疏张量时,出现以下错误:


  InvalidArgumentError:a_indices的行数不匹配
  a_values [[节点:
  while / SparseTensorDenseMatMul / SparseTensorDenseMatMul =
  SparseTensorDenseMatMul [T = DT_FLOAT,Tindices = DT_INT64,
  adjoint_a = false,adjoint_b = false,
  _device =“ / job:localhost / replica:0 / task:0 / device:GPU:0”](while / SparseTensorDenseMatMul / SparseTensorDenseMatMul / Enter,
  while / SparseTensorDenseMatMul / SparseTensorDenseMatMul / Enter_1,
  ConstantFolding / dense_to_sparse / Shape_enter / _1,而/ Switch_1:1)]]]
  [[Node:while / Exit_1 / _5 = _Recvclient_terminated = false,
  recv_device =“ / job:localhost /副本:0 / task:0 / device:CPU:0”,
  send_device =“ / job:localhost /副本:0 / task:0 / device:GPU:0”,
  send_device_incarnation = 1,tensor_name =“ edge_62_while / Exit_1”,
  tensor_type = DT_FLOAT,
  _device =“ / job:localhost /副本:0 / task:0 / device:CPU:0”]]


我在两个版本之间进行的唯一更改是将HH转换为HH = tf.contrib.layers.dense_to_sparse(HH)并使用tf.sparse_tensor_dense_matmul(HH,f)而不是tf.matmul(HH,f)-如图所示下面的注释代码。

with tf.device('/gpu:0'):
    g=tf.constant(g,shape=[np.size(g),1],dtype=tf.float32)
    H=tf.constant(H,dtype=tf.float32);
    Ht=tf.transpose(H)
    HH=tf.matmul(Ht,H)
    #HH=tf.contrib.layers.dense_to_sparse(HH)
    a=tf.matmul(Ht,g)
    i=tf.constant(0,dtype=tf.int32)
    f=tf.constant(f,dtype=tf.float32)
    body = lambda i,f:(tf.add(i,1),tf.divide(tf.multiply(f,a),tf.matmul(HH,f)+10e-9))
    #body = lambda i,f:(tf.add(i,1),tf.divide(tf.multiply(f,a),tf.sparse_tensor_dense_matmul(HH,f)+10e-9))
    cond= lambda i,f:tf.less(i,iterations)
    i,f=tf.while_loop(cond,body,(i,f))
sess=tf.Session()
i,f=sess.run([i,f])


请注意,只要H,g和f足够小,此代码就可以工作。例如,对于H.shape =(8000,3840),g.shape =(8000,1),f.shape =(3840,1)和更大的值,会发生此错误,但对于H.shape =(8000, 3584),g.shape =(8000,1),f.shape =(3584,1)和更小。我需要在while循环中对稀疏张量做一些特殊的事情以确保它们保持其形状吗?

最佳答案

我尝试从tensorflow 1.8更新到1.12,并且tensorflow完全停止工作(ts.Session会无限期挂起)。因此,我改变了anaconda环境,并从tensorflow 1.12重新开始。在此更新/重新安装后,稀疏张量的问题消失了,尽管尚不清楚该问题是否与tensorflow的版本或我的anaconda环境中的其他问题有关。

关于python - 使用稀疏张量的while_loop中的InvalidArgumentError,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/54444626/

10-12 18:50