我是numpy的新手,因此在可视化numpy.tensordot()
函数的工作时遇到了一些问题。根据tensordot
的文档,在参数中传递轴,其中轴= 0或1表示法线矩阵乘法,而轴= 2表示收缩。
有人可以解释在给定的例子中乘法将如何进行吗?
示例1:a=[1,1] b=[2,2] for axes=0,1
为什么对于轴= 2会引发错误?
示例2:a=[[1,1],[1,1]] b=[[2,2],[2,2]] for axes=0,1,2
最佳答案
编辑:此答案的最初焦点是在axes
是元组的情况下,为每个参数指定一个或多个轴。这种用法使我们能够对常规dot
进行变化,特别是对于大于2d的数组(在链接的问题中,我的答案也是https://stackoverflow.com/a/41870980/901925)。标量轴是一个特例,它被转换为元组版本。因此,它的核心仍然是dot
产品。
轴作为元组
In [235]: a=[1,1]; b=[2,2]
a
和b
是列表; tensordot
将它们转换为数组。In [236]: np.tensordot(a,b,(0,0))
Out[236]: array(4)
由于它们都是一维数组,因此我们将轴值指定为0。
如果我们尝试指定1:
In [237]: np.tensordot(a,b,(0,1))
---------------------------------------------------------------------------
1282 else:
1283 for k in range(na):
-> 1284 if as_[axes_a[k]] != bs[axes_b[k]]:
1285 equal = False
1286 break
IndexError: tuple index out of range
它正在检查
a
的轴0的大小是否与b
的轴1的大小匹配。但是由于b
是1d,因此无法检查。In [239]: np.array(a).shape[0]
Out[239]: 2
In [240]: np.array(b).shape[1]
IndexError: tuple index out of range
您的第二个示例是2d数组:
In [242]: a=np.array([[1,1],[1,1]]); b=np.array([[2,2],[2,2]])
指定
a
的最后一个轴和b
的第一个轴(倒数第二个),将产生常规矩阵(点)乘积:In [243]: np.tensordot(a,b,(1,0))
Out[243]:
array([[4, 4],
[4, 4]])
In [244]: a.dot(b)
Out[244]:
array([[4, 4],
[4, 4]])
更好的诊断价值:
In [250]: a=np.array([[1,2],[3,4]]); b=np.array([[2,3],[2,1]])
In [251]: np.tensordot(a,b,(1,0))
Out[251]:
array([[ 6, 5],
[14, 13]])
In [252]: np.dot(a,b)
Out[252]:
array([[ 6, 5],
[14, 13]])
In [253]: np.tensordot(a,b,(0,1))
Out[253]:
array([[11, 5],
[16, 8]])
In [254]: np.dot(b,a) # same numbers, different layout
Out[254]:
array([[11, 16],
[ 5, 8]])
In [255]: np.dot(b,a).T
Out[255]:
array([[11, 5],
[16, 8]])
另一个配对:
In [256]: np.tensordot(a,b,(0,0))
In [257]: np.dot(a.T,b)
轴的(0,1,2)是错误的。 axis参数应为2个数字或2个元组,分别与2个参数相对应。
tensordot
中的基本处理是对输入进行转置和整形,以便随后可以将结果传递给常规(矩阵a的倒数,b的倒数第二)的np.dot
。标量轴
如果我正确阅读
tensordot
代码,则axes
参数将转换为两个列表,其中包括:def foo(axes):
try:
iter(axes)
except Exception:
axes_a = list(range(-axes, 0))
axes_b = list(range(0, axes))
else:
axes_a, axes_b = axes
try:
na = len(axes_a)
axes_a = list(axes_a)
except TypeError:
axes_a = [axes_a]
na = 1
try:
nb = len(axes_b)
axes_b = list(axes_b)
except TypeError:
axes_b = [axes_b]
nb = 1
return axes_a, axes_b
对于标量值0,1,2,结果为:
In [281]: foo(0)
Out[281]: ([], [])
In [282]: foo(1)
Out[282]: ([-1], [0])
In [283]: foo(2)
Out[283]: ([-2, -1], [0, 1])
axes=1
与在元组中指定的相同:In [284]: foo((-1,0))
Out[284]: ([-1], [0])
对于2:
In [285]: foo(((-2,-1),(0,1)))
Out[285]: ([-2, -1], [0, 1])
在我的最新示例中,
axes=2
与在2个数组的所有轴上指定dot
相同:In [287]: np.tensordot(a,b,axes=2)
Out[287]: array(18)
In [288]: np.tensordot(a,b,axes=((0,1),(0,1)))
Out[288]: array(18)
这与在数组的展平的1d视图上执行
dot
相同:In [289]: np.dot(a.ravel(), b.ravel())
Out[289]: 18
我已经展示了用于这些阵列的常规点积,即
axes=1
盒。axes=0
与axes=((),())
相同,两个数组没有求和轴:In [292]: foo(((),()))
Out[292]: ([], [])
np.tensordot(a,b,((),()))
与np.tensordot(a,b,axes=0)
相同输入数组为1d时,正是
-2
转换中的foo(2)
给您带来了问题。 axes=1
是一维数组的“收缩”。换句话说,不要从字面上太用文字描述。他们只是试图描述代码的作用。它们不是正式规范。等效值
我认为
einsum
的轴规格更清晰,功能更强大。这是0,1,2的等效项In [295]: np.einsum('ij,kl',a,b)
Out[295]:
array([[[[ 2, 3],
[ 2, 1]],
[[ 4, 6],
[ 4, 2]]],
[[[ 6, 9],
[ 6, 3]],
[[ 8, 12],
[ 8, 4]]]])
In [296]: np.einsum('ij,jk',a,b)
Out[296]:
array([[ 6, 5],
[14, 13]])
In [297]: np.einsum('ij,ij',a,b)
Out[297]: 18
axes = 0的情况,等效于:
np.dot(a[:,:,None],b[:,None,:])
它添加了一个新的最后一个轴和新的第二个到最后一个轴,并对它们进行了常规的点积求和。但是我们通常在广播中进行这种“外部”乘法:
a[:,:,None,None]*b[None,None,:,:]
虽然将0,1,2用于轴很有趣,但实际上并没有增加新的计算能力。轴的元组形式更强大,更有用。
代码摘要(重要步骤)
1-按照上述
axes
函数的摘录将axes_a
转换为axes_b
和foo
2-将
a
和b
分成数组,并获得形状和ndim3-检查要累加(收缩)的轴上的匹配尺寸
4-构造一个
newshape_a
和newaxes_a
; b
相同(复杂步骤)5-
at = a.transpose(newaxes_a).reshape(newshape_a)
;与b
相同6-
res = dot(at, bt)
7-将
res
重塑为所需的返回形状5和6是计算核心。 4从概念上讲是最复杂的步骤。对于所有
axes
值,计算结果是相同的(dot
乘积),但设置有所不同。超过0,1,2
虽然文档仅提及标量轴为0,1,2,但代码并不限于这些值
In [331]: foo(3)
Out[331]: ([-3, -2, -1], [0, 1, 2])
如果输入为3,则轴= 3应该起作用:
In [330]: np.tensordot(np.ones((2,2,2)), np.ones((2,2,2)), axes=3)
Out[330]: array(8.)
或更一般而言:
In [325]: np.tensordot(np.ones((2,2,2)), np.ones((2,2,2)), axes=0).shape
Out[325]: (2, 2, 2, 2, 2, 2)
In [326]: np.tensordot(np.ones((2,2,2)), np.ones((2,2,2)), axes=1).shape
Out[326]: (2, 2, 2, 2)
In [327]: np.tensordot(np.ones((2,2,2)), np.ones((2,2,2)), axes=2).shape
Out[327]: (2, 2)
In [328]: np.tensordot(np.ones((2,2,2)), np.ones((2,2,2)), axes=3).shape
Out[328]: ()
如果输入为0d,则轴= 0有效(轴= 1无效):
In [335]: np.tensordot(2,3, axes=0)
Out[335]: array(6)
你能解释一下吗?
In [363]: np.tensordot(np.ones((4,2,3)),np.ones((2,3,4)),axes=2).shape
Out[363]: (4, 4)
我玩过3d数组的其他标量轴值。虽然可以提出可行的形状对,但更明确的元组轴值更易于使用。
0,1,2
选项是仅适用于特殊情况的快捷方式。元组方法更易于使用-尽管我仍然更喜欢einsum
表示法。关于python - numpy.tensordot函数如何逐步工作?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/51989572/