我正在做一些需要快速协作的事情,我相信numba可以加速我的代码。
这里有一个愚蠢的例子:一个函数平方它的输入,并增加它被调用的次数。
def make_square_plus_count():
i = 0
def square_plus_count(x):
nonlocal i
i += 1
return x**2 + i
return square_plus_count
您甚至不能
nopython=False
jit这个,可能是由于nonlocal
关键字。但如果使用类,则不需要
nonlocal
:def make_square_plus_count():
@numba.jitclass({'i': numba.uint64})
class State:
def __init__(self):
self.i = 0
state = State()
@numba.jit()
def square_plus_count(x):
state.i += 1
return x**2 + state.i
return square_plus_count
这至少是可行的,但如果你这样做的话,它就会中断。
是否有一个可以用
nopython=True
编译的解决方案? 最佳答案
如果无论如何要使用状态类,也可以使用方法而不是闭包(不应编译python):
import numba
@numba.jitclass({'i': numba.uint64})
class State(object):
def __init__(self):
self.i = 0
def square_plus_count(self, x):
self.i += 1
return x**2 + self.i
square_with_call_count = State().square_plus_count # using the method
print([square_with_call_count(i) for i in range(10)])
# [1, 3, 7, 13, 21, 31, 43, 57, 73, 91]
然而,计时显示,这实际上比纯python闭包实现慢我希望只要您不使用
nonlocal
numpy数组或在方法(或闭包)中对数组执行操作,这将降低效率!