随着今年深度学习的发展,基于神经网络的模型在许多任务上取得了比传统方法高很多的准确率指标。但是随之而来的问题是模型越来越庞大,这给它们部署在移动端平台(例如手机、AR/VR)上带来了诸多不便。因此,如何减小模型大小和提高推理速度,成为了一个新的热门研究方向。
目前的模型压缩方法主要分为两大类:1. 设计更高效的网络结构;2.将网络的权重和激活函数从float32量化成更低的bit。
之前的研究有两个问题:
所谓量化,就是在原始的浮点数r和量化后的整数q之间寻找一个仿射变换,使得: r = S ( q − Z ) r=S(q-Z) r=S(q−Z), 其中S和Z是参数。对于每一层权重(weights)和对应激活函数(activation),S和Z是相同的。q的bit数常见的有8,4,2,1等,我们常说的8bit量化就是指q的bit数为8,其他bit数类似。bias一般量化为32bit整数。
S称为放缩因子(scale),Z称为零点(zero-point)。Z的存在很有必要,这样能保证浮点数中的零能跟一个量化后的整数对应上。在神经网络中,经常有用零进行padding的情况,因此如果找不到一个整数对应的话会大大损失量化后的精度。
考虑两个均为 N × N N\times N N×N的方阵 r 1 r_1 r1和 r 2 r_2 r2,我们需要通过它们的矩阵乘法获得方阵 r 3 r_3 r3。用 r α ( i , j ) , α = 1 , 2 , o r , 3 , 1 ≤ i , j ≤ N r_{\alpha}^{(i,j)}, \alpha=1,2,or,3,1\leq i,j\leq N rα(i,j),α=1,2,or,3,1≤i,j≤N表示第 α \alpha α个矩阵的第 i i i行第 j j j列的元素。 S , Z , q S,Z,q S,Z,q的表示方法类似。则有:
S 3 ( q 3 ( i , k ) ) = ∑ j = 1 N S 1 ( q 1 ( i , j ) − Z 1 ) S 2 ( q 2 ( j , k ) − Z 2 ) (1-1) S_3(q_3^{(i,k)})=\sum_{j=1}^{N}S_1(q_1^{(i,j)}-Z_1)S_2(q_2^{(j,k)}-Z_2)\tag{1-1} S3(q3(i,k))=j=1∑NS1(q1(i,j)−Z1)S2(q2(j,k)−Z2)(1-1)
进而得到 q 3 ( i , k ) = Z 3 + M ∑ j = 1 N ( q 1 ( i , j ) − Z 1 ) ( q 2 ( j , k ) − Z 2 ) (1-2) q_3^{(i,k)}=Z_3+M\sum_{j=1}^{N}(q_1^{(i,j)}-Z_1)(q_2^{(j,k)}-Z_2)\tag{1-2} q3(i,k)=Z3+Mj=1∑N(q1(i,j)−Z1)(q2(j,k)−Z2)(1-2)
其中 M = S 1 S 2 S 3 M=\frac{S_1S_2}{S_3} M=S3S1S2
在(1-2)中,除了M外,都是整数。
对于M,经过大量实验统计表明,它总是在区间(0,1]中。因此可以表示为:
M = 2 − n M 0 (1-3) M=2^{-n}M_0\tag{1-3} M=2−nM0(1-3)
其中 M 0 ∈ [ 0.5 , 1 ) M_0\in [0.5,1) M0∈[0.5,1),n是一个非负整数。可以将 M 0 M_0 M0表示成一个定点数,也即如果硬件平台采用的是int32,则可以找出离 2 31 M 0 2^{31}M_0 231M0最近的整数来代替 M 0 M_0 M0(最后记得再还原回去就行),这样 M 0 M_0 M0的精度至少有30bit。这样的话,(1-3)就可以转换为整数的右移计算,大大提高效率。
相比原始的浮点矩阵乘,(1-2)中似乎额外增加了 2 N 3 2N^3 2N3次减法。其实可以简化为:
q 3 ( i , k ) = Z 3 + M ( N Z 1 Z 2 − Z 1 a 2 ( k ) − Z 2 a 1 ( i ) + ∑ j = 1 N q 1 ( i , j ) q 2 ( j , k ) ) (1-4) q_3^{(i,k)}=Z_3+M\left( NZ_1Z_2-Z_1a_2^{(k)}-Z_2a_1^{(i)}+\sum_{j=1}^Nq_1^{(i,j)}q_2^{(j,k)}\right)\tag{1-4} q3(i,k)=Z3+M(NZ1Z2−Z1a2(k)−Z2a1(i)+j=1∑Nq1(i,j)q2(j,k))(1-4)
其中 a 2 ( k ) = ∑ j = 1 N q 2 ( j , k ) , a 1 ( i ) = ∑ j = 1 N q 2 ( i , j ) a_2^{(k)}=\sum_{j=1}^Nq_2^{(j,k)}, a_1^{(i)}=\sum_{j=1}^Nq_2^{(i,j)} a2(k)=∑j=1Nq2(j,k),a1(i)=∑j=1Nq2(i,j)
可见 a 2 ( k ) , a 1 ( i ) a_2^{(k)},a_1^{(i)} a2(k),a1(i)分别只需要N次加法,因此总的加法只需要 2 N 2 2N^2 2N2(注意当i变化时, a 2 ( k ) a_2^{(k)} a2(k)不变;k变化时, a 1 ( i ) a_1^{(i)} a1(i)不变)
因此,(1-4)的主要复杂度在于 ∑ j = 1 N q 1 ( i , j ) q 2 ( j , k ) \sum_{j=1}^Nq_1^{(i,j)}q_2^{(j,k)} ∑j=1Nq1(i,j)q2(j,k),它总共具有 2 N 2N 2N次的算数运算(乘和加),为了得到结果矩阵中的所有元素则需要的复杂度为 2 N 3 2N^3 2N3,这在原始的浮点乘以及其他形式的量化计算下都是无法避免的。其他的复杂度是 O ( N 2 ) O(N^2) O(N2),带有一个小的常数,可以忽略。