TVM-conv2d_nchw算子理解

一、 TVM topi中关于conv2d_nchw的代码:

def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
    if out_dtype is None:
        out_dtype = Input.dtype
    assert isinstance(stride, int) or len(stride) == 2
    assert isinstance(dilation, int) or len(dilation) == 2
    if isinstance(stride, int):
        stride_h = stride_w = stride
    else:
        stride_h, stride_w = stride

    if isinstance(dilation, int):
        dilation_h = dilation_w = dilation
    else:
        dilation_h, dilation_w = dilation

    batch, in_channel, in_height, in_width = Input.shape
    num_filter, channel, kernel_h, kernel_w = Filter.shape
    # compute the output shape
    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
    pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
        padding, (dilated_kernel_h, dilated_kernel_w)
    )
    out_channel = num_filter
    out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
    out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
    # compute graph
    pad_before = [0, 0, pad_top, pad_left]
    pad_after = [0, 0, pad_down, pad_right]
    temp = pad(Input, pad_before, pad_after, name="pad_temp")
    rc = te.reduce_axis((0, in_channel), name="rc")
    ry = te.reduce_axis((0, kernel_h), name="ry")
    rx = te.reduce_axis((0, kernel_w), name="rx")
    return te.compute(
        (batch, out_channel, out_height, out_width),
        lambda nn, ff, yy, xx: te.sum(
            temp[nn, rc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w].astype(
                out_dtype
            )
            * Filter[ff, rc, ry, rx].astype(out_dtype),
            axis=[rc, ry, rx],
        ),
        tag="conv2d_nchw",
    )

二、 conv2d一共有7个for-loop,我们慢慢推到一下:

2.1、 先放出三种常用类型的卷积感受一下

普通CNN.gif

stride=2的CNN.gif

dilation=2的CNN.gif

2.2 推导CNN的7个for-loop

假设Batch为B,输入I大小为N*M,通道数为C,卷积核F大小为X*Y,卷积核的数量为K,输出O大小为H*W,神经网络中stride为s, dilation为d

2.2.1、 单个卷积

输入和卷积核相同维度

for x in X:
  for y in Y:
    O[h][w] += I[x][y] * F[x][y]

2.2.2、 滑动窗口

for n in N:
  for m in M:
    for x in X:
      for y in Y:
        O[h][w] += I[n*s+x*d][m*s+y*d] * F[x][y]

滑动窗口时n、m每+1,就会窗口就会滑动n*s、m*s格,与之相乘的索引x、y会滑动x*d、y*d格。要永远记住n、m、x、y都是索引

2.2.3、 多通道滑动窗口卷积

for c in C:
  for n in N:
    for m in M:
      for x in X:
        for y in Y:
          O[h][w] += I[c][n*s+x*d][m*s+y*d] * F[c][x][y]

输入和卷积核的通道数总是保持一致的,输入有C个通道,那么卷积核就有C和通道,这与卷积核的数量概念不一样,卷积核的数量代表输出的通道数,因此这里的O依然只有两个维度

2.2.4、 多卷积核多通道滑动窗口卷积

for k in K:
  for c in C:
    for n in N:
      for m in M:
        for x in X:
          for y in Y:
            O[k][h][w] += I[c][n*s+x*d][m*s+y*d] * F[k][c][x][y]

k是卷积核的数量,也是O的通道数

2.2.5、 多batch多卷积核多通道滑动窗口卷积(NCHW)

for b in B
  for k in K:
    for c in C:
      for n in N:
        for m in M:
          for x in X:
            for y in Y:
              O[b][k][h][w] += I[b][c][n*s+x*d][m*s+y*d] * F[k][c][x][y]

batch是输入输出的batch,与卷积核F无关

小结:通过7个for-loop可以看到输出O有四个维度b、k、h、w,(也就是NCHW)

三、 TVM topi算子库对NCHW卷积的处理:

3.1、 对7个for-loop进行简化

可以看到O只有4个维度,即4个变量就可以决定O,因此可以对7个for-loop中的其余3个变量进行归约(reduce),即对c、x、y进行归约,因此可以看到TVM中的操作:

rc = te.reduce_axis((0, in_channel), name="rc")
ry = te.reduce_axis((0, kernel_h), name="ry")
rx = te.reduce_axis((0, kernel_w), name="rx")

3.2、 生成TE源语

return te.compute(
        (batch, out_channel, out_height, out_width),
        lambda nn, ff, yy, xx: te.sum(
            temp[nn, rc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w].astype(
                out_dtype
            )
            * Filter[ff, rc, ry, rx].astype(out_dtype),
            axis=[rc, ry, rx],
        ),
        tag="conv2d_nchw",
    )

O[b][k][h][w] += I[b][c][n*s][m*s] * F[k][c][x*d][y*d]
nn -> b       ff -> k       yy -> h       xx -> w       rc -> c       rx -> x       ry -> y

可以发现7个for-loop的参数都可以一一对应!

你可能感兴趣的:(TVM-conv2d_nchw算子理解)