量化感知训练QAT(Quantification Aware Training)

目录

前言

对称量化

非对称量化

基于Pytorch官方API量化代码实现


前言

为了减少网络模型的空间占用和运行速度,除了在网络方面进行改进,模型剪枝和量化算是最常用的优化方法。剪枝就是将训练好的大模型的不重要的通道删除掉,在几乎不影响准确率的条件下对网络进行加速。而量化就是将浮点数(高精度)表示的权重和偏置用低精度整数(常用的INT8)来近似表示,在量化到低精度之后就可以应用移动平台上的优化技术如NEON对计算过程进行加速,并且原始模型量化后的模型容量也会减少,使其能够更好的应用到移动端环境。

量化感知训练QAT(Quantification Aware Training)_第1张图片

对称量化

对称量化的量化公式如下:

量化感知训练QAT(Quantification Aware Training)_第2张图片

其中 Δ \Delta Δ表示量化的缩放因子,x和 xint​分别表示量化前和量化后的数值。这里通过除以缩放因子接取整操作就把原始的浮点数据量化到了一个小区间中,比如对于有符号的8Bit 就[−128,127](无符号就是0到255了)。

这里有个Trick,即对于权重是量化到 [−127,127],这是为了累加的时候减少溢出的风险。

对应的反量化公式为:

量化感知训练QAT(Quantification Aware Training)_第3张图片

即将量化后的值乘以Δ就得到了反量化的结果,当然这个过程是有损的,如下图所示,橙色线表示的就是量化前的范围 [rmin,rmax],而蓝色线代表量化后的数据范围[−128,127],注意权重−127。

量化感知训练QAT(Quantification Aware Training)_第4张图片

非对称量化

非对称量化相比于对称量化就在于多了一个零点偏移。一个float32的浮点数非对称量化到一个int8的整数(如果是有符号就是,如果是无符号就是)的步骤为 缩放,取整,零点偏移,和溢出保护,如下图所示:

量化感知训练QAT(Quantification Aware Training)_第5张图片

缩放系数 Δ和零点偏移的计算公式如

量化感知训练QAT(Quantification Aware Training)_第6张图片

基于Pytorch官方API量化代码实现

import torch
import torch.nn as nn

#模型量化
class QAT_ASPNET_tpc(nn.Module):
    def __init__(self, model_fp32):
        super(QAT_ASPNET_tpc, self).__init__()
        # QuantStub converts tensors from floating point to quantized.
        # This will only be used for inputs.
        self.quant = torch.quantization.QuantStub()
        # DeQuantStub converts tensors from quantized to floating point.
        # This will only be used for outputs.
        self.dequant = torch.quantization.DeQuantStub()
        # FP32 model
        self.model_fp32 = model_fp32
    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        #print(x)
        x = self.model_fp32(x)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x

#在正常模型训练代码流程中加入如下模型量化操作

#首先加载正常模型 model

#模型model量化
model = QAT_ASPNET_tpc(model_fp32=model)

#自定义量化配置
MovingAverageMinMaxObserver=torch.quantization.observer.MovingAverageMinMaxObserver
model.qconfig = torch.quantization.QConfig(activation=torch.quantization.fake_quantize.FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,quant_min=-128, quant_max=127,dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False), weight=torch.quantization.fake_quantize.FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,quant_min=-127, quant_max=127,dtype=torch.qint8, qscheme=torch.per_tensor_symmetric, reduce_range=False))

#也可以用官方默认的配置,有两种方式,fbgemm是per_channel的,qnnpack是逐层的
#model.qconfig = torch.quantization.get_default_qat_qconfig("fbgemm")
#model.qconfig = torch.quantization.get_default_qat_qconfig("qnnpack")

#自定义需要融合的层,cov+bn+relue 或 cov+bn, 或者不融合
torch.quantization.fuse_modules(model,[['model_fp32.features.0','model_fp32.features.1','model_fp32.features.2'],['model_fp32.features.4','model_fp32.features.5','model_fp32.features.6'],['model_fp32.features.8','model_fp32.features.9','model_fp32.features.10'],['model_fp32.features.12','model_fp32.features.13','model_fp32.features.14'],['model_fp32.features.16','model_fp32.features.17','model_fp32.features.18']], inplace=True)#融合层cov+bn+relu

#将原来的浮点模型,替换为插入了伪定点算子的新模型
torch.quantization.prepare_qat(model, inplace=True)

你可能感兴趣的:(深度学习,深度学习,cnn,神经网络,计算机视觉,pytorch)