给定一个大小为 (n, 3)n 约为 1000 的 ndarray,如何快速将每行的所有元素相乘?下面的(不优雅的)第二个解决方案在大约 0.3 毫秒内运行,是否可以改进?

# dummy data
n = 999
a = np.random.uniform(low=0, high=10, size=n).reshape(n/3,3)

# two solutions
def prod1(array):
    return [np.prod(row) for row in array]

def prod2(array):
    return [row[0]*row[1]*row[2] for row in array]

# benchmark
start = time.time()
prod1(a)
print time.time() - start
# 0.0015

start = time.time()
prod2(a)
print time.time() - start
# 0.0003

最佳答案

进一步提高性能

首先是一般的经验法则。您正在使用数值数组,因此请使用数组而不是列表。列表可能看起来有点像一个通用数组,但在后端完全不同,并且绝对不适用于大多数数值计算。

如果您使用 Numpy-Arrays 编写一个简单的代码,您可以通过简单地对其进行抖动来提高性能,如下所示。如果您使用列表,您可以或多或少地重写您的代码。

import numpy as np
import numba as nb

@nb.njit(fastmath=True)
def prod(array):
  assert array.shape[1]==3 #Enable SIMD-Vectorization (adding some performance)
  res=np.empty(array.shape[0],dtype=array.dtype)
  for i in range(array.shape[0]):
    res[i]=array[i,0]*array[i,1]*array[i,2]

  return res

使用 np.prod(a, axis=1) 不是一个坏主意,但性能并不是很好。对于只有 1000x3 的数组,函数调用开销非常大。当在另一个 jitted 函数中使用 jitted prod 函数时,可以完全避免这种情况。

基准
# The first call to the jitted function takes about 200ms compilation overhead.
#If you use @nb.njit(fastmath=True,cache=True) you can cache the compilation result for every successive call.
n=999
prod1   = 795  µs
prod2   = 187  µs
np.prod = 7.42 µs
prod      0.85 µs

n=9990
prod1   = 7863 µs
prod2   = 1810 µs
np.prod = 50.5 µs
prod      2.96 µs

关于python - 有效地将每一行的元素相乘,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/49290059/

10-10 18:42