问题描述
我正在使用cs231n,并且在理解此索引的工作方式方面遇到了困难.鉴于
I'm working on cs231n and I'm having a difficult time understanding how this indexing works. Given that
x = [[0,4,1], [3,2,4]]
dW = np.zeros(5,6)
dout = [[[ 1.19034710e-01 -4.65005990e-01 8.93743168e-01 -9.78047129e-01
-8.88672957e-01 -4.66605091e-01]
[ -1.38617461e-03 -2.64569728e-01 -3.83712733e-01 -2.61360826e-01
8.07072009e-01 -5.47607277e-01]
[ -3.97087458e-01 -4.25187949e-02 2.57931759e-01 7.49565950e-01
1.37707667e+00 1.77392240e+00]]
[[ -1.20692745e+00 -8.28111550e-01 6.53041092e-01 -2.31247762e+00
-1.72370321e+00 2.44308033e+00]
[ -1.45191870e+00 -3.49328154e-01 6.15445782e-01 -2.84190582e-01
4.85997687e-02 4.81590106e-01]
[ -1.14828583e+00 -9.69055406e-01 -1.00773809e+00 3.63553835e-01
-1.28078363e+00 -2.54448436e+00]]]
他们所做的操作是
np.add.at(dW, x, dout)
x是一个二维数组.索引在这里如何工作?我浏览了np.ufunc.at
文档,但是他们有带有1d数组和常量的简单示例:
x is a two dimensional array. How does indexing work here? I went through np.ufunc.at
documentation but they have simple examples with 1d array and constant:
np.add.at(a, [0, 1, 2, 2], 1)
推荐答案
In [226]: x = [[0,4,1], [3,2,4]]
...: dW = np.zeros((5,6),int)
In [227]: np.add.at(dW,x,1)
In [228]: dW
Out[228]:
array([[0, 0, 0, 1, 0, 0],
[0, 0, 0, 0, 1, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 0, 0]])
使用此x
时,没有任何重复的条目,因此add.at
与使用+=
索引相同.等效地,我们可以通过以下方式读取更改的值:
With this x
there aren't any duplicate entries, so add.at
is the same as using +=
indexing. Equivalently we can read the changed values with:
In [229]: dW[x[0], x[1]]
Out[229]: array([1, 1, 1])
两种索引的工作方式相同,包括广播:
The indices work the same either way, including broadcasting:
In [234]: dW[...]=0
In [235]: np.add.at(dW,[[[1],[2]],[2,4,4]],1)
In [236]: dW
Out[236]:
array([[0, 0, 0, 0, 0, 0],
[0, 0, 1, 0, 2, 0],
[0, 0, 1, 0, 2, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]])
可能的值
相对于索引,值必须为broadcastable
:
In [112]: np.add.at(dW,[[[1],[2]],[2,4,4]],np.ones((2,3)))
...
In [114]: np.add.at(dW,[[[1],[2]],[2,4,4]],np.ones((2,3)).ravel())
...
ValueError: array is not broadcastable to correct shape
In [115]: np.add.at(dW,[[[1],[2]],[2,4,4]],[1,2,3])
In [117]: np.add.at(dW,[[[1],[2]],[2,4,4]],[[1],[2]])
In [118]: dW
Out[118]:
array([[ 0, 0, 0, 0, 0, 0],
[ 0, 0, 3, 0, 9, 0],
[ 0, 0, 4, 0, 11, 0],
[ 0, 0, 0, 0, 0, 0],
[ 0, 0, 0, 0, 0, 0]])
在这种情况下,索引定义了(2,3)形状,因此(2,3),(3,),(2,1)和标量值起作用. (6,)不会.
In this case the indices define a (2,3) shape, so (2,3),(3,), (2,1), and scalar values work. (6,) does not.
在这种情况下,add.at
将(2,3)数组映射到dW
的(2,2)子数组上.
In this case, add.at
is mapping a (2,3) array onto a (2,2) subarray of dW
.
这篇关于np.add.at与数组建立索引的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!