我正在阅读一本深度学习书(第7章,CNN)中的implementation of im2col,其目的是将4维数组转换为2维数组。我不知道为什么在实现中有一个6维数组。我对作者使用的算法背后的想法很感兴趣。

我尝试搜索过许多有关im2col实现的论文,但没有一个使用这样的高维数组。我发现对im2col过程的可视化有用的当前 Material 是this paper - HAL Id: inria-00112631的图片

def im2col(input_data, filter_h, filter_w, stride=1, pad=0):
    """
    Parameters
    ----------
    input_data : (batch size, channel, height, width), or (N,C,H,W) at below
    filter_h : kernel height
    filter_w : kernel width
    stride : size of stride
    pad : size of padding
    Returns
    -------
    col : two dimensional array
    """
    N, C, H, W = input_data.shape
    out_h = (H + 2*pad - filter_h)//stride + 1
    out_w = (W + 2*pad - filter_w)//stride + 1

    img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
    col = np.zeros((N, C, filter_h, filter_w, out_h, out_w))

    for y in range(filter_h):
        y_max = y + stride*out_h
        for x in range(filter_w):
            x_max = x + stride*out_w
            col[:, :, y, x, :, :] = img[:, :, y:y_max:stride, x:x_max:stride]

    col = col.transpose(0, 4, 5, 1, 2, 3).reshape(N*out_h*out_w, -1)
    return col

最佳答案

让我们尝试可视化im2col的功能。它以一堆彩色图像作为输入,该堆栈具有尺寸图像ID,颜色 channel ,垂直位置,水平位置。为了简单起见,假设我们只有一个图像:

Python:利用6维数组的优势实现im2col吗?-LMLPHP

首先要做的是填充:

Python:利用6维数组的优势实现im2col吗?-LMLPHP

接下来,将其切成窗口。窗口的大小由filter_h/w控制,重叠部分由strides控制。

Python:利用6维数组的优势实现im2col吗?-LMLPHP

这是六个维度的来源:图像ID(由于只有一幅图像,因此在示例中缺失),网格高度/宽度,颜色 channel 。窗口的高度/宽度。

Python:利用6维数组的优势实现im2col吗?-LMLPHP

目前的算法有点笨拙,它以错误的尺寸顺序组装输出,然后必须使用transpose对其进行纠正。

最好首先将其正确设置:

def im2col_better(input_data, filter_h, filter_w, stride=1, pad=0):
    img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
    N, C, H, W = img.shape
    out_h = (H - filter_h)//stride + 1
    out_w = (W - filter_w)//stride + 1
    col = np.zeros((N, out_h, out_w, C, filter_h, filter_w))
    for y in range(out_h):
        for x in range(out_w):
            col[:, y, x] = img[
                ..., y*stride:y*stride+filter_h, x*stride:x*stride+filter_w]
    return col.reshape(np.multiply.reduceat(col.shape, (0, 3)))

顺便提一句:我们可以使用stride_tricks甚至做得更好,并避免嵌套的for循环:
def im2col_best(input_data, filter_h, filter_w, stride=1, pad=0):
    img = np.pad(input_data, [(0,0), (0,0), (pad, pad), (pad, pad)], 'constant')
    N, C, H, W = img.shape
    NN, CC, HH, WW = img.strides
    out_h = (H - filter_h)//stride + 1
    out_w = (W - filter_w)//stride + 1
    col = np.lib.stride_tricks.as_strided(img, (N, out_h, out_w, C, filter_h, filter_w), (NN, stride * HH, stride * WW, CC, HH, WW)).astype(float)
    return col.reshape(np.multiply.reduceat(col.shape, (0, 3)))

该算法做的最后一件事是重塑形状,合并前三个维度(在我们的示例中又只有两个维度,因为只有一个图像)。红色箭头显示各个窗口如何排成第一个新维度:

Python:利用6维数组的优势实现im2col吗?-LMLPHP

最后三个维度颜色 channel ,窗口中的y坐标,窗口中的x坐标合并到第二个输出维中。各个像素按黄色箭头所示排列:

Python:利用6维数组的优势实现im2col吗?-LMLPHP

关于Python:利用6维数组的优势实现im2col吗?,我们在Stack Overflow上找到一个类似的问题:https://stackoverflow.com/questions/50292750/

10-11 19:32