在前面的文章中,我们简单介绍了在 MegEngine imperative
中的各模块以及它们的作用。对于新用户而言可能不太了解各个模块的使用方法,对于模块的结构和原理也是一头雾水。Python
作为现在深度学习领域的主流编程语言,其相关的模块自然也是深度学习框架的重中之重。
模块串讲将对 MegEngine
的 Python
层相关模块分别进行更加深入的介绍,会涉及到一些原理的解释和代码解读。Python
层模块串讲共分为上、中、下三个部分,本文将介绍 Python
层的 quantization
模块。量化是为了减少模型的存储空间和计算量,从而加速模型的推理过程。在量化中,我们将权重和激活值从浮点数转换为整数,从而减少模型的大小和运算的复杂性。通过本文读者将会对量化的基本原理和使用 MegEngine
得到量化模型有所了解。
降低模型内存占用利器 —— quantization 模块
量化是一种对深度学习模型参数进行压缩以降低计算量的技术。它基于这样一种思想:神经网络是一个近似计算过程,不需要其中每个计算过程的绝对的精确。因此在某些情况下可以把需要较多比特存储的模型参数转为使用较少比特存储,而不影响模型的精度。
量化通过舍弃数值表示上的精度来追求极致的推理速度。直觉上用低精度/比特类型的模型参数会带来较大的模型精度下降(称之为掉点),但在经过一系列精妙的量化处理之后,掉点可以变得微乎其微。
如下图所示,量化通常是将浮点模型(常见神经网络的 Tensor
数据类型一般是 float32
)处理为一个量化模型(Tensor
数据类型为 int8
等)。
量化基本流程
MegEngine
中支持工业界的两类主流量化技术,分别是训练后量化(PTQ
)和量化感知训练(QAT
)。
训练后量化(
Post-Training Quantization
,PTQ
)训练后量化,顾名思义就是将训练后的
Float
模型转换成低精度/比特模型。比较常见的做法是对模型的权重(
weight
)和激活值(activation
)进行处理,把它们转换成精度更低的类型。虽然是在训练后再进行精度转换,但为了获取到模型转换需要的一些统计信息(比如缩放因子scale
),仍然需要在模型进行前向计算时插入观察者(Observer
)。使用训练后量化技术通常会导致模型掉点,某些情况下甚至会导致模型不可用。可以使用小批量数据在量化之前对
Observer
进行校准(Calibration
),这种方案叫做Calibration
后量化。也可以使用QAT
方案。量化感知训练(
Quantization-Aware Training
,QAT
)QAT
会向Float
模型中插入一些伪量化(FakeQuantize
)算子,在前向计算过程中伪量化算子根据Observer
观察到的信息进行量化模拟,模拟数值截断的情况下的数值转换,再将转换后的值还原为原类型。让被量化对象在训练时“提前适应”量化操作,减少训练后量化的掉点影响。而增加这些伪量化算子模拟量化过程又会增加训练开销,因此模型量化通常的思路是:
- 按照平时训练模型的流程,设计好
Float
模型并进行训练,得到一个预训练模型; - 插入
Observer
和FakeQuantize
算子,得到Quantized-Float
模型(QFloat
模型)进行量化感知训练; - 训练后量化,得到真正的
Quantized
模型(Q
模型),也就是最终用来进行推理的低比特模型。
过程如下图所示(实际使用时,量化流程也可能会有变化):
- 按照平时训练模型的流程,设计好
- 注意这里的量化感知训练
QAT
是在预训练好的QFloat
模型上微调(Fine-tune
)的(而不是在原来的Float
模型上),这样减小了训练的开销,得到的微调后的模型再做训练后量化PTQ
(“真量化”),QModel
就是最终部署的模型。
模型(Model
)与模块(Module
)
量化是一个对模型(Model
)的转换操作,但其本质其实是对模型中的模块( Module
) 进行替换。
在 MegEngine
中,对应与 Float Model
、QFloat Model
和 Q Model
的 Module
分别为:
- 进行正常
float
运算的默认Module
- 带有
Observer
和FakeQuantize
算子的qat.QATModule
- 无法训练、专门用于部署的
quantized.QuantizedModule
以 Conv
算子为例,这些 Module
对应的实现分别在:
Float Module
:imperative/python/megengine/module/conv.pyqat.QATModule
:imperative/python/megengine/module/qat/conv.pyquantized.QuantizedModule
:imperative/python/megengine/module/quantized/conv.py
量化配置 QConfig
量化配置包括 Observer
和 FakeQuantize
两部分,要设置它们,用户可以使用 MegEngine
预设配置也可以自定义配置。
使用
MegEngine
预设配置MegEngine
提供了多种量化预设配置。以
ema_fakequant_qconfig
为例,用户可以通过如下代码使用该预设配置:import megengine.quantization as Q Q.quantize_qat(model, qconfig=Q.ema_fakequant_qconfig)
用户自定义量化配置
用户还可以自己选择
Observer
和FakeQuantize
,灵活配置 QConfig 灵活选择weight_observer
、act_observer
、weight_fake_quant
和act_fake_quant
)。可选的
Observer
和FakeQuantize
可参考量化 API 参考页面。
QConfig
提供了一系列用于对模型做量化的接口,要使用这些接口,需要网络的 Module
能够在 forward
时给权重、激活值加上 Observer
以及进行 FakeQuantize
。
模型转换的作用是:将普通的 Float Module
替换为支持这些操作的 QATModule
(可以训练),再替换为 QuantizeModule
(无法训练、专用于部署)。
以 Conv2d
为例,模型转换的过程如图:
在量化时常常会用到算子融合(Fusion
)。比如一个 Conv2d
算子加上一个 BatchNorm2d
算子,可以用一个 ConvBn2d
算子来等价替代,这里 ConvBn2d
算子就是 Conv2d
和 BatchNorm2d
的融合算子。
MegEngine
中提供了一些预先融合好的 Module
,比如 ConvRelu2d
、ConvBn2d
和 ConvBnRelu2d
等。使用融合算子会使用底层实现好的融合算子(kernel
),而不会分别调用子模块在底层的 kernel
,因此能够加快模型的速度,而且框架还无需根据网络结构进行自动匹配和融合优化,同时存在融合和不需融合的算子也可以让用户能更好的控制网络转换的过程。
实现预先融合的 Module
也有缺点,那就是用户需要在代码中修改原先的网络结构(把可以融合的多个 Module
改为融合后的 Module
)。
模型转换的原理是,将父 Module
中的 Quantable
(可被量化的)子 Module
替换为新 Module
。而这些 Quantable submodule
中可能又包含 Quantable submodule
,这些 submodule
不会再进一步转换,因为其父 Module
被替换后的 forward
计算过程已经改变了,不再依赖于这些子 Module
。
有时候用户不希望对模型的部分 Module
进行转换,而是保留其 Float
状态(比如转换会导致模型掉点),则可以使用 disable_quantize
方法关闭量化。
比如下面这行代码关闭了 fc
层的量化处理:
model.fc.disable_quantize()
由于模型转换过程修改了原网络结构,因此模型保存与加载无法直接适用于转换后的网络,读取新网络保存的参数时,需要先调用转换接口得到转换后的网络,才能用 load_state_dict
将参数进行加载。
量化代码
要从一个 Float
模型得到一个可用于部署的量化模型,大致需要经历三个步骤:
修改网络结构。将
Float
模型中的普通Module
替换为已经融合好的Module
,比如ConvBn2d
、ConvBnRelu2d
等(可以参考 imperative/python/megengine/module/quantized 目录下提供的已融合模块)。然后在正常模式下预训练模型,并且在每轮迭代保存网络检查点。以
ResNet18
的BasicBlock
为例,模块修改前的代码为:class BasicBlock(M.Module): def __init__(self, in_channels, channels): super().__init__() self.conv1 = M.Conv2d(in_channels, channels, 3, 1, padding=dilation, bias=False) self.bn1 = M.BatchNorm2d self.conv2 = M.Conv2d(channels, channels, 3, 1, padding=1, bias=False) self.bn2 = M.BatchNorm2d self.downsample = ( M.Identity() if in_channels == channels and stride == 1 else M.Sequential( M.Conv2d(in_channels, channels, 1, stride, bias=False) M.BatchNorm2d ) def forward(self, x): identity = x x = F.relu(self.bn1(self.conv1(x))) x = self.bn2(self.conv2(x)) identity = self.downsample(identity) x = F.relu(x + identity) return x
注意到现在的前向中使用的都是普通
Module
拼接在一起,而实际上许多模块是可以融合的。
用可以融合的模块替换掉原先的 Module
:
class BasicBlock(M.Module):
def __init__(self, in_channels, channels):
super().__init__()
self.conv_bn_relu1 = M.ConvBnRelu2d(in_channels, channels, 3, 1, padding=dilation, bias=False)
self.conv_bn2 = M.ConvBn2d(channels, channels, 3, 1, padding=1, bias=False)
self.downsample = (
M.Identity()
if in_channels == channels and stride == 1
else M.ConvBn2d(in_channels, channels, 1, 1, bias=False)
)
self.add_relu = M.Elemwise("FUSE_ADD_RELU")
def forward(self, x):
identity = x
x = self.conv_bn_relu1(x)
x = self.conv_bn2(x)
identity = self.downsample(identity)
x = self.add_relu(x, identity)
return x
注意到此时前向中已经有许多模块使用的是融合后的 Module
。
再对该模型进行若干论迭代训练,并保存检查点:
for step in range(0, total_steps):
# Linear learning rate decay
epoch = step // steps_per_epoch
learning_rate = adjust_learning_rate(step, epoch)
image, label = next(train_queue)
image = tensor(image.astype("float32"))
label = tensor(label.astype("int32"))
n = image.shape[0]
loss, acc1, acc5 = train_func(image, label, net, gm) # traced
optimizer.step().clear_grad()
# Save checkpoints
完整代码见:
- [修改前的模型结构](https://github.com/MegEngine/Models/blob/master/official/vision/classification/resnet/model.py)
- [修改后的模型结构](https://github.com/MegEngine/Models/blob/master/official/quantization/models/resnet.py)
调用 quantize_qat 方法 将
Float
模型转换为QFloat
模型,并进行微调(量化感知训练或校准,取决于QConfig
)。使用
quantize_qat
方法将Float
模型转换为QFloat
模型的代码大致为:from megengine.quantization import ema_fakequant_qconfig, quantize_qat model = ResNet18() # QAT quantize_qat(model, ema_fakequant_qconfig) # Or Calibration: # quantize_qat(model, calibration_qconfig)
将
Float
模型转换为QFloat
模型后,加载预训练Float
模型保存的检查点进行微调 / 校准:if args.checkpoint: logger.info("Load pretrained weights from %s", args.checkpoint) ckpt = mge.load(args.checkpoint) ckpt = ckpt["state_dict"] if "state_dict" in ckpt else ckpt model.load_state_dict(ckpt, strict=False) # Fine-tune / Calibrate with new traced train_func # Save checkpoints
完整代码见:
- 调用 quantize 方法将
QFloat
模型转换为Q
模型,也就是可用于模型部署的量化模型。
需要在推理的方法中设置 trace
的 capture_as_const=True
,以进行模型导出:
from megengine.quantization import quantize
@jit.trace(capture_as_const=True)
def infer_func(processed_img):
model.eval()
logits = model(processed_img)
probs = F.softmax(logits)
return probs
quantize(model)
processed_img = transform.apply(image)[np.newaxis, :]
processed_img = processed_img.astype("int8")
probs = infer_func(processed_img)
infer_func.dump(output_file, arg_names=["data"])
调用了 quantize
后,model
就从 QFloat
模型转换为了 Q
模型,之后便使用这个 Quantized
模型进行推理。
调用 dump
方法将模型导出,便得到了一个可用于部署的量化模型。
完整代码见:
MegEngine Python
层模块串讲系列到这里就结束了,我们介绍了用户在使用 MegEngine
时主要会接触到的 python
层的各个模块的主要功能、结构以及使用方法,此外还有一些原理性的介绍。对于各模块具体实现感兴趣的读者可以参考 MegEngine 官方文档 和 github。之后的文章我们会对 MegEngine
开发相关工具以及 MegEngine
底层的实现做更深入的介绍。