我正在尝试使用numba在Python中实现quicksort算法。

它似乎比numpy排序功能要慢得多。

我该如何改善?我的代码在这里:

import numba as nb

@nb.autojit
def quick_sort(list_):
    """
    Iterative version of quick sort
    """
    #temp_stack = []
    #temp_stack.append((left,right))

    max_depth = 1000

    left = 0
    right = list_.shape[0]-1

    i_stack_pos = 0
    a_temp_stack = np.ndarray( ( max_depth, 2), dtype=np.int32 )
    a_temp_stack[i_stack_pos,0] = left
    a_temp_stack[i_stack_pos,1] = right
    i_stack_pos+=1
    #Main loop to pop and push items until stack is empty

    while i_stack_pos>0:

        i_stack_pos-=1
        right = a_temp_stack[ i_stack_pos, 1 ]
        left  = a_temp_stack[ i_stack_pos, 0 ]

        piv = partition(list_,left,right)
        #If items in the left of the pivot push them to the stack
        if piv-1 > left:
            #temp_stack.append((left,piv-1))

            a_temp_stack[ i_stack_pos, 0 ] = left
            a_temp_stack[ i_stack_pos, 1 ] = piv-1
            i_stack_pos+=1
        #If items in the right of the pivot push them to the stack
        if piv+1 < right:
            a_temp_stack[ i_stack_pos, 0 ] = piv+1
            a_temp_stack[ i_stack_pos, 1 ] = right
            i_stack_pos+=1

@nb.autojit( nopython=True )
def partition(list_, left, right):
    """
    Partition method
    """
    #Pivot first element in the array
    piv = list_[left]
    i = left + 1
    j = right

    while 1:
        while i <= j  and list_[i] <= piv:
            i +=1
        while j >= i and list_[j] >= piv:
            j -=1
        if j <= i:
            break
        #Exchange items
        list_[i], list_[j] = list_[j], list_[i]
    #Exchange pivot to the right position
    list_[left], list_[j] = list_[j], list_[left]
    return j


我的测试代码在这里:

    x = np.random.random_integers(0,1000,1000000)
    y = x.copy()

    quick_sort( y )

    z = np.sort(x)

    np.testing.assert_array_equal( z, y )

    y = x.copy()
    with Timer( 'nb' ):
        numba_fns.quick_sort( y )

    with Timer( 'np' ):
        x = np.sort(x)


更新:

我已经重写了该函数,以强制代码的循环部分在nopython模式下运行。 while循环似乎并未导致nopython失败。但是,我没有获得任何性能改进:

@nb.autojit
def quick_sort2(list_):
    """
    Iterative version of quick sort
    """

    max_depth = 1000

    left        = 0
    right       = list_.shape[0]-1

    i_stack_pos = 0
    a_temp_stack = np.ndarray( ( max_depth, 2), dtype=np.int32 )
    a_temp_stack[i_stack_pos,0] = left
    a_temp_stack[i_stack_pos,1] = right
    i_stack_pos+=1
    #Main loop to pop and push items until stack is empty

    return _quick_sort2( list_, a_temp_stack, left, right )

@nb.autojit( nopython=True )
def _quick_sort2( list_, a_temp_stack, left, right ):

    i_stack_pos = 1
    while i_stack_pos>0:

        i_stack_pos-=1
        right = a_temp_stack[ i_stack_pos, 1 ]
        left  = a_temp_stack[ i_stack_pos, 0 ]

        piv = partition(list_,left,right)
        #If items in the left of the pivot push them to the stack
        if piv-1 > left:
            a_temp_stack[ i_stack_pos, 0 ] = left
            a_temp_stack[ i_stack_pos, 1 ] = piv-1
            i_stack_pos+=1
        if piv+1 < right:
            a_temp_stack[ i_stack_pos, 0 ] = piv+1
            a_temp_stack[ i_stack_pos, 1 ] = right
            i_stack_pos+=1

@nb.autojit( nopython=True )
def partition(list_, left, right):
    """
    Partition method
    """
    #Pivot first element in the array
    piv = list_[left]
    i = left + 1
    j = right

    while 1:
        while i <= j  and list_[i] <= piv:
            i +=1
        while j >= i and list_[j] >= piv:
            j -=1
        if j <= i:
            break
        #Exchange items
        list_[i], list_[j] = list_[j], list_[i]
    #Exchange pivot to the right position
    list_[left], list_[j] = list_[j], list_[left]
    return j

最佳答案

一个小的建议可能会有所帮助(但是,正如您在问题的注释中正确地告诉您的那样,您将难以克服纯C实现):

您想确保大多数操作都是在“ nopython”模式(@jit(nopython=True))下完成的。在您的函数之前添加它,并查看它在哪里中断。还要在函数上调用inspect_types(),看看它是否能够正确识别它们。

代码中最有可能迫使其进入对象模式(与nopython模式相对)的一件事是分配numpy数组。尽管numba可以在nopython模式下单独编译循环,但我不知道它是否可以对while循环进行编译。呼叫inspect_types会告诉您。

我通常在创建numpy数组的同时确保其余部分处于nopython模式的工作流程是创建包装函数。

@nb.jit(nopython=True) # make sure it can be done in nopython mode
def _quick_sort_impl(list_,output_array):
   ...most of your code goes here...

@nb.jit
def quick_sort(list_):
   # this code won't compile in nopython mode, but it's
   # short and isolated
   max_depth = 1000
   a_temp_stack = np.ndarray( ( max_depth, 2), dtype=np.int32 )
   _quick_sort_impl(list_,a_temp_stack)

关于python - 如何用numba加快quicksort?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/29200353/

10-09 03:50