因此,我想加快在numba jit的帮助下编写的程序的速度。但是jit似乎与许多scipy函数不兼容,因为它们使用了try无法处理的except ... jit ...结构(我对吗?)

我想出的一个相对简单的解决方案是复制所需的scipy源代码并删除try except部分(我已经知道它不会出错,因此try部分始终可以正常工作)

但是,我不喜欢这种解决方案,并且不确定是否可以使用。

我的代码结构如下所示

import scipy.integrate as integrate
from scipy optimize import curve_fit
from numba import jit

def fitfunction():
    ...

@jit
def function(x):
    # do some stuff
    try:
        fit_param, fit_cov = curve_fit(fitfunction, x, y, p0=(0,0,0), maxfev=500)
        for idx in some_list:
            integrated = integrate.quad(lambda x: fitfunction(fit_param), lower, upper)
    except:
        fit_param=(0,0,0)
        ...

现在,这将导致以下错误:



我认为这是由于jit无法处理try except(如果我仅将jit放在curve_fitintegrate.quad部分上,并且围绕我自己的try except结构,这也将不起作用)

import scipy.integrate as integrate
from scipy optimize import curve_fit
from numba import jit

def fitfunction():
    ...

@jit
def integral(lower, upper):
    return integrate.quad(lambda x: fitfunction(fit_param), lower, upper)

@jit
def fitting(x, y, pzero, max_fev)
    return curve_fit(fitfunction, x, y, p0=pzero, maxfev=max_fev)


def function(x):
    # do some stuff
    try:
        fit_param, fit_cov = fitting(x, y, (0,0,0), 500)
        for idx in some_list:
            integrated = integral(lower, upper)
    except:
        fit_param=(0,0,0)
        ...

有没有一种方法可以将jitscipy.integrate.quadcurve_fit结合使用,而无需从scipy代码中手动删除所有try except结构?

甚至会加快代码速度吗?

最佳答案

Numba只是而不是一个通用库,用于加快代码速度。使用numba可以解决一类问题(特别是如果您在数组上循环,进行数字运算)可以更快地解决,但其他所有问题要么是(1)不支持,要么是(2)速度稍快甚至很多慢点。

SciPy已经是一个高性能的库,因此在大多数情况下,我希望numba的性能更差(或者很少:稍微好一点)。您可能需要执行一些profiling来确定瓶颈是否确实存在于jit ted的代码中,那么您可以得到一些改进。但是我怀疑瓶颈将在SciPy的编译代码中,并且该编译代码可能已经进行了优化(因此实际上是,不太可能发现您可以“仅”与该代码竞争的实现)。

正如您正确地假设numba目前不支持tryexcept一样。

因此,这里的答案是

关于python - Numba jit与scipy,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/55317665/

10-12 17:07