0基础学习地平线QAT量化感知训练

文章目录

  • 1. 背景
  • 2. 基础理论知识
  • 3. 文件准备与程序运行
  • 4. 代码详解
    • 4.1 导入必要依赖
    • 4.2 主函数
    • 4.3 构建fx模式所需要的float_model
    • 4.4 不同阶段模型的获取
    • 4.5 定义常规模型训练与验证的函数
    • 4.6 float与qat训练代码解读——float_model/qat_model
    • 4.7 模型校准部分的代码解读——calib_model
    • 4.8 定点模型评测精度 代码解读——quantized_model
    • 4.9 编译生成上板模型——script_model/model.hbm
  • 5. 建议or吐槽

1. 背景

首先感谢一下地平线工具链用户手册和官方提供的示例,给了我很大的帮助,特别是代码的注释写了很多的知识点,超赞!要是注释能再详细点,就是超超赞了!下面开始正文。
以前从0开始学习过地平线的PTQ(后量化)方案,写了一些基础知识文章,后来发现地平线的用户手册关于PTQ方面其实挺完善的,东西很多很全,就没再想着写。
最近想着学QAT(量化感知训练)玩玩,大体看了一下地平线的用户手册,不说精度调优、性能调优之类比较复杂的,光一个QAT上手,就感觉对我这种小白不是很友好,比如我这种小白,捣鼓了好久,感觉在用户手册中很多基础概念都没写,不同模块之间的关联性也没有详细地介绍,直到我“精读”用户手册 4.2量化感知训练(QAT) ,发现了这么一句话,

懂了,没用过Pytorch的QAT,直接看手册学起来有点费劲才是正常滴!
那针对只使用过Pytorch在服务器上训练过一些分类、检测模型,没接触过QAT的小白,又不想读PyTorch官方文档,只想简单入个门,怎么办嘞?欢迎看看这篇文章,提供实操代码和运行步骤,如果文章对你有点作用的话,麻烦收藏+点个赞再走~

该文章参考自J5 OE1.1.52中对应的示例以及用户手册,为啥不是用的XJ3 OE,请看第5节吐槽部分

2. 基础理论知识

深度学习量化通常是指以int类型的数据代替浮点float类型的数据进行计算和存储,从而减小模型大小,降低带宽需求,理论上,INT8 量化,与常规的 FP32 模型相比,模型大小减少 4 倍,内存带宽需求减少 4 倍。
量化可以分为PTQ与QAT,

  • PTQ:Post-training Quantization,训练后量化,指浮点模型训练完成后,基于一些校准数据,直接通过工具自动进行模型量化的过程,相比QAT,PTQ更简单一些,这篇文章不介绍PTQ。
  • QAT:Quantization aware training,量化感知训练,指浮点模型训练完成后,在模型中插入伪量化节点再进行量化训练的过程,大体过程如下图所示,相比PTQ,QAT精度更有保障一些,这篇文章介绍QAT
    0基础学习地平线QAT量化感知训练_第1张图片

小白:图中伪量化节点FakeQuantize node是什么?有什么作用?

大黑:从命名看,就是假装量化呗,模拟将数据从float类型量化为int类型,主要作用于网络的权重和激活(节点输出,不是relu这种激活函数的意思)。在QAT中,通过使用伪量化节点,可以在训练期间优化模型以适应后续的真实量化操作,从而提高量化模型的准确性和性能。一旦模型训练完成后,伪量化节点将被替换为真实的量化操作,以生成最终的量化模型。

小白:插入伪量化节点后需要Retraining/Funetuning?感觉很浪费资源的样子…

大黑:通常再多训 1/10 浮点阶段训练的轮数就好了,比如浮点阶段训练了100epoch,QAT训个10epoch就好,为了精度,浪费就浪费点,小问题!

小白:从上面这个图看,感觉QAT还挺简单的,其实目前我就只会用pytorch搭一个卷积网络,然后去训练,那我要经历哪些阶段才能得到最终上板部署的模型呢?

大黑:整个过程会涉及到以下几个模型:
0基础学习地平线QAT量化感知训练_第2张图片
在每个阶段,还有一些需要注意的地方,比如…

小白:停停停,先别急,这里面新名词有点多,先帮我捋捋。float_model和我直接用pytorch搭建的有什么不同吗?fx是什么?calib是什么?qat_model和quantized_model还不是一个意思?script_model又是哪儿冒出来的?板端部署hbm模型我知道,就是可以在板子上推理的模型,类似于PTQ里的bin模型对吧?

大黑:这一连串问题问的挺好,我下面逐个简单解释一下。

  • float_model和我直接用pytorch搭建的有什么不同吗?
    这里float_model浮点模型,其实就是在pytorch搭建的常规网络输入处插入QuantStub节点、输出处插入DeQuantstub节点,在PyTorch中,QuantStub/DequantStub 是一种用于量化的辅助工具,用于标记量化过程中需要量化/反量化的层或操作,前期浮点训练时可以当它不存在,在量化时会自动被替换为对应的量化操作。从普遍意义上说,每个分支都要对应插入QuantStub,别再追问为什么了,问就是甲鱼的臀部——“规定”。
  • fx是什么?
    pytorch中量化方式有两种,分别是Eager Mode Quantization和FX Graph Mode Quantization,它俩各有优劣。对于初学者,Eager模式需要手工修改网络代码,并对很多节点进行替换,比较复杂,而 FX模式不需要这种操作,使用起来比较简单,因此,推荐使用fx模式。
    关于fx与eager两种模式体现在地平线量化训练以及部署层面的差异,大家感兴趣的话,可参考地平线开发者社区专业介绍:QAT - 异构与非异构方案使用简介。
    地平线同时支持fx和eager两种模式,fx模式体现在地平线封装的各种函数中,例如prepare_qat_fx(),就是在函数最后有fx字样。
  • calib是什么?
    calib是校准calibration的缩写,主要作用是确定量化参数,我们知道,合理的初始量化参数能够显著提升模型精度并加快模型的收敛速度。calibration 就是在浮点模型中插入 Observer,使用少量训练数据,在模型 forward 过程中统计各处的数据分布,以确定合理的量化参数的过程。虽然不做 Calibration 也可以进行量化训练,但一般来说,它对量化训练有益无害,所以推荐大家将此步骤作为必选项。
  • qat_model和quantized_model还不是一个意思?
    不一样的。
    qat_model是一种插入了伪量化节点的伪量化模型,简单理解为:它是为了量化训练而存在的模型,里面还“流淌”着浮点的参数,伪量化节点在模拟量化而已。
    quantized_model:模型中的浮点参数转换为定点参数,且把浮点算子转换成定点算子,这种转换后的模型称之为quantized_model /定点模型 / 量化模型。
  • script_model又是哪儿冒出来的?
    scipt_model是一种可以序列化的Torch脚本(TorchScript),方便在不需要Python解释器的环境中使用模型,例如C++应用程序、移动端应用等。scipt_model的获取通过torch.jit.trace实现。torch.jit.trace是PyTorch中的一个静态图转换工具,用于将一个PyTorch模型转换成一个可以序列化的Torch脚本(TorchScript)。其工作流程是,首先使用输入张量对模型进行前向计算,然后将计算图转换为Torch脚本。在这个过程中,PyTorch会执行所有与输入相关的计算,从而记录下计算图的结构和参数的值。
    以下是torch.jit.trace方法的基本语法:script_model = torch.jit.trace(model, example_inputs, optimize=True),其中,model是待转换的PyTorch模型,并不一定需要是quantized_model,普通的也可以,这里是QAT场景,因此是quantized_model。example_inputs是一个输入张量或元组,用于为模型执行前向计算,并记录计算图的结构和参数的值。optimize是一个布尔值,用于指定是否对转换后的计算图进行优化。默认情况下,optimize为True,将对计算图进行常量折叠、运算融合等优化。
  • 板端部署hbm模型我知道,就是可以在板子上推理的模型,类似于PTQ里的bin模型对吧?
    非常对。

小白:这些模型是如何生成的?通过图中那几个函数?是地平线封装好的,直接用?
大黑:是的。

3. 文件准备与程序运行

  • 一共就需要3个文件
(plugin) [xxx plugin_basic]$ tree -L 3
.
|-- data
|   |-- cifar-10-batches-py    					# cifar10数据集
|-- mobilenet_example_release_fx_only.py    	# 代码
|-- model
|   `-- mobilenetv2
|       |-- mobilenet_v2-b0353104.pth       	# 预训练权重

为了方便大家获取,以上文件均存放在网盘链接中:

链接:https://pan.baidu.com/s/1yJjjWEOB9rtHug77yA5mIw 
提取码:zdi5

代码运行,建议在地平线提供的docker里运行,当然,如果大家自己会配置本地环境的话,也可以不用docker,我两种都试了,都是ok的。

  • 运行过程
# 生成float-checkpoint.ckpt
python3 mobilenet_example_release_fx_only.py --stage=float 
# 生成calib-checkpoint.ckpt   
python3 mobilenet_example_release_fx_only.py --stage=calib
# 生成qat-checkpoint.ckpt    
python3 mobilenet_example_release_fx_only.py --stage=qat
# 使用定点quantized model evaluate一次      
python3 mobilenet_example_release_fx_only.py --stage=int_infer    
# 编译生成model.hbm,并对script_model进行可视化
python3 mobilenet_example_release_fx_only.py --stage=compile    

特别是在stage=compile,产出物有点多,在这儿具体介绍一下

模型名称 模型解读
int_model.pt torch.jit.save(script_model, “int_model.pt”)生成的,指 torchscript 模型
model.pt compile_model函数产出的中间产物,和int_model.pt是一回事,指 torchscript 模型
model.hbir compile_model函数产出的中间产物,用于出现问题时提供给地平线技术支持分析,我们不需要关注
model.hbm compile_model函数产出的最终产物,即板端可部署模型
xxx.html perf_model函数的产物,一个html文件,里面提供一些编译器层面分析出的性能信息

运行完全程,所有文件如下图:
0基础学习地平线QAT量化感知训练_第3张图片
跑起来很简单,下面再和大家一起看看代码层面的情况。

4. 代码详解

该章节参考地平线用户手册:XJ3用户手册 4.2.3 快速上手、J5用户手册 4.2.3. 快速入门,由于XJ3 OE包中未提供对应示例,代码参考的是J5 OE ddk/samples/ai_toolchain/horizon_model_train_sample/plugin_basic/mobilenet_example_release.py,OE包中代码是fx模式和eager模式混合在一起的,为了防止大家搞混,我给拆开了,这里只放fx模式的例子,其实XJ3用户手册 4.2.3 快速上手、J5用户手册 4.2.3. 快速入门都有提供fx模式对应ipynb的代码,只是我不太习惯而已,大家可以根据自己偏好使用。

4.1 导入必要依赖

之所以写这一节,主要是希望大家可以从注释中,简单了解各个函数的作用,像torch、os这种导入就省略没写,全部的依赖可以看提供的代码。其中,horizon_plugin_pytorch是地平线基于 PyTorch 开发的 的量化训练工具,可以理解成numpy这种库,里面有很多用于量化训练的的依赖,我们直接用就好了。

# 定义程序需要接收哪些命令行参数,以及这些参数的类型、默认值等信息。
import argparse     
# torch中的一个类,主要用于将量化操作的结果转换回浮点数,也就是对输出数据转换回浮点数
from torch.quantization import DeQuantStub
# 用CIFAR10数据集,简单快速
from torchvision.datasets import CIFAR10
# 导入两个类,用来当父类,目的是构建float_model。model_urls是一个字典
from torchvision.models.mobilenetv2 import (
    InvertedResidual,
    MobileNetV2,
    model_urls,
)      
# 从url中下载预训练权重
from torchvision._internally_replaced_utils import load_state_dict_from_url
# 硬件芯片架构,J5:bayes;XJ3:bernoulli2,具体可看源码
from horizon_plugin_pytorch.march import March, set_march       
from horizon_plugin_pytorch.quantization import (
    QuantStub,      # 类似于torch中的类QuantStub,用于将输入数据量化,使用plugin中的QuantStub是因为它支持通过参数手动固定 scale
    convert_fx,     # 将伪量化模型qat_model转换为定点模型quantized_model
    prepare_qat_fx, # 将float模型转成calib/qat模型,变动表现:进行一些conv+bn等算子融合
    set_fake_quantize,  # 用于设置qat/calib model 伪量化状态,内参包括FakeQuantState
    FakeQuantState,     # 用于设置伪量化状态,有FakeQuantState.QAT用于qat model train,FakeQuantState.VALIDATION用于qat/calib model eval,FakeQuantState.CALIBRATION用于 calib eval
    check_model,        # 用于检查模型是否可以被硬件支持,本例中输入是可序列化的script_model,并给出一些根据硬件对齐规则可以提升性能的建议
    compile_model,      # 用于编译生成可以上板的hbm模型
    perf_model,         # 用于推测模型耗时等信息
    visualize_model,    # 用于可视化算子优化替换后的模型结构
)
from horizon_plugin_pytorch.quantization.qconfig import (
    default_calib_8bit_fake_quant_qconfig,      # 校准时,模型总体伪量化节点的量化配置
    default_qat_8bit_fake_quant_qconfig,        # 量化训练时,模型总体伪量化节点的量化配置
    default_qat_out_8bit_fake_quant_qconfig,    # 模型输出的伪量化节点配置,用于配置输出conv节点高精度int32输出
    default_calib_out_8bit_fake_quant_qconfig,  # 和上一行是一个东西
)

4.2 主函数

看了第2节理论知识部分,主函数部分的代码就是严格执行那几个阶段stage(详见第2节),很easy,关于内部细节,在后面几个小节挨个介绍。

def main(
    stage: str,
    data_path: str,
    model_path: str,
    train_batch_size: int,
    eval_batch_size: int,
    epoch_num: int,
    device_id: int = 0,
    quant_method: str = "fx",
    march: str = March.BAYES,
    compile_opt: int = 0,
):
    # 对应操作几个阶段的模型
    assert stage in ("float", "calib", "qat", "int_infer", "compile")
    assert quant_method in ("fx")

    device = torch.device(
        "cuda:{}".format(device_id) if device_id >= 0 else "cpu"
    )

    if not os.path.exists(model_path):
        os.makedirs(model_path, exist_ok=True)

    # 浮点训练阶段优化器
    def float_optim_config(model: nn.Module):
        # This is an example to illustrate the usage of QAT training tool, so
        # we do not fine tune the training hyper params to get optimized
        # float model accuracy.
        optimizer = torch.optim.Adam(model.parameters(), weight_decay=2e-4)

        return optimizer, None

    # qat训练阶段优化器
    def qat_optim_config(model: nn.Module):
        # QAT training is targeted at fine tuning model params to match the
        # numerical quantization, so the learning rate should not be too large.
        optimizer = torch.optim.SGD(
            model.parameters(), lr=0.0001, weight_decay=2e-4
        )

        return optimizer, None

    default_epoch_num = {
        "float": 20,     
        "qat": 2,       # 通常float训练epoch数量是qat训练epoch数量的10倍
    }

    if stage in ("float", "qat"):
        if epoch_num is None:
            epoch_num = default_epoch_num[stage]

        train(
            data_path,
            model_path,
            train_batch_size,
            eval_batch_size,
            epoch_num,
            device,
            float_optim_config if stage == "float" else qat_optim_config,
            stage,
            march,
            quant_method,
        )

    elif stage == "calib":
        calibrate(
            data_path,
            model_path,
            train_batch_size,
            eval_batch_size,
            device,
            march=march,
            quant_method=quant_method,
        )

    elif stage == "int_infer":
        int_infer(
            data_path,
            model_path,
            eval_batch_size,
            device,
            march=march,
            quant_method=quant_method,
        )

    else:
        compile(
            data_path,
            model_path,
            compile_opt,
            march=march,
            quant_method=quant_method,
        )

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Run mobilenet example.")
    parser.add_argument(
        "--stage",
        type=str,
        choices=("float", "calib", "qat", "int_infer", "compile"),
        help=(
            "Pipeline stage, must be executed in following order: "
            "float -> calib(optional) -> qat(optional) -> int_infer -> compile"
        ),
    )
    parser.add_argument(
        "--data_path",
        type=str,
        default="data",
        help="Path to the cifar-10 dataset",
    )
    parser.add_argument(
        "--model_path",
        type=str,
        default="model/mobilenetv2",
        help="Where to save the model and other results",
    )
    parser.add_argument(
        "--train_batch_size",
        type=int,
        default=256,
        help="Batch size for training",
    )
    parser.add_argument(
        "--eval_batch_size",
        type=int,
        default=256,
        help="Batch size for evaluation",
    )
    parser.add_argument(
        "--epoch_num",
        type=int,
        default=None,
        help=(
            "Rewrite the default training epoch number, pass 0 to skip "
            "training and only do evaluation (in stage 'float' or 'qat')"
        ),
    )
    parser.add_argument(
        "--device_id",
        type=int,
        default=2,
        help="Specify which device to use, pass a negative value to use cpu",
    )
    parser.add_argument(
        "--quant_method",
        type=str,
        choices=["fx"],
        default="fx",
        help=(
            "Specify fx mode quantization."
            " Please do not change quant method "
            "between different stages, or the model may fail to load"
        ),
    )
    parser.add_argument(
        "--opt",
        type=str,
        choices=["0", "1", "2", "3", "ddr", "fast", "balance"],
        default=0,
        help="Specity optimization level for compilation",
    )
    args = parser.parse_args()
    print(args)

    main(
        args.stage,
        args.data_path,
        args.model_path,
        args.train_batch_size,
        args.eval_batch_size,
        args.epoch_num,
        args.device_id,
        args.quant_method,
        compile_opt=args.opt,
    )

4.3 构建fx模式所需要的float_model

从torchvision.models中继承MobileNetV2,微调一下,以支持量化相关操作。模型改造必要的操作有:

  • 在模型所有输入分支前插入 QuantStub
  • 在模型所有输出分支后插入 DequantStub

这部分具体实现过程解读可见代码注释。

# ----------------------------------------------------------------------------#
# At first, we do necessary modify to the MobilenetV2 model from torchvision.
# For FX mode, we need to:
# 1. Insert QuantStub before first layer and DequantStub after last layer.
# Operation replacement and fusion will be carried out automatically (^_^).
# ----------------------------------------------------------------------------#
# 在PyTorch中,QuantStub/DequantStub 是一种用于量化的辅助工具,
# 用于标记量化过程中需要量化/反量化的层或操作,
# 前期浮点训练时当它不存在,在量化时会自动被替换为对应的量化操作
# ----------------------------------------------------------------------------#
# 从torchvision.models中继承MobileNetV2,微调一下
class FxQATReadyMobileNetV2(MobileNetV2):
    def __init__(
        self,
        num_classes: int = 10,      # 实例变量,使用self.来引用变量
        width_mult: float = 0.5,
        inverted_residual_setting: Optional[List[List[int]]] = None,
        round_nearest: int = 8,
    ):
        super().__init__(   # 类变量,使用类名来引用变量,如ClassName.variable_name
            num_classes, width_mult, inverted_residual_setting, round_nearest
        )
        # --------------------------------------------------------------------#
        # 简单理解,在模型首尾部包一层类似于量化反量化操作,每个输入分支都需要包一下
        # --------------------------------------------------------------------#
        # 地平线plugin中的QuantStub可以配置scale
        # 这里的scale=1/128是后面模型输入配置为pyramid必备的
        # pyramid是地平线的芯片上的一个硬件,数据输入可以从这儿来,也可以从DDR来
        # --------------------------------------------------------------------#
        self.quant = QuantStub(scale=1 / 128)   
        self.dequant = DeQuantStub()

    def forward(self, x: Tensor) -> Tensor:
        x = self.quant(x)
        x = super().forward(x)
        x = self.dequant(x)

        return x

关于如何加载预训练权重部分的代码在函数load_pretrain里,详细内容可以看Python文件,这里不再呈现。

def load_pretrain(model: nn.Module, model_path: str):
    state_dict = load_state_dict_from_url(
        model_urls["mobilenet_v2"], model_dir=model_path, progress=True
    )   # model_urls是一个字典,取里面mobilenet_v2的对应url,下载路径到model_dir,progress是下载进度条显示

4.4 不同阶段模型的获取

在代码运行时,有个输入参数stage必须配置,表示拿到哪个model去整后面的事,当stage参数传入(“float”, “calib”, “qat”, “int_infer”)中某一个时,会通过如下函数去获取,具体实现过程解读可见代码注释。

# --------------------------------------------------------------------------#
# Next, we define the model convert pipeline to generate model for each stage.
# --------------------------------------------------------------------------#
def get_model(
    stage: str,
    model_path: str,
    device: torch.device,
    march=March.BAYES,
    quant_method="fx",
) -> nn.Module:
    # 运行代码时,有个输入参数stage必须配置,表示拿到哪个model去整后面的事
    assert stage in ("float", "calib", "qat", "int_infer")
    assert quant_method in ("fx")

    model_kwargs = dict(num_classes=10, width_mult=1.0)
    float_model = FxQATReadyMobileNetV2(**model_kwargs).to(device)

    if stage == "float":
        # Load pretrained model (on ImageNet) to speed up float training.
        load_pretrain(float_model, model_path)

        return float_model      # float的时候,到这儿就退出了

    # 浮点训练完成后的权重
    float_ckpt_path = os.path.join(model_path, "float-checkpoint.ckpt")
    assert os.path.exists(float_ckpt_path)
    float_state_dict = torch.load(float_ckpt_path, map_location=device)

    # A global march indicating the target hardware version must be setted
    # before prepare qat.
    set_march(march)

    # Preserve a clean float_model for calibration and qat training.
    ori_float_model = float_model         
    float_model = copy.deepcopy(ori_float_model)

    float_model.load_state_dict(float_state_dict)
    # -----------------------------------------------------------#
    # The op fusion is included in `prepare_qat_fx`.
    # -----------------------------------------------------------#
    # Make sure the output model is on target device.
    # CAUTION: prepare_qat_fx and convert_fx do not guarantee the
    # output model is on the same device as input model.
    # ----------------------------------------------------------#

    # ----------------从float_model转成calib_model----------------#
    float_model.qconfig = default_calib_8bit_fake_quant_qconfig
    # ----------------------------------------------------------------------#
    #   不配置输出层的qconfig,其输出默认是int8输出
    #   尾部conv/linear,calib和qat配置为
    #   default_{calib/qat}_out_8bit_fake_quant_qconfig时,表示int32高精度输出
    # ----------------------------------------------------------------------#
    float_model.classifier.qconfig = (
        default_calib_out_8bit_fake_quant_qconfig
    )
    calib_model = prepare_qat_fx(float_model).to(device)

    # calib stage时,函数到这儿就会返回了
    if stage == "calib":
        return calib_model

    calib_ckpt_path = os.path.join(model_path, "calib-checkpoint.ckpt")
    assert os.path.exists(calib_ckpt_path)
    calib_state_dict = torch.load(calib_ckpt_path, map_location=device)

    # ---------------------------------------------#
    #   这一行是必须的,上面的float_model已经"变味"了
    # ---------------------------------------------#
    float_model = copy.deepcopy(ori_float_model)

    # 尾部conv/linear,qat配置为default_qat_out_***_qconfig时,可为int32高精度输出
    qat_model = prepare_qat_fx(
        float_model,        # 这儿必须是float_model,不能是calib_model,也不能是"变味"的float_model
        {
            "": default_qat_8bit_fake_quant_qconfig,
            "module_name": {
                "classifier": default_qat_out_8bit_fake_quant_qconfig,
            },
        },
    ).to(device)    # prepare_qat_fx 接口不保证输出模型的 device 和输入模型完全一致

    # qat_model加载的是calib_state_dict!!!
    qat_model.load_state_dict(calib_state_dict)

    if stage == "qat":    # qat阶段到这儿就退出了
        return qat_model

    qat_ckpt_path = os.path.join(model_path, "qat-checkpoint.ckpt")
    assert os.path.exists(qat_ckpt_path)
    qat_model.load_state_dict(torch.load(qat_ckpt_path, map_location=device))

    # 将模型转为定点状态
    # 通过参数转换把伪量化模型中的浮点参数转换成定点参数,
    # 并且把浮点算子转换成定点算子,该转换后的模型称为 Quantized 模型 / 定点模型 / 量化模型
    quantized_model = convert_fx(qat_model).to(device)

    return quantized_model    # int_infer阶段会到这儿才退出

4.5 定义常规模型训练与验证的函数

具体实现,看py代码就行,很常规。

# --------------------------------------------------------------------------#
# Next, we define dataloaders and other helper functions used in training
# and evaluation.
# --------------------------------------------------------------------------#

def prepare_data_loaders(
    data_path: str, train_batch_size: int, eval_batch_size: int
) -> Tuple[data.DataLoader, data.DataLoader]:


class AverageMeter(object):
    """Computes and stores the average and current value"""
    

def accuracy(output: Tensor, target: Tensor, topk=(1,)) -> List[Tensor]:
    """Computes the accuracy over the k top predictions for the specified values of k"""
    

def train_one_epoch(
    model: nn.Module,
    criterion: Callable,
    optimizer: torch.optim.Optimizer,
    scheduler: Optional[torch.optim.lr_scheduler._LRScheduler],
    data_loader: data.DataLoader,
    device: torch.device,
) -> None:


def evaluate(
    model: nn.Module, data_loader: data.DataLoader, device: torch.device
) -> Tuple[AverageMeter, AverageMeter]:

4.6 float与qat训练代码解读——float_model/qat_model

针对float_model和qat_model的参数训练,代码解读如下,

# --------------------------------------------------------------------------#
# Next, we define the main function for each stage.
# --------------------------------------------------------------------------#

# Float and qat share the same training procedure.
def train(
    data_path: str,
    model_path: str,
    train_batch_size: int,
    eval_batch_size: int,
    epoch_num: int,
    device: torch.device,
    optim_config: Callable,
    stage: str,
    march=March.BAYES,
    quant_method="fx",
):
    # --------------------------------------------#
    #   qat模型训练和普通浮点模型训练的不同之处!
    # --------------------------------------------#
    model = get_model(stage, model_path, device, march, quant_method)

    train_data_loader, eval_data_loader = prepare_data_loaders(
        data_path, train_batch_size, eval_batch_size
    )

    optimizer, scheduler = optim_config(model)

    best_acc = 0

    for nepoch in range(epoch_num):
        # Train/Eval state must be setted correctly
        # before `set_fake_quantize`
        model.train()
        # --------------------------------------------#
        #   qat模型训练和普通浮点模型训练的不同之处!
        # --------------------------------------------#
        if stage == "qat":
            set_fake_quantize(model, FakeQuantState.QAT)

        train_one_epoch(
            model,
            nn.CrossEntropyLoss(),
            optimizer,
            scheduler,
            train_data_loader,
            device,
        )

        model.eval()
        # --------------------------------------------#
        #   qat模型训练和普通浮点模型训练的不同之处!
        # --------------------------------------------#
        if stage == "qat":
            set_fake_quantize(model, FakeQuantState.VALIDATION)

        top1, top5 = evaluate(
            model,
            eval_data_loader,
            device,
        )
        print(
            "{} Epoch {}: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
                stage.capitalize(), nepoch, top1.avg, top5.avg
            )
        )

        if top1.avg > best_acc:
            best_acc = top1.avg

            torch.save(
                model.state_dict(),
                os.path.join(model_path, "{}-checkpoint.ckpt".format(stage)),
            )   # 可用于保存 float-checkpoint.ckpt 和 qat-checkpoint.ckpt

    # ----------------------------------------------#
    #   当传入epoch_num=0,用于qat eval
    # ----------------------------------------------#
    if nepoch == 0:
        model.eval()
        if stage == "qat":
            set_fake_quantize(model, FakeQuantState.VALIDATION)

        top1, top5 = evaluate(
            model,
            eval_data_loader,
            device,
        )
        print(
            "{} eval only: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
                stage.capitalize(), top1.avg, top5.avg
            )
        )   # stage.capitalize()表示将字符串首字母大写

    print("Best Acc@1 {:.3f}".format(best_acc))

    return model

4.7 模型校准部分的代码解读——calib_model

float模型训练完成后,需要进行参数校准,得到calib_model,如果calib_model精度满足要求,qat训练就不需要了,即使calib_model精度不行,calib_model_state_dict(校准后的权重)对qat训练收敛也非常有帮助。

def calibrate(
    data_path,
    model_path,
    calib_batch_size,
    eval_batch_size,
    device,
    num_examples=float("inf"),  # float("inf")表示无穷大,主要用于控制使用多少数据进行校准,默认使用所有数据集
    march=March.BAYES,
    quant_method="fx",
):
    calib_model = get_model("calib", model_path, device, march, quant_method)
    # Please note that calibration need the model in eval mode
    # to make BatchNorm act properly.
    calib_model.eval()  # 即使下面用的是train数据集,这儿也是eval
    # set CALIBRATION state will make FakeQuantize in training mode.
    set_fake_quantize(calib_model, FakeQuantState.CALIBRATION)

    train_data_loader, eval_data_loader = prepare_data_loaders(
        data_path, calib_batch_size, eval_batch_size
    )

    with torch.no_grad():
        cnt = 0
        for image, target in train_data_loader:
            image, target = image.to(device), target.to(device)
            calib_model(image)
            print(".", end="", flush=True)
            cnt += image.size(0)
            if cnt >= num_examples:     # 主要用于控制使用多少数据进行校准,默认使用所有数据集
                break
        print()

    # Must set eval mode again before validation, because
    # set CALIBRATION state will make FakeQuantize in training mode.
    calib_model.eval()  
    set_fake_quantize(calib_model, FakeQuantState.VALIDATION)

    top1, top5 = evaluate(
        calib_model,
        eval_data_loader,
        device,
    )
    print(
        "Calibration: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
            top1.avg, top5.avg
        )
    )

    torch.save(
        calib_model.state_dict(),
        os.path.join(model_path, "calib-checkpoint.ckpt"),
    )

    return calib_model

4.8 定点模型评测精度 代码解读——quantized_model

定点模型/quantized模型/量化模型 eval推理一下看看精度

# 定点模型/quantized模型/量化模型 eval推理一下看看精度
def int_infer(
    data_path,
    model_path,
    eval_batch_size,
    device,
    march=March.BAYES,
    quant_method="fx",
):
    # 定点模型/quantized模型/量化模型
    quantized_model = get_model(
        "int_infer", model_path, device, march, quant_method
    )

    _, eval_data_loader = prepare_data_loaders(
        data_path, eval_batch_size, eval_batch_size
    )

    top1, top5 = evaluate(
        quantized_model,
        eval_data_loader,
        device,
    )
    print(
        "Quantized: evaluation Acc@1 {:.3f} Acc@5 {:.3f}".format(
            top1.avg, top5.avg
        )
    )

    return quantized_model

4.9 编译生成上板模型——script_model/model.hbm

编译生成上板模型model.hbm,同时针对script_model预估模型性能,并进行可视化

def compile(
    data_path,
    model_path,
    compile_opt=0,
    march=March.BAYES,
    quant_method="fx",
):
    # It is recommended to do compile on cpu, because associated interfaces
    # do not fully support cuda.
    device = torch.device("cpu")

    # 定点模型
    quantized_model = get_model(
        "int_infer", model_path, device, march, quant_method
    )

    # prepare_data_loaders(data_path: str, train_batch_size: int, eval_batch_size: int)
    _, eval_data_loader = prepare_data_loaders(data_path, 1, 1)

    # We can generate random input data (in proper shape) for
    # tracing and compiling and so on.
    # Use real data in `perf_model` will get more accurate perf result.
    example_input = next(iter(eval_data_loader))[0]     # Tensor

    # ------------------------------------------------------------------#
    #   torch.jit.trace是PyTorch中的一个静态图转换工具,
    #   用于将一个PyTorch模型转换成一个可以序列化的Torch脚本(TorchScript),
    #   以便在不需要Python解释器的环境中使用模型。
    #   model并不一定需要是quantized_model,普通的也可以,这里是QAT场景
    # ------------------------------------------------------------------#
    script_model = torch.jit.trace(quantized_model.cpu(), example_input)    # 单纯为了更保险,这儿再次加上.cpu()
    # 这个.pt结尾,就和手册中术语约定对上了:文档中的 pt 模型指 torchscript 模型
    torch.jit.save(script_model, os.path.join(model_path, "int_model.pt"))  

    check_model(script_model, [example_input], advice=1)

    compile_model(
        script_model,
        [example_input],
        hbm=os.path.join(model_path, "model.hbm"),
        input_source="pyramid",     # 上板时输入的数据来源,通常有ddr/resizer/pyramid,多输入时配置为字符串列表
        opt=compile_opt,
    )

    # hbdk预估模型性能,生成html文件,里面提供一些性能评测信息
    perf_model(
        script_model,
        [example_input],
        out_dir=os.path.join(model_path, "perf_out"),
        input_source="pyramid",
        opt=compile_opt,
        layer_details=True,     # html中会提供逐层算子耗时
    )

    # 可视化torchscript模型,也就是hbdk眼中的模型,会考虑到layout的变换、硬件对齐、算子融合、算子等效替换等情况
    visualize_model(
        script_model,
        [example_input],
        save_path=os.path.join(model_path, "model.svg"),
        show=False,
    )

    return script_model

5. 建议or吐槽

免责声明:纯纯吐槽,如有雷同,请勿当真!

  • 提供用户手册、提供上手示例,很棒!只是说好的快速上手示例,能麻烦大佬们写的基础一点嘛~

  • 一定要善于看源码,里面有函数的作用和使用方法的介绍,很有用!可惜我用vscode在docker里总是无法跳转,馋哭了,其实可以有个笨方法,如下图
    0基础学习地平线QAT量化感知训练_第4张图片

  • 初次上手的例子,建议和我们说一个最标准的流程就好了,像float_model到底选用origin_float_model更好还是FxQATReadyModel更好?calib这一步到底要不要?qat_model到底加载float_state_dict更合适还是calib_state_dict更合适?这些问题在我初次看代码时产生了一些疑惑~

  • X3的OE包里,能否像J5 OE包里一样提供plugin_basic的例子?要不是J5 OE包也对外释放了,都学不到这种好东西,偏心了啊!

  • J5 OE包里提供的plugin_basic例子,能否把fx和eager拆开成两个py文件?放到一起,刚开始学的时候总搞混…(当然,也可能是我水平问题)

  • 用户手册中把快速上手部分全部可执行代码放出来,感觉还挺好的,适合我这种小白,当然,在OE包里还有一份全面的代码,感觉在手册里告诉我它在OE包里的位置,这样也可以接受。其实我想说:手册中更建议多放点需要跟着操作的步骤,或者理论介绍,或者代码多点注释,不是很理解为啥把全部log日志都贴出来了(4.2.3 快速上手)!输出日志部分,放点开头、结尾、关键部分说明意思就行,想看全部的话,我自己会去跑跑试一下的,难道手册有最低字数限制?

  • 想让尾部conv以高精度int32输出,竟然配置的是default_qat_out_8bit_fake_quant_qconfig,大大问号脸?明明是out_8bit啊!后来咨询技术支持,原来这里的8bit是weight的量化方式为8bit。感觉这个命名有点容易造成误解,不知道能否修改为qat_out_int32_weight_8bit_fake_quant_qconfig?(反正都已经很长了…,哦在最新发的版本中已修改为default_qat_8bit_weight_32bit_act_fake_quant_qconfig,这里的act应该是activation的缩写,表示节点输出)

  • OE包中看着提供了很多例子,但例子之间又有很多共用的代码,造成非常多的嵌套,我就参考其中一个,还得下载整个OE包,不知道能否拆开例子,放到github或者gitee上,想参考哪个我就下载哪个多好!

  • 能否给点从浮点训练 到 量化转换编译 到 上板部署(python/c++) 到 可视化 的全流程示例仓库,本来生态就不如英伟达,支持国产总得让我们用起来很顺溜才好吧!建议搞点全流程例子给我们!(理直气不壮)

都看到这儿了,如果对您有帮助的话,麻烦给点个赞呀~

你可能感兴趣的:(地平线开发板相关,QAT,量化感知训练,plugin)