MegEngine Python 层模块串讲(下)

前面的文章中,我们简单介绍了在 MegEngine imperative 中的各模块以及它们的作用。对于新用户而言可能不太了解各个模块的使用方法,对于模块的结构和原理也是一头雾水。Python 作为现在深度学习领域的主流编程语言,其相关的模块自然也是深度学习框架的重中之重。

模块串讲将对 MegEngine 的 Python 层相关模块分别进行更加深入的介绍,会涉及到一些原理的解释和代码解读。Python 层模块串讲共分为上、中、下三个部分,本文将介绍 Python 层的 quantization 模块。量化是为了减少模型的存储空间和计算量,从而加速模型的推理过程。在量化中,我们将权重和激活值从浮点数转换为整数,从而减少模型的大小和运算的复杂性。通过本文读者将会对量化的基本原理和使用 MegEngine 得到量化模型有所了解。

降低模型内存占用利器 —— quantization 模块

量化是一种对深度学习模型参数进行压缩以降低计算量的技术。它基于这样一种思想:神经网络是一个近似计算过程,不需要其中每个计算过程的绝对的精确。因此在某些情况下可以把需要较多比特存储的模型参数转为使用较少比特存储,而不影响模型的精度。

量化通过舍弃数值表示上的精度来追求极致的推理速度。直觉上用低精度/比特类型的模型参数会带来较大的模型精度下降(称之为掉点),但在经过一系列精妙的量化处理之后,掉点可以变得微乎其微。

如下图所示,量化通常是将浮点模型(常见神经网络的 Tensor 数据类型一般是 float32)处理为一个量化模型(Tensor 数据类型为 int8 等)。

量化基本流程

MegEngine 中支持工业界的两类主流量化技术,分别是训练后量化(PTQ)和量化感知训练(QAT)。

  1. 训练后量化(Post-Training QuantizationPTQ

    训练后量化,顾名思义就是将训练后的 Float 模型转换成低精度/比特模型。

    比较常见的做法是对模型的权重(weight)和激活值(activation)进行处理,把它们转换成精度更低的类型。虽然是在训练后再进行精度转换,但为了获取到模型转换需要的一些统计信息(比如缩放因子 scale),仍然需要在模型进行前向计算时插入观察者(Observer)。

    使用训练后量化技术通常会导致模型掉点,某些情况下甚至会导致模型不可用。可以使用小批量数据在量化之前对 Observer 进行校准(Calibration),这种方案叫做 Calibration 后量化。也可以使用 QAT 方案。

  2. 量化感知训练(Quantization-Aware TrainingQAT

    QAT 会向 Float 模型中插入一些伪量化(FakeQuantize)算子,在前向计算过程中伪量化算子根据 Observer 观察到的信息进行量化模拟,模拟数值截断的情况下的数值转换,再将转换后的值还原为原类型。让被量化对象在训练时“提前适应”量化操作,减少训练后量化的掉点影响。

    而增加这些伪量化算子模拟量化过程又会增加训练开销,因此模型量化通常的思路是:

    • 按照平时训练模型的流程,设计好 Float 模型并进行训练,得到一个预训练模型;
    • 插入 Observer 和 FakeQuantize 算子,得到 Quantized-Float 模型(QFloat 模型)进行量化感知训练;
    • 训练后量化,得到真正的 Quantized 模型(Q 模型),也就是最终用来进行推理的低比特模型。

    过程如下图所示(实际使用时,量化流程也可能会有变化):

  1. 注意这里的量化感知训练 QAT 是在预训练好的 QFloat 模型上微调(Fine-tune)的(而不是在原来的 Float 模型上),这样减小了训练的开销,得到的微调后的模型再做训练后量化 PTQ(“真量化”),QModel 就是最终部署的模型。

模型(Model)与模块(Module

量化是一个对模型(Model)的转换操作,但其本质其实是对模型中的模块( Module) 进行替换。

在 MegEngine 中,对应与 Float Model 、QFloat Model 和 Q Model 的 Module 分别为:

  1. 进行正常 float 运算的默认 Module
  2. 带有 Observer 和 FakeQuantize 算子的 qat.QATModule
  3. 无法训练、专门用于部署的 quantized.QuantizedModule

以 Conv 算子为例,这些 Module 对应的实现分别在:

量化配置 QConfig

量化配置包括 Observer 和 FakeQuantize 两部分,要设置它们,用户可以使用 MegEngine 预设配置也可以自定义配置。

  1. 使用 MegEngine 预设配置

    MegEngine 提供了多种量化预设配置

    以 ema_fakequant_qconfig 为例,用户可以通过如下代码使用该预设配置:

    import megengine.quantization as Q
    Q.quantize_qat(model, qconfig=Q.ema_fakequant_qconfig)
  2. 用户自定义量化配置

    用户还可以自己选择 Observer 和 FakeQuantize,灵活配置 QConfig 灵活选择 weight_observeract_observerweight_fake_quant 和 act_fake_quant)。

    可选的 Observer 和 FakeQuantize 可参考量化 API 参考页面。

QConfig 提供了一系列用于对模型做量化的接口,要使用这些接口,需要网络的 Module 能够在 forward 时给权重、激活值加上 Observer 以及进行 FakeQuantize

模型转换的作用是:将普通的 Float Module 替换为支持这些操作的 QATModule(可以训练),再替换为 QuantizeModule(无法训练、专用于部署)。

以 Conv2d 为例,模型转换的过程如图:

在量化时常常会用到算子融合(Fusion)。比如一个 Conv2d 算子加上一个 BatchNorm2d 算子,可以用一个 ConvBn2d 算子来等价替代,这里 ConvBn2d 算子就是 Conv2d 和 BatchNorm2d 的融合算子。

MegEngine 中提供了一些预先融合好的 Module,比如 ConvRelu2dConvBn2d 和 ConvBnRelu2d 等。使用融合算子会使用底层实现好的融合算子(kernel),而不会分别调用子模块在底层的 kernel,因此能够加快模型的速度,而且框架还无需根据网络结构进行自动匹配和融合优化,同时存在融合和不需融合的算子也可以让用户能更好的控制网络转换的过程。

实现预先融合的 Module 也有缺点,那就是用户需要在代码中修改原先的网络结构(把可以融合的多个 Module 改为融合后的 Module)。

模型转换的原理是,将父 Module 中的 Quantable (可被量化的)子 Module 替换为新 Module。而这些 Quantable submodule 中可能又包含 Quantable submodule,这些 submodule 不会再进一步转换,因为其父 Module 被替换后的 forward 计算过程已经改变了,不再依赖于这些子 Module

有时候用户不希望对模型的部分 Module 进行转换,而是保留其 Float 状态(比如转换会导致模型掉点),则可以使用 disable_quantize 方法关闭量化。

比如下面这行代码关闭了 fc 层的量化处理:

model.fc.disable_quantize()

由于模型转换过程修改了原网络结构,因此模型保存与加载无法直接适用于转换后的网络,读取新网络保存的参数时,需要先调用转换接口得到转换后的网络,才能用 load_state_dict 将参数进行加载。

量化代码

要从一个 Float 模型得到一个可用于部署的量化模型,大致需要经历三个步骤:

  1. 修改网络结构。将 Float 模型中的普通 Module 替换为已经融合好的 Module,比如 ConvBn2dConvBnRelu2d 等(可以参考 imperative/python/megengine/module/quantized 目录下提供的已融合模块)。然后在正常模式下预训练模型,并且在每轮迭代保存网络检查点。

    以 ResNet18 的 BasicBlock 为例,模块修改前的代码为:

    class BasicBlock(M.Module):
          def __init__(self, in_channels, channels):
             super().__init__()
             self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=dilation, bias=False)
             self.bn1 = M.BatchNorm2d
             self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False)
             self.bn2 = M.BatchNorm2d
             self.downsample = (
                M.Identity()
                if in_channels == channels and stride == 1
                else M.Sequential(
                M.Conv2d(in_channels, channels, 1, stride, bias=False)
                M.BatchNorm2d
             )
    ​
          def forward(self, x):
             identity = x
             x = F.relu(self.bn1(self.conv1(x)))
             x = self.bn2(self.conv2(x))
             identity = self.downsample(identity)
             x = F.relu(x + identity)
             return x

    注意到现在的前向中使用的都是普通 Module 拼接在一起,而实际上许多模块是可以融合的。

用可以融合的模块替换掉原先的 Module

class BasicBlock(M.Module):
      def __init__(self, in_channels, channels):
         super().__init__()
         self.conv_bn_relu1 = M.ConvBnRelu2d(in_channels, channels, 3, 1, padding=dilation, bias=False)
         self.conv_bn2 = M.ConvBn2d(channels, channels, 3, 1, padding=1, bias=False)
         self.downsample = (
            M.Identity()
            if in_channels == channels and stride == 1
            else M.ConvBn2d(in_channels, channels, 1, 1, bias=False)
         )
         self.add_relu = M.Elemwise("FUSE_ADD_RELU")
​
      def forward(self, x):
         identity = x
         x = self.conv_bn_relu1(x)
         x = self.conv_bn2(x)
         identity = self.downsample(identity)
         x = self.add_relu(x, identity)
         return x

注意到此时前向中已经有许多模块使用的是融合后的 Module

再对该模型进行若干论迭代训练,并保存检查点:

for step in range(0, total_steps):
    # Linear learning rate decay
    epoch = step // steps_per_epoch
    learning_rate = adjust_learning_rate(step, epoch)
​
    image, label = next(train_queue)
    image = tensor(image.astype("float32"))
    label = tensor(label.astype("int32"))
​
    n = image.shape[0]
​
    loss, acc1, acc5 = train_func(image, label, net, gm)  # traced
    optimizer.step().clear_grad()
​
    # Save checkpoints

完整代码见:

-   [修改前的模型结构](https://github.com/MegEngine/Models/blob/master/official/vision/classification/resnet/model.py)
-   [修改后的模型结构](https://github.com/MegEngine/Models/blob/master/official/quantization/models/resnet.py)
  1. 调用 quantize_qat 方法 将 Float 模型转换为 QFloat 模型,并进行微调(量化感知训练或校准,取决于 QConfig)。

    使用 quantize_qat 方法将 Float 模型转换为 QFloat 模型的代码大致为:

    from megengine.quantization import ema_fakequant_qconfig, quantize_qat
    ​
    model = ResNet18()
    ​
    # QAT
    quantize_qat(model, ema_fakequant_qconfig)
    ​
    # Or Calibration:
    # quantize_qat(model, calibration_qconfig)

    将 Float 模型转换为 QFloat 模型后,加载预训练 Float 模型保存的检查点进行微调 / 校准:

    if args.checkpoint:
        logger.info("Load pretrained weights from %s", args.checkpoint)
        ckpt = mge.load(args.checkpoint)
        ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt
        model.load_state_dict(ckpt, strict=False)
    ​
    # Fine-tune / Calibrate with new traced train_func
    # Save checkpoints

    完整代码见:

  2. 调用 quantize 方法将 QFloat 模型转换为 Q 模型,也就是可用于模型部署的量化模型。

需要在推理的方法中设置 trace 的 capture_as_const=True,以进行模型导出:

from megengine.quantization import quantize
​
@jit.trace(capture_as_const=True)
def infer_func(processed_img):
    model.eval()
    logits = model(processed_img)
    probs = F.softmax(logits)
    return probs
​
quantize(model)
​
processed_img = transform.apply(image)[np.newaxis, :]
processed_img = processed_img.astype("int8")
probs = infer_func(processed_img)
​
infer_func.dump(output_file, arg_names=["data"])

调用了 quantize 后,model 就从 QFloat 模型转换为了 Q 模型,之后便使用这个 Quantized 模型进行推理。

调用 dump 方法将模型导出,便得到了一个可用于部署的量化模型。

完整代码见:

MegEngine Python 层模块串讲系列到这里就结束了,我们介绍了用户在使用 MegEngine 时主要会接触到的 python 层的各个模块的主要功能、结构以及使用方法,此外还有一些原理性的介绍。对于各模块具体实现感兴趣的读者可以参考 MegEngine 官方文档 和 github。之后的文章我们会对 MegEngine 开发相关工具以及 MegEngine 底层的实现做更深入的介绍。

你可能感兴趣的:(MegEngine Python 层模块串讲(下))