我试图理解 numpy.einsum() 函数,但是来自 stackoverflow 的文档和 this answer 仍然给我留下了一些问题。
让我们采用爱因斯坦和和答案中定义的矩阵。
A = np.array([0, 1, 2])
B = np.array([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, 11]])
np.einsum('i,ij->i', A, B)
因此,根据我对爱因斯坦总和的理解,我会将这个函数转换为等价于符号 (A_i*B_ij),因此我将获得:
j = 1 : A_1*B_11 + A_2*B_21 + A_3*B_31
j = 2 : A_1*B_12 + A_2*B_22+ A_3*B_32
依此类推,直到 j = 4。这给出
j = 1 : 0 + 4 + 16
j = 2 : 0 + 5 + 18
根据我的理解,这将是爱因斯坦求和。相反,该函数不执行整体求和,而是将单独的项存储在矩阵中,我们可以在其中发现 (A_i * B_ij) 的结果
0 0 0 0
4 5 6 7
16 18 20 22
这实际上是如何由函数控制的?我觉得这是由文档中提到的输出标签控制的:
所以不知何故,我认为放置
->i
会禁用内部总和的求和。但这究竟是如何工作的呢?这对我来说不清楚。放置 ->j
提供了预期的实际爱因斯坦和。 最佳答案
看来你对爱因斯坦求和的理解是不正确的。您写出的下标运算的乘法正确,但求和在错误的轴上。
想想这意味着什么: np.einsum('i,ij->i', A, B)
。
A
的形状为 (i,)
,B
的形状为 (i, j)
。 B
的每一列乘以 A
。 B
的第二个轴上求和,即在标记为 j
的轴上。 这给出了形状
(i,) == (3,)
的输出,而您的下标符号给出了形状 (j,) == (4,)
的输出。你在错误的轴上求和。更多细节:
请记住,乘法总是首先发生。左边的下标告诉
np.einsum
函数输入数组的哪些行/列/等要彼此相乘。此步骤的输出始终与最高维输入数组具有相同的形状。即,此时,假设的“中间”数组的形状为 (3, 4) == B.shape
。乘法之后就是求和。这是由从右侧省略哪些下标来控制的。在这种情况下,省略了
j
,这意味着沿数组的第一个轴求和。 (您正在沿第零进行求和。)如果您改为写道:
np.einsum('i,ij->ij', A, B)
,则不会有求和,因为没有省略下标。因此,您将获得问题末尾的数组。下面是几个例子:
Ex 1:
没有省略下标,所以没有求和。只需将
B
的列乘以 A
。这是您写出的最后一个数组。>>> (np.einsum('i,ij->ij', A, B) == (A[:, None] * B)).all()
True
Ex 2:
与示例相同。乘以列,然后对输出的列求和。
>>> (np.einsum('i,ij->i', A, B) == (A[:, None] * B).sum(axis=-1)).all()
True
例 3:
你上面写的总和。乘以列,然后对输出的行求和。
>>> (np.einsum('i,ij->j', A, B) == (A[:, None] * B).sum(axis=0)).all()
True
Ex 4:
请注意,我们可以在末尾省略所有轴,以获取整个数组的总和。
>>> np.einsum('i,ij->', A, B)
98
Ex 5:
请注意,求和确实发生了,因为我们重复了输入标签
'i'
。如果我们对输入数组的每个轴使用不同的标签,我们可以计算类似于 Kronecker 乘积的东西:>>> np.einsum('i,jk', A, B).shape
(3, 3, 4)
编辑
爱因斯坦求和的 NumPy 实现与传统定义略有不同。从技术上讲,爱因斯坦总和没有“输出标签”的概念。这些总是由重复的输入标签暗示。
来自文档:
"Whenever a label is repeated, it is summed."
所以,传统上,我们会写一些类似 np.einsum('i,ij', A, B)
的东西。这相当于 np.einsum('i,ij->j', A, B)
。 i
被重复,所以它被求和,只留下标记为 j
的轴。您可以将不指定输出标签的总和视为与仅指定在输入中不重复的标签相同。也就是说,标签 'i,ij'
与 'i,ij->j'
相同。输出标签是在 NumPy 中实现的扩展或扩充,它允许调用者强制求和或在轴上强制不求和。来自文档:
"The output can be controlled by specifying output subscript labels as well. This specifies the label order, and allows summing to be disallowed or forced when desired."
关于python - 关于 numpy.einsum() 的附加信息,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/47366812/