我用一个函数填充3D数组,该函数取决于其他1D数组的值,如下面的代码所示。涉及我的真实数据的代码将永远占用,因为我的1d数组(以及我的3D数组)的长度大约为100万。有什么方法可以更快地执行此操作,例如在python中不使用循环?

这个想法可能看起来很愚蠢,但是我仍然想知道在程序中以C ++形式填充此对象导入代码是否会更快……我是C ++的新手,所以我没有尝试。

import numpy as np
import time

start_time = time.time()
kx = np.linspace(0,400,100)
ky = np.linspace(0,400,100)
kz = np.linspace(0,400,100)

Kh = np.empty((len(kx),len(ky),len(kz)))

for i in range(len(kx)):
    for j in range(len(ky)):
        for k in range(len(kz)):
            if np.sqrt(kx[i]**2+ky[j]**2) != 0:
                Kh[i][j][k] = np.sqrt(kx[i]**2+ky[j]**2+kz[k]**2)
            else:
                Kh[i][j][k] = 1


print('Finished in %s seconds' % (time.time() - start_time))

最佳答案

您可以使用高性能JIT编译器@njit中的numba装饰器。它将时间减少了一个数量级以上。下面是比较和代码。就像导入njit然后使用@njit作为函数的修饰符一样简单。 This是官方网站。

我还使用1000*1000*1000计算了njit数据点的时间,仅花费了17.856173038482666秒。将并行版本用作@njit(parallel=True)可以进一步将时间减少到9.36257791519165秒。使用正常功能执行相同操作将花费几分钟。

我还对njit和@Bily在下面的answer中建议的矩阵运算做了一些时间比较。在点数最多为700的情况下,时间是可比较的,而njit方法在点数大于700的情况下显然会获胜,如下图所示。

import numpy as np
import time
from numba import njit

kx = np.linspace(0,400,100)
ky = np.linspace(0,400,100)
kz = np.linspace(0,400,100)

Kh = np.empty((len(kx),len(ky),len(kz)))

@njit  # <----- Decorating your function here
def func_njit(kx, ky, kz, Kh):
    for i in range(len(kx)):
        for j in range(len(ky)):
            for k in range(len(kz)):
                if np.sqrt(kx[i]**2+ky[j]**2) != 0:
                    Kh[i][j][k] = np.sqrt(kx[i]**2+ky[j]**2+kz[k]**2)
                else:
                    Kh[i][j][k] = 1
    return Kh

start_time = time.time()
Kh = func_njit(kx, ky, kz, Kh)
print('NJIT Finished in %s seconds' % (time.time() - start_time))

def func_normal(kx, ky, kz, Kh):
    for i in range(len(kx)):
        for j in range(len(ky)):
            for k in range(len(kz)):
                if np.sqrt(kx[i]**2+ky[j]**2) != 0:
                    Kh[i][j][k] = np.sqrt(kx[i]**2+ky[j]**2+kz[k]**2)
                else:
                    Kh[i][j][k] = 1
    return Kh

start_time = time.time()
Kh = func_normal(kx, ky, kz, Kh)
print('Normal function Finished in %s seconds' % (time.time() - start_time))




NJIT Finished in 0.36797094345092773 seconds
Normal function Finished in 5.540749788284302 seconds


njit与矩阵方法的比较

python - 填充3D数组而不在python中使用循环-LMLPHP

08-15 21:37