本文介绍了设置csr_matrix的行的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!
问题描述
我有一个稀疏的csr_matrix,我想将单行的值更改为不同的值.但是,我找不到一种简单有效的实现方式.这是它要做的:
I have a sparse csr_matrix, and I want to change the values of a single row to different values. I can't find an easy and efficient implementation however. This is what it has to do:
A = csr_matrix([[0, 1, 0],
[1, 0, 1],
[0, 1, 0]])
new_row = np.array([-1, -1, -1])
print(set_row_csr(A, 2, new_row).todense())
>>> [[ 0, 1, 0],
[ 1, 0, 1],
[-1, -1, -1]]
这是我目前对set_row_csr
的实现:
def set_row_csr(A, row_idx, new_row):
A[row_idx, :] = new_row
return A
但这给了我SparseEfficiencyWarning
.有没有一种方法可以在不进行手动索引操作的情况下完成此操作,或者这是我唯一的出路吗?
But this gives me a SparseEfficiencyWarning
. Is there a way of getting this done without manual index juggling, or is this my only way out?
推荐答案
最后,我设法通过索引变戏法来完成此操作.
In the end, I managed to get this done with index juggling.
def set_row_csr(A, row_idx, new_row):
'''
Replace a row in a CSR sparse matrix A.
Parameters
----------
A: csr_matrix
Matrix to change
row_idx: int
index of the row to be changed
new_row: np.array
list of new values for the row of A
Returns
-------
None (the matrix A is changed in place)
Prerequisites
-------------
The row index shall be smaller than the number of rows in A
The number of elements in new row must be equal to the number of columns in matrix A
'''
assert sparse.isspmatrix_csr(A), 'A shall be a csr_matrix'
assert row_idx < A.shape[0], \
'The row index ({0}) shall be smaller than the number of rows in A ({1})' \
.format(row_idx, A.shape[0])
try:
N_elements_new_row = len(new_row)
except TypeError:
msg = 'Argument new_row shall be a list or numpy array, is now a {0}'\
.format(type(new_row))
raise AssertionError(msg)
N_cols = A.shape[1]
assert N_cols == N_elements_new_row, \
'The number of elements in new row ({0}) must be equal to ' \
'the number of columns in matrix A ({1})' \
.format(N_elements_new_row, N_cols)
idx_start_row = A.indptr[row_idx]
idx_end_row = A.indptr[row_idx + 1]
additional_nnz = N_cols - (idx_end_row - idx_start_row)
A.data = np.r_[A.data[:idx_start_row], new_row, A.data[idx_end_row:]]
A.indices = np.r_[A.indices[:idx_start_row], np.arange(N_cols), A.indices[idx_end_row:]]
A.indptr = np.r_[A.indptr[:row_idx + 1], A.indptr[(row_idx + 1):] + additional_nnz]
这篇关于设置csr_matrix的行的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!