问题描述
import numpy as np
a = np.array([.4], dtype='float32')
b = np.array([.4, .6])
print(a > b)
print(a > b[0], a > b[1])
print(a[0] > b[0], a[0] > b[1])
[ True False]
[False] [False]
True False
这是怎么回事?是的,b.dtype == 'float64'
,但其切片b[0]
&;b[1]
,a
剩余'float32'
。
注意:我是问为什么会发生这种情况,而不是问如何规避它,我知道这一点(例如,将两者都转换为'float64'
)。
推荐答案
正如我已经提到的in another answer,NumPy中的类型转换相当复杂,这是您看到的行为的根本原因。答案中链接的文档清楚地表明,标量(/0d数组)和一维数组在类型转换方面不同,因为后者不被视为逐值计算。
您已经知道问题的前半部分:问题是两种情况下类型转换的方式不同:
>>> (a + b).dtype
dtype('float64')
>>> (a + b[0]).dtype
dtype('float32')
>>> (a[0] + b[0]).dtype
dtype('float64')
如果我们考虑类型转换表,我相信我们可以理解您的示例中发生的事情:
>>> from numpy.testing import print_coercion_tables
can cast
[...]
In these tables, ValueError is '!', OverflowError is '@', TypeError is '#'
scalar + scalar
+ ? b h i l q p B H I L Q P e f d g F D G S U V O M m
? ? b h i l q l B H I L Q L e f d g F D G # # # O ! m
b b b h i l q l h i l d d d e f d g F D G # # # O ! m
h h h h i l q l h i l d d d f f d g F D G # # # O ! m
i i i i i l q l i i l d d d d d d g D D G # # # O ! m
l l l l l l q l l l l d d d d d d g D D G # # # O ! m
q q q q q q q q q q q d d d d d d g D D G # # # O ! m
p l l l l l q l l l l d d d d d d g D D G # # # O ! m
B B h h i l q l B H I L Q L e f d g F D G # # # O ! m
H H i i i l q l H H I L Q L f f d g F D G # # # O ! m
I I l l l l q l I I I L Q L d d d g D D G # # # O ! m
L L d d d d d d L L L L Q L d d d g D D G # # # O ! m
Q Q d d d d d d Q Q Q Q Q Q d d d g D D G # # # O ! m
P L d d d d d d L L L L Q L d d d g D D G # # # O ! m
e e e f d d d d e f d d d d e f d g F D G # # # O ! #
f f f f d d d d f f d d d d f f d g F D G # # # O ! #
d d d d d d d d d d d d d d d d d g D D G # # # O ! #
g g g g g g g g g g g g g g g g g g G G G # # # O ! #
F F F F D D D D F F D D D D F F D G F D G # # # O ! #
D D D D D D D D D D D D D D D D D G D D G # # # O ! #
G G G G G G G G G G G G G G G G G G G G G # # # O ! #
S # # # # # # # # # # # # # # # # # # # # # # # O ! #
U # # # # # # # # # # # # # # # # # # # # # # # O ! #
V # # # # # # # # # # # # # # # # # # # # # # # O ! #
O O O O O O O O O O O O O O O O O O O O O O O O O ! #
M ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !
m m m m m m m m m m m m m m # # # # # # # # # # # ! m
scalar + neg scalar
[...]
array + scalar
+ ? b h i l q p B H I L Q P e f d g F D G S U V O M m
? ? b h i l q l B H I L Q L e f d g F D G # # # O ! m
b b b b b b b b b b b b b b e f d g F D G # # # O ! m
h h h h h h h h h h h h h h f f d g F D G # # # O ! m
i i i i i i i i i i i i i i d d d g D D G # # # O ! m
l l l l l l l l l l l l l l d d d g D D G # # # O ! m
q q q q q q q q q q q q q q d d d g D D G # # # O ! m
p l l l l l l l l l l l l l d d d g D D G # # # O ! m
B B B B B B B B B B B B B B e f d g F D G # # # O ! m
H H H H H H H H H H H H H H f f d g F D G # # # O ! m
I I I I I I I I I I I I I I d d d g D D G # # # O ! m
L L L L L L L L L L L L L L d d d g D D G # # # O ! m
Q Q Q Q Q Q Q Q Q Q Q Q Q Q d d d g D D G # # # O ! m
P L L L L L L L L L L L L L d d d g D D G # # # O ! m
e e e e e e e e e e e e e e e e e e F F F # # # O ! #
f f f f f f f f f f f f f f f f f f F F F # # # O ! #
d d d d d d d d d d d d d d d d d d D D D # # # O ! #
g g g g g g g g g g g g g g g g g g G G G # # # O ! #
F F F F F F F F F F F F F F F F F F F F F # # # O ! #
D D D D D D D D D D D D D D D D D D D D D # # # O ! #
G G G G G G G G G G G G G G G G G G G G G # # # O ! #
S # # # # # # # # # # # # # # # # # # # # # # # O ! #
U # # # # # # # # # # # # # # # # # # # # # # # O ! #
V # # # # # # # # # # # # # # # # # # # # # # # O ! #
O O O O O O O O O O O O O O O O O O O O O O O O O ! #
M ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! ! !
m m m m m m m m m m m m m m # # # # # # # # # # # ! m
[...]
以上是基于价值的促销的当前促销表的一部分。它表示在将给定类型的两个NumPy对象配对时,不同类型对结果类型的贡献(有关特定类型,请参阅第一列和第一行)。类型按照single-character dtype specifications(以下一字符串)来理解,特别是np.dtype('f')
对应np.float32
(对于C样式的浮点类型为f)和np.dtype('d')
(对于C样式的双精度类型为d)到np.float64
(另见np.typename('f')
和'd'
相同)。我在上表中注意到两项黑体:
现在让我们看看你的案例。前提是您有一个'f'
数组a
和一个'd'
数组b
。a
只有一个元素这一事实并不重要:它是长度为1的一维数组,而不是0维数组。
当您执行
a > b
时,您是在比较两个数组,上面的表中没有表示这一点。我不确定这里的行为是什么;我猜测a
被广播到b
的形状,然后它的类型被转换为'd'
。我认为这是因为np.can_cast(a, np.float64)
是True
和np.can_cast(b, np.float32)
是False
。但这只是一个猜测,NumPy中的许多机制对我来说并不直观。当您执行
a > b[0]
时,您是在将'f'
数组与'd'
标量进行比较,因此根据上面的说明,您将得到'f'
数组。这就是(a + b[0]).dtype
告诉我们的。(当您使用a > b[0]
时,您看不到转换步骤,因为结果始终是bool。)当您执行
a[0] > b[0]
操作时,您是在将'f'
标量与'd'
标量进行比较,因此根据上面的说明,您将得到一个'd'
标量。这是(a[0] + b[0]).dtype
告诉我们的。
所以我认为这与NumPy中类型转换的怪癖是一致的。虽然它可能看起来像是双精度和单精度的0.4
值的不幸的转折点,但该功能更深入,该问题充当一个红色大警告,提示您在混合不同的数据类型时应该非常小心。
最安全的做法是自己转换类型,以便控制代码中发生的事情。尤其是因为有关于重新考虑类型提升的某些方面的讨论。
这篇关于对于相同的元素,不同的切片给出了不同的不等式的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持!