我正在尝试在cython中实现通用排序算法。因此,我创建了以下模块,该模块在类sorter_t中实现了Heapsort算法:

# file general_sort_c.pyx

from libc.stdint cimport int32_t
cdef bint bint_true = 1
cdef bint bint_false = 0

cdef class sorter_t:

    cdef object sortable_object

    def __init__(self,sortable_object):
        self.sortable_object = sortable_object

    cpdef sort_c(self):

        """
        https://en.wikipedia.org/wiki/Heapsort

        """

        cdef int32_t end
        cdef int32_t count = self.sortable_object.num_elements_int32

        self.heapify_c(count)

        end = count-1
        while end > 0:
            self.sortable_object.swap_c(0,end)
            end = end - 1
            self.siftDown_c(0,end)

    cdef heapify_c(self,int32_t count):

        cdef int32_t start = (count - 2)/2

        while start >= 0:
            self.siftDown_c(start, count-1)
            start -= 1

    cdef siftDown_c(self,int32_t start, int32_t end):

        cdef int32_t root = start
        cdef int32_t swap
        cdef int32_t child

        while root * 2 + 1 <= end:

            child = root * 2 + 1
            swap = root

            # if "swap" < "child" then ...
            if self.sortable_object.lt_c(swap,child) == 1:
                swap = child

            if child+1 <= end and self.sortable_object.lt_c(swap,child+1) == 1:
                swap = child + 1

            if swap != root:
                self.sortable_object.swap_c(root,swap)
                root = swap
            else:
                return


定义类型为sorter_t的对象时,必须提供一个sortable_object,它具有cdef函数lt_c(用于比较一个元素是否小于另一个元素)和swap_c(用于交换元素)的特定实现。 )。

例如,以下代码将从列表中定义并创建sortable_object,并使用该sortable_object测试“ sorter_t”的实现。

import numpy
cimport numpy
from libc.stdint cimport int32_t
import general_sort_c

cdef class sortable_t:

    cdef public int32_t num_elements_int32
    cdef int32_t [:] mv_lista

    def __init__(self,int32_t [:] mv_lista):
        self.num_elements_int32 = mv_lista.shape[0]
        self.mv_lista = mv_lista

    cdef public bint lt_c(self, int32_t left, int32_t right):
        if self.mv_lista[left] < self.mv_lista[right]:
            return 1 # True
        else:
            return 0 # False

    cdef public bint gt_c(self, int32_t left, int32_t right):
        if self.mv_lista[left] > self.mv_lista[right]:
            return 1 # True
        else:
            return 0 # False

    cdef public swap_c(self, int32_t left, int32_t right):
        cdef int32_t tmp
        tmp = self.mv_lista[right]
        self.mv_lista[right] = self.mv_lista[left]
        self.mv_lista[left] = tmp

def probar():

    lista = numpy.array([3,4,1,7],dtype=numpy.int32)
    cdef int32_t [:] mv_lista = lista

    cdef sortable = sortable_t(mv_lista)
    cdef sorter = general_sort_c.sorter_t(sortable)
    sorter.sort_increasing_c()
    print list(lista)


编译两个.pyx文件并在IPython控制台中运行以下命令后,将出现以下错误:

In [1]: import test_general_sort_c as tgs

In [2]: tgs.probar()

...

 general_sort_c.sorter_t.siftDown_increasing_c (general_sort_c.c:1452)()
    132
    133             #if mv_tnet_time[swap] < mv_tnet_time[child]:

--> 134             if self.sortable_object.lt_c(swap,child) == bint_true:
    135                 swap = child
    136

AttributeError: 'test_general_sort_c.sortable_t' object has no attribute 'lt_c'


因此,问题在于,从模块lt_c中的代码看不到函数general_sort_c.pyx的实现。如果我使用lt_c而不是cpdef定义函数cdef,它将起作用,但是您将有很多Python开销。如何以cdef(“纯C”)方式调用此函数?

最佳答案

不幸的是,我不确定如何使它与融合类型一起使用,但是其余的很简单:

test_general_sort_c.pyx需要免费的test_general_sort_c.pxd

from libc.stdint cimport int32_t

cdef class sortable_t:
    cdef public int32_t num_elements_int32
    cdef int32_t [:] mv_lista
    cdef public bint lt_c(self, int32_t left, int32_t right)
    cdef public bint gt_c(self, int32_t left, int32_t right)
    cdef public swap_c(self, int32_t left, int32_t right)


然后general_sort_c.pyx必须cimport test_general_sort_c并键入其self.sortable_objecttest_general_sort_c.sortable_t

当然,如果您可以使用多种受支持的类型,那就更好了。不过,目前还不确定您会怎么做。



另外,内置的TrueFalse可以正常工作。

如果您对Cython的信任度更高,您会意识到您可以编写

cdef public bint lt_c(self, int32_t left, int32_t right):
    return self.mv_lista[left] < self.mv_lista[right]

cdef public bint gt_c(self, int32_t left, int32_t right):
    return self.mv_lista[left] > self.mv_lista[right]

cdef public swap_c(self, int32_t left, int32_t right):
    self.mv_lista[right], self.mv_lista[left] = self.mv_lista[left], self.mv_lista[right]


正好。 :)

关于python - cython cdef类c方法:如何在没有python开销的情况下从另一个cython cdef类调用它?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/18850671/

10-12 17:02