我有一个二维UINT8 numpy数组,大小为(149797,64)。每个元素都是0或1。我想将每行中的这些二进制值打包成UINT64值,这样我就得到了形状为149797的UINT64数组。我使用numpy bitpack函数尝试了以下代码。

test = np.random.randint(0, 2, (149797, 64),dtype=np.uint8)
col_pack=np.packbits(test.reshape(-1, 8, 8)[:, ::-1]).view(np.uint64)


packbits函数大约需要10毫秒才能执行。对数组本身进行简单的整形似乎需要大约7毫秒。我还尝试了通过移位操作在2d numpy数组上进行迭代,以达到相同的结果;但是速度没有提高。

最后,我也想使用numba进行编译。

@njit
def shifting(bitlist):
    x=np.zeros(149797,dtype=np.uint64)  #54
    rows,cols=bitlist.shape
    for i in range(0,rows):             #56
      out=0
      for bit in range(0,cols):
         out = (out << 1) | bitlist[i][bit] # If i comment out bitlist, time=190 microsec
      x[i]=np.uint64(out)  # Reduces time to microseconds if line is commented in njit
    return x


使用njit大约需要6毫秒。

这是并行的njit版本

@njit(parallel=True)
def shifting(bitlist):
    rows,cols=149797,64
    out=0
    z=np.zeros(rows,dtype=np.uint64)
    for i in prange(rows):
      for bit in range(cols):
         z[i] = (z[i] * 2) + bitlist[i,bit] # Time becomes 100 micro if i use 'out' instead of 'z[i] array'

    return z


与3.24ms执行时间相比,它要好一些(Google colab双核2.2Ghz)
目前,使用swapbytes(Paul's)方法的python解决方案似乎是最好的解决方案,即1.74毫秒。

我们如何进一步加快转换速度?是否可以使用任何矢量化(或并行化),位数组等来实现加速?

参考:numpy packbits pack to uint16 array

在12核计算机(Intel®Xeon®CPU E5-1650 v2 @ 3.50GHz)上,

鲍尔斯方法:1595.0微秒(我想它不使用多核)

Numba代码:146.0微秒(上述并行Numba)

即大约10倍加速!

最佳答案

您可以通过使用byteswap而不是重新塑形等方式来获得较大的加速:

test = np.random.randint(0, 2, (149797, 64),dtype=np.uint8)

np.packbits(test.reshape(-1, 8, 8)[:, ::-1]).view(np.uint64)
# array([ 1079982015491401631,   246233595099746297, 16216705265283876830,
#        ...,  1943876987915462704, 14189483758685514703,
       12753669247696755125], dtype=uint64)
np.packbits(test).view(np.uint64).byteswap()
# array([ 1079982015491401631,   246233595099746297, 16216705265283876830,
#        ...,  1943876987915462704, 14189483758685514703,
       12753669247696755125], dtype=uint64)

timeit(lambda:np.packbits(test.reshape(-1, 8, 8)[:, ::-1]).view(np.uint64),number=100)
# 1.1054180909413844

timeit(lambda:np.packbits(test).view(np.uint64).byteswap(),number=100)
# 0.18370431219227612

09-10 08:22
查看更多