嗨,我正在尝试从ndarray派生一个类。我坚持在docs中找到的食谱,但是当我重写__getiem__()函数时遇到一个我不明白的错误。我确定这是应该起作用的方式,但我不知道如何正确执行。我的类基本上添加了“dshape”属性,如下所示:

class Darray(np.ndarray):
    def __new__(cls, input_array, dshape, *args, **kwargs):
        obj = np.asarray(input_array).view(cls)
        obj.SelObj = SelObj
        obj.dshape = dshape
        return obj

    def __array_finalize__(self, obj):
        if obj is None: return
        self.info = getattr(obj, 'dshape', 'N')

    def __getitem__(self, index):
        return self[index]

当我现在尝试做的时候:
D = Darray( ones((10,10)), ("T","N"))

解释器将以最大深度递归失败,因为他一遍又一遍地调用__getitem__

有人可以向我解释为什么以及如何实现getitem函数吗?

干杯,
大卫

最佳答案



对于您当前的代码,不需要__getitem__。当我删除SelObj实现时,您的类(class)工作正常(未定义的__getitem__除外)。

最大递归深度错误的原因是__getitem__的定义,该定义使用self[index]:self.__getitem__(index)的简写形式。如果必须重写__getitem__,请确保调用__getitem__的父类(super class)实现:

def __getitem__(self, index):
    return super(Darray, self).__getitem__(index)

至于您为什么要这样做:有很多原因可以覆盖此功能,例如您可以将名称与数组的行关联:
class NamedRows(np.ndarray):
    def __new__(cls, rows, *args, **kwargs):
        obj = np.asarray(*args, **kwargs).view(cls)
        obj.__row_name_idx = dict((n, i) for i, n in enumerate(rows))
        return obj

    def __getitem__(self, idx):
        if isinstance(idx, basestring):
            idx = self.__row_name_idx[idx]
        return super(NamedRows, self).__getitem__(idx)

演示:
>>> a = NamedRows(["foo", "bar"], [[1,2,3], [4,5,6]])
>>> a["foo"]
NamedRows([1, 2, 3])

关于python - 从ndarray类__getitem__继承,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/19305128/

10-13 03:41