我正在尝试使用numba在Python中实现quicksort算法。
它似乎比numpy排序功能要慢得多。
我该如何改善?我的代码在这里:
import numba as nb
@nb.autojit
def quick_sort(list_):
"""
Iterative version of quick sort
"""
#temp_stack = []
#temp_stack.append((left,right))
max_depth = 1000
left = 0
right = list_.shape[0]-1
i_stack_pos = 0
a_temp_stack = np.ndarray( ( max_depth, 2), dtype=np.int32 )
a_temp_stack[i_stack_pos,0] = left
a_temp_stack[i_stack_pos,1] = right
i_stack_pos+=1
#Main loop to pop and push items until stack is empty
while i_stack_pos>0:
i_stack_pos-=1
right = a_temp_stack[ i_stack_pos, 1 ]
left = a_temp_stack[ i_stack_pos, 0 ]
piv = partition(list_,left,right)
#If items in the left of the pivot push them to the stack
if piv-1 > left:
#temp_stack.append((left,piv-1))
a_temp_stack[ i_stack_pos, 0 ] = left
a_temp_stack[ i_stack_pos, 1 ] = piv-1
i_stack_pos+=1
#If items in the right of the pivot push them to the stack
if piv+1 < right:
a_temp_stack[ i_stack_pos, 0 ] = piv+1
a_temp_stack[ i_stack_pos, 1 ] = right
i_stack_pos+=1
@nb.autojit( nopython=True )
def partition(list_, left, right):
"""
Partition method
"""
#Pivot first element in the array
piv = list_[left]
i = left + 1
j = right
while 1:
while i <= j and list_[i] <= piv:
i +=1
while j >= i and list_[j] >= piv:
j -=1
if j <= i:
break
#Exchange items
list_[i], list_[j] = list_[j], list_[i]
#Exchange pivot to the right position
list_[left], list_[j] = list_[j], list_[left]
return j
我的测试代码在这里:
x = np.random.random_integers(0,1000,1000000)
y = x.copy()
quick_sort( y )
z = np.sort(x)
np.testing.assert_array_equal( z, y )
y = x.copy()
with Timer( 'nb' ):
numba_fns.quick_sort( y )
with Timer( 'np' ):
x = np.sort(x)
更新:
我已经重写了该函数,以强制代码的循环部分在nopython模式下运行。 while循环似乎并未导致nopython失败。但是,我没有获得任何性能改进:
@nb.autojit
def quick_sort2(list_):
"""
Iterative version of quick sort
"""
max_depth = 1000
left = 0
right = list_.shape[0]-1
i_stack_pos = 0
a_temp_stack = np.ndarray( ( max_depth, 2), dtype=np.int32 )
a_temp_stack[i_stack_pos,0] = left
a_temp_stack[i_stack_pos,1] = right
i_stack_pos+=1
#Main loop to pop and push items until stack is empty
return _quick_sort2( list_, a_temp_stack, left, right )
@nb.autojit( nopython=True )
def _quick_sort2( list_, a_temp_stack, left, right ):
i_stack_pos = 1
while i_stack_pos>0:
i_stack_pos-=1
right = a_temp_stack[ i_stack_pos, 1 ]
left = a_temp_stack[ i_stack_pos, 0 ]
piv = partition(list_,left,right)
#If items in the left of the pivot push them to the stack
if piv-1 > left:
a_temp_stack[ i_stack_pos, 0 ] = left
a_temp_stack[ i_stack_pos, 1 ] = piv-1
i_stack_pos+=1
if piv+1 < right:
a_temp_stack[ i_stack_pos, 0 ] = piv+1
a_temp_stack[ i_stack_pos, 1 ] = right
i_stack_pos+=1
@nb.autojit( nopython=True )
def partition(list_, left, right):
"""
Partition method
"""
#Pivot first element in the array
piv = list_[left]
i = left + 1
j = right
while 1:
while i <= j and list_[i] <= piv:
i +=1
while j >= i and list_[j] >= piv:
j -=1
if j <= i:
break
#Exchange items
list_[i], list_[j] = list_[j], list_[i]
#Exchange pivot to the right position
list_[left], list_[j] = list_[j], list_[left]
return j
最佳答案
一个小的建议可能会有所帮助(但是,正如您在问题的注释中正确地告诉您的那样,您将难以克服纯C实现):
您想确保大多数操作都是在“ nopython”模式(@jit(nopython=True)
)下完成的。在您的函数之前添加它,并查看它在哪里中断。还要在函数上调用inspect_types()
,看看它是否能够正确识别它们。
代码中最有可能迫使其进入对象模式(与nopython模式相对)的一件事是分配numpy数组。尽管numba可以在nopython模式下单独编译循环,但我不知道它是否可以对while循环进行编译。呼叫inspect_types
会告诉您。
我通常在创建numpy数组的同时确保其余部分处于nopython模式的工作流程是创建包装函数。
@nb.jit(nopython=True) # make sure it can be done in nopython mode
def _quick_sort_impl(list_,output_array):
...most of your code goes here...
@nb.jit
def quick_sort(list_):
# this code won't compile in nopython mode, but it's
# short and isolated
max_depth = 1000
a_temp_stack = np.ndarray( ( max_depth, 2), dtype=np.int32 )
_quick_sort_impl(list_,a_temp_stack)
关于python - 如何用numba加快quicksort?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/29200353/