转载地址:https://bbs.huaweicloud.com/forum/thread-112006-1-1.html
作者:李响
量化是以较低的推理精度损失将连续取值或者大量可能的离散取值的浮点型模型权重或流经模型的张量数据定点近似为有限个离散值的过程,它是以更少位数的数据类型用于近似表示32位有限范围浮点型数据的过程,而模型的输入输出依然是浮点型。这样的好处是可以减小模型尺寸大小,减少模型内存占用,加快模型推理速度,降低功耗等。因此,与FP32类型相比,FP16、INT8、INT4等低精度数据表达类型所占用空间更小。使用低精度数据表达类型替换高精度数据表达类型,可以大幅降低存储空间和传输时间。而低比特的计算性能也更高,INT8相对比FP32的加速比可达到3倍甚至更高,对于相同的计算,功耗上也有明显优势。当前业界量化方案主要分为两种:感知量化训练和训练后量化。
其中,伪量化节点是指感知量化训练中插入的节点,用以寻找网络数据分布,并反馈损失精度,具体作用有两点:
1 找到网络数据的分布,即找到待量化参数的最大值和最小值;
2 模拟量化为低比特时的精度损失,把该损失作用到网络模型中,传递给损失函数,让优化器在训练过程中对该损失值进行优化。
感知量化训练模型与一般训练步骤一致,在定义网络和最后生成模型阶段后,需要进行额外的操作,完整流程如下:
1 数据处理加载数据集。
2 定义原始非量化网络。
3 定义融合网络。在完成定义原始非量化网络后,替换指定的算子,完成融合网络的定义。
4 定义优化器和损失函数。
5 转化量化网络。基于融合网络,使用转化接口在融合网络中插入伪量化节点,生成量化网络。
6 进行量化训练。基于量化网络训练,生成量化模型。
注:
融合网络:使用指定算子(nn.Conv2dBnAct
、nn.DenseBnAct
)替换后的网络。
量化网络:融合模型使用转换接口(QuantizationAwareTraining.quantize
)插入伪量化节点后得到的网络。
量化模型:量化网络训练后得到的checkpoint格式的模型。
定义融合网络,在定义网络后,替换指定的算子。
使用nn.Conv2dBnAct
算子替换原网络模型中的2个算子nn.Conv2d
和nn.ReLU
。
使用nn.DenseBnAct
算子替换原网络模型中的2个算子nn.Dense
和nn.ReLU
。
下面用代码阐述:
原网络模型LeNet5的定义如下所示:
class LeNet5(nn.Cell):
"""
Lenet network
Args:
num_class (int): Num classes. Default: 10.
num_channel (int): Num channel. Default: 1.
Returns:
Tensor, output tensor
Examples:
>>> LeNet(num_class=10, num_channel=1)
"""
def __init__(self, num_class=10, num_channel=1):
super(LeNet5, self).__init__()
self.conv1 = nn.Conv2d(num_channel, 6, 5, pad_mode='valid')
self.conv2 = nn.Conv2d(6, 16, 5, pad_mode='valid')
self.fc1 = nn.Dense(16 * 5 * 5, 120, weight_init=Normal(0.02))
self.fc2 = nn.Dense(120, 84, weight_init=Normal(0.02))
self.fc3 = nn.Dense(84, num_class, weight_init=Normal(0.02))
self.relu = nn.ReLU()
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
self.flatten = nn.Flatten()
def construct(self, x):
x = self.max_pool2d(self.relu(self.conv1(x)))
x = self.max_pool2d(self.relu(self.conv2(x)))
x = self.flatten(x)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
x = self.fc3(x)
return x
替换算子后的融合网络如下:
class LeNet5(nn.Cell):
def __init__(self, num_class=10):
super(LeNet5, self).__init__()
self.num_class = num_class
self.conv1 = nn.Conv2dBnAct(1, 6, kernel_size=5, activation='relu')
self.conv2 = nn.Conv2dBnAct(6, 16, kernel_size=5, activation='relu')
self.fc1 = nn.DenseBnAct(16 * 5 * 5, 120, activation='relu')
self.fc2 = nn.DenseBnAct(120, 84, activation='relu')
self.fc3 = nn.DenseBnAct(84, self.num_class)
self.max_pool2d = nn.MaxPool2d(kernel_size=2, stride=2)
def construct(self, x):
x = self.max_pool2d(self.conv1(x))
x = self.max_pool2d(self.conv2(x))
x = self.flattern(x)
x = self.fc1(x)
x = self.fc2(x)
x = self.fc3(x)
return x
使用QuantizationAwareTraining.quantize
接口自动在融合模型中插入伪量化节点,将融合模型转化为量化网络。
from mindspore.compression.quant import QuantizationAwareTraining
quantizer = QuantizationAwareTraining(quant_delay=900,
bn_fold=False,
per_channel=[True, False],
symmetric=[True, False])
net = quantizer.quantize(network)