Convolution的变形 -- Convolution、 InnerProduct、MatMul、BLAS(GEMM、GEMV) 之间的关联

1 前言

MNN 模型模型转换结束后进行 optimizeNet 的第一项即为 turnInnerProduct2Convolution,抛开各种复杂的优化考虑,单是 减少了概念的维护 一点就已经很让人心动了。

日常开发中,可能大家都感到过 Inner Procuct、MatMul、BLAS、Concolution 等操作存在那么些相似性,但究竟哪里相似,哪里不同,或许没有去仔细地分析过。一方面因为多维操作的分析想象的确比较烧脑;另一方面没有需求啊,推理用就好了,模型用就好了,跑崩了官方提issue啊,找前向框架开发解决啊!

这不失为一个不错的思路。然而,抛开某些紧急场景下我们要自己实现一些模型操作转换设计,又或要实现一些模型操作优化…… 作为一个AI时代的新青年,难道我们不该有一些“AI工程师的好奇心”,培养下 “AI工程师的自我修养” 嘛!

追根究底 是件很有趣的事情, 同事 也是 作为积累 在很多紧急的关键时刻能 “救命”的东西。那么马上开始,让我们一起来看看 类似卷积的相关操作的相互变换 吧。

2 变换

2.1 回忆 - 卷积的关键参数

我们先来回忆一下卷积的相关参数

2.1.1 输入Input

Convolution的变形 -- Convolution、 InnerProduct、MatMul、BLAS(GEMM、GEMV) 之间的关联_第1张图片
输入

input_width : 输入宽,如图设为10
input_height : 输入高,如图设为10
input_channel : 输入维度,如图设为3
input_padX : 水平方向加边 - 本文不讨论,设为0
input_padY : 垂直方向加边 - 本文不讨论,设为0
input_padMode : 加边模式 - 本文不讨论
input_group_number : 输入分组 - 暂不讨论,设为 1

2.1.2 卷积核Kernal

Convolution的变形 -- Convolution、 InnerProduct、MatMul、BLAS(GEMM、GEMV) 之间的关联_第2张图片
卷积核

kernal_width : 卷积核宽,如图设为3
kernel_height : 卷积核高,如图设为3
kernal_dilateX : 水平方向膨胀 - 本文不讨论,设为1
kernal_dilateY : 垂直方向膨胀 - 本文不讨论,设为1
kernal_number = input_channel * output_channel,举例的计算结果为 3 x 4 = 12
(output_channel 的说明在 2.1.5 部分)

2.1.3 偏置Bias

Convolution的变形 -- Convolution、 InnerProduct、MatMul、BLAS(GEMM、GEMV) 之间的关联_第3张图片
偏置

bias_number = output_channel,举例的计算结果为 4
(output_channel 的说明在 2.1.5 部分)

2.1.4 其他配置(步长Stride)

strideX : 水平方向步长 - 本文不讨论,设为1
strideY : 垂直方向步长 - 本文不讨论,设为1

2.1.5 输出Output

Convolution的变形 -- Convolution、 InnerProduct、MatMul、BLAS(GEMM、GEMV) 之间的关联_第4张图片
输出

output_width : (input_width + input_padX * 2 - ((kernal_width - 1) * kernal_dilateX + 1)) / strideX + 1,按上面的设置计算,为8
output_height : (input_height + input_padY * 2 - ((kernal_height - 1) * kernal_dilateY + 1)) / strideY + 1,按上面的设置计算,为8
output_channel : 输出维度,如图设为4

2.1.6 卷积操作示例

按照上面示例的配置,我们得到一份卷积运算,如下图:

Convolution的变形 -- Convolution、 InnerProduct、MatMul、BLAS(GEMM、GEMV) 之间的关联_第5张图片
简单的卷积运算

2.2 Convolution的基础参数变换

我们首先关注如下5个配置的修改变化:

input_channel 输入维度(或输入特征)数量
output_channel 输出维度(或输出特征)数量
kernal_width 卷积核宽度
kernel_height 卷积核高度
group 卷积分组数量

Convolution的变形 -- Convolution、 InnerProduct、MatMul、BLAS(GEMM、GEMV) 之间的关联_第6张图片
卷积基本变换 - 网页中点击图片,选择[查看原图],可查看高清原图

应该蛮直观的吧,相信大家稍微不熟悉的应该只有分组卷积 (Group Convolution) 了吧。

金哥和你一起学模型压缩——结构篇(1)

这篇文章中对分组卷积略有介绍,可以适当参考。
(姚神记得给我支付宝打广告费~!)

2.3 Convolution <--> Inner Product

让我们渐入佳境,看看Convolution与Inner Product的转换关系。

Convolution的变形 -- Convolution、 InnerProduct、MatMul、BLAS(GEMM、GEMV) 之间的关联_第7张图片
Convolution <--> Inner Product - 网页中点击图片,选择[查看原图],可查看高清原图

Convolution 使用 1x1 的 kernal,再将输入 size 变为 1x1 就退化成了 Inner Product!

有些神奇对不对,更精彩的还在后面~

2.4 Convolution(Inner Product)<-> MatMul

Inner Product 退化为一个矩阵乘法很简单,Inner Product 可以看做是一个特殊的 Convolution,所以 MatMul 也可以看做一个特殊的 Convolution

Convolution的变形 -- Convolution、 InnerProduct、MatMul、BLAS(GEMM、GEMV) 之间的关联_第8张图片
Convolution(Inner Product)<-> MatMul - 网页中点击图片,选择[查看原图],可查看高清原图

但细心的小伙伴会发现我们举例的 MatMul 的左矩阵是 1 x 3 的,如果换证一个 2 x 3 的正常矩阵?……貌似转换就不太顺利了……
别急,我们先来了解下 BLAS

2.5 BLAS

BLASBasic Linear Algebra Subprograms,即 基础线性代数子程序。
我们比较常见的 GEMM,即广义矩阵乘法就是 BLAS 的一种高级形式。

Convolution的变形 -- Convolution、 InnerProduct、MatMul、BLAS(GEMM、GEMV) 之间的关联_第9张图片
BLAS - 网页中点击图片,选择[查看原图],可查看高清原图

会到我们 2.4节 遗留的问题,有没有发现 11号GEMM 变换 MatMul 很容易呢?
马上来看看 BLAS & Convolution 的相互变换

2.6 Convolution(InnerProduct) <-> BLAS

我们以退化为 Inner ProductConvolution 为例来描述 Convolution & BLAS 之间的关联

Convolution的变形 -- Convolution、 InnerProduct、MatMul、BLAS(GEMM、GEMV) 之间的关联_第10张图片
Convolution(InnerProduct) <-> BLAS - 网页中点击图片,选择[查看原图],可查看高清原图

GEMM 可以看做一个 每组kernal都完全相同 的分组卷积!
所以转模型的时候想将一个 GEMM 转化为 Convolution 的话,是要 牺牲一些存储空间 咯。

我们可以理解为分组卷积的变种

2.7 总结

把我们 2.1 ~ 2.6 描述的变换整合到一张图

卷积变换 - 网页中点击图片,选择[查看原图],可查看高清原图

还是有点复杂的,不过针对 有不同操作转换硬需求 的小伙伴,相信这张图能为整理思路节约不少的时间!

3 后记

然而,Convolution 的变换场景还不止如此,比如当 Group ConvolutionInput Channel、Output Channel、Group Number 相等时,就变成了一个 DepthWise Convolution !

Convolution的变形 -- Convolution、 InnerProduct、MatMul、BLAS(GEMM、GEMV) 之间的关联_第11张图片
DepthWise Convolution

有机会我们再慢慢研究讨论咯~

你可能感兴趣的:(Convolution的变形 -- Convolution、 InnerProduct、MatMul、BLAS(GEMM、GEMV) 之间的关联)