一、 TVM topi中关于conv2d_nchw的代码:
def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
if out_dtype is None:
out_dtype = Input.dtype
assert isinstance(stride, int) or len(stride) == 2
assert isinstance(dilation, int) or len(dilation) == 2
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
if isinstance(dilation, int):
dilation_h = dilation_w = dilation
else:
dilation_h, dilation_w = dilation
batch, in_channel, in_height, in_width = Input.shape
num_filter, channel, kernel_h, kernel_w = Filter.shape
# compute the output shape
dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
pad_top, pad_left, pad_down, pad_right = get_pad_tuple(
padding, (dilated_kernel_h, dilated_kernel_w)
)
out_channel = num_filter
out_height = simplify((in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1)
out_width = simplify((in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1)
# compute graph
pad_before = [0, 0, pad_top, pad_left]
pad_after = [0, 0, pad_down, pad_right]
temp = pad(Input, pad_before, pad_after, name="pad_temp")
rc = te.reduce_axis((0, in_channel), name="rc")
ry = te.reduce_axis((0, kernel_h), name="ry")
rx = te.reduce_axis((0, kernel_w), name="rx")
return te.compute(
(batch, out_channel, out_height, out_width),
lambda nn, ff, yy, xx: te.sum(
temp[nn, rc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w].astype(
out_dtype
)
* Filter[ff, rc, ry, rx].astype(out_dtype),
axis=[rc, ry, rx],
),
tag="conv2d_nchw",
)
二、 conv2d一共有7个for-loop,我们慢慢推到一下:
2.1、 先放出三种常用类型的卷积感受一下
2.2 推导CNN的7个for-loop
假设Batch为B,输入I大小为N*M,通道数为C,卷积核F大小为X*Y,卷积核的数量为K,输出O大小为H*W,神经网络中stride为s, dilation为d
2.2.1、 单个卷积
输入和卷积核相同维度
for x in X:
for y in Y:
O[h][w] += I[x][y] * F[x][y]
2.2.2、 滑动窗口
for n in N:
for m in M:
for x in X:
for y in Y:
O[h][w] += I[n*s+x*d][m*s+y*d] * F[x][y]
滑动窗口时n、m每+1,就会窗口就会滑动n*s、m*s格,与之相乘的索引x、y会滑动x*d、y*d格。要永远记住n、m、x、y都是索引
2.2.3、 多通道滑动窗口卷积
for c in C:
for n in N:
for m in M:
for x in X:
for y in Y:
O[h][w] += I[c][n*s+x*d][m*s+y*d] * F[c][x][y]
输入和卷积核的通道数总是保持一致的,输入有C个通道,那么卷积核就有C和通道,这与卷积核的数量概念不一样,卷积核的数量代表输出的通道数,因此这里的O依然只有两个维度
2.2.4、 多卷积核多通道滑动窗口卷积
for k in K:
for c in C:
for n in N:
for m in M:
for x in X:
for y in Y:
O[k][h][w] += I[c][n*s+x*d][m*s+y*d] * F[k][c][x][y]
k是卷积核的数量,也是O的通道数
2.2.5、 多batch多卷积核多通道滑动窗口卷积(NCHW)
for b in B
for k in K:
for c in C:
for n in N:
for m in M:
for x in X:
for y in Y:
O[b][k][h][w] += I[b][c][n*s+x*d][m*s+y*d] * F[k][c][x][y]
batch是输入输出的batch,与卷积核F无关
小结:通过7个for-loop可以看到输出O有四个维度b、k、h、w,(也就是NCHW)
三、 TVM topi算子库对NCHW卷积的处理:
3.1、 对7个for-loop进行简化
可以看到O只有4个维度,即4个变量就可以决定O,因此可以对7个for-loop中的其余3个变量进行归约(reduce),即对c、x、y进行归约,因此可以看到TVM中的操作:
rc = te.reduce_axis((0, in_channel), name="rc")
ry = te.reduce_axis((0, kernel_h), name="ry")
rx = te.reduce_axis((0, kernel_w), name="rx")
3.2、 生成TE源语
return te.compute(
(batch, out_channel, out_height, out_width),
lambda nn, ff, yy, xx: te.sum(
temp[nn, rc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w].astype(
out_dtype
)
* Filter[ff, rc, ry, rx].astype(out_dtype),
axis=[rc, ry, rx],
),
tag="conv2d_nchw",
)
O[b][k][h][w] += I[b][c][n*s][m*s] * F[k][c][x*d][y*d]
nn -> b ff -> k yy -> h xx -> w rc -> c rx -> x ry -> y
可以发现7个for-loop的参数都可以一一对应!