给定一个大小为 (n, 3)
且 n
约为 1000 的 ndarray,如何快速将每行的所有元素相乘?下面的(不优雅的)第二个解决方案在大约 0.3 毫秒内运行,是否可以改进?
# dummy data
n = 999
a = np.random.uniform(low=0, high=10, size=n).reshape(n/3,3)
# two solutions
def prod1(array):
return [np.prod(row) for row in array]
def prod2(array):
return [row[0]*row[1]*row[2] for row in array]
# benchmark
start = time.time()
prod1(a)
print time.time() - start
# 0.0015
start = time.time()
prod2(a)
print time.time() - start
# 0.0003
最佳答案
进一步提高性能
首先是一般的经验法则。您正在使用数值数组,因此请使用数组而不是列表。列表可能看起来有点像一个通用数组,但在后端完全不同,并且绝对不适用于大多数数值计算。
如果您使用 Numpy-Arrays 编写一个简单的代码,您可以通过简单地对其进行抖动来提高性能,如下所示。如果您使用列表,您可以或多或少地重写您的代码。
import numpy as np
import numba as nb
@nb.njit(fastmath=True)
def prod(array):
assert array.shape[1]==3 #Enable SIMD-Vectorization (adding some performance)
res=np.empty(array.shape[0],dtype=array.dtype)
for i in range(array.shape[0]):
res[i]=array[i,0]*array[i,1]*array[i,2]
return res
使用
np.prod(a, axis=1)
不是一个坏主意,但性能并不是很好。对于只有 1000x3 的数组,函数调用开销非常大。当在另一个 jitted 函数中使用 jitted prod 函数时,可以完全避免这种情况。基准
# The first call to the jitted function takes about 200ms compilation overhead.
#If you use @nb.njit(fastmath=True,cache=True) you can cache the compilation result for every successive call.
n=999
prod1 = 795 µs
prod2 = 187 µs
np.prod = 7.42 µs
prod 0.85 µs
n=9990
prod1 = 7863 µs
prod2 = 1810 µs
np.prod = 50.5 µs
prod 2.96 µs
关于python - 有效地将每一行的元素相乘,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/49290059/