利用torch.fx进行后量化

torch.fx 量化支持——FX GRAPH MODE QUANTIZATION

torch.fx目前支持的量化方式:

  • Post Training Quantization
    • Weight Only Quantization
    • Dynamic Quantization
    • Static Quantization
  • Quantization Aware Training
    • Static Quantization

其中,Post Training Quantization中的Static Quantization和Dynamic Quantization提供了demo。

与Eager模式对比

简单来说,fx提供一个Graph模式:

  • 可以自动插入量化节点(如quantize和dequantize),不需要手动修改当前的network及forward
  • 这个模式下可以看到forward是怎么被自动构建的,可以进行更精细的调整

Graph模式

局限:只有可以转换为符号的部分(symbolically traceable)可以被量化,Data dependent control flow是不支持的。如果模型有些部分无法被符号化,则量化只能在模型的部分上工作,不能被符号化的部分会被跳过。

如果需要这些部分被量化:

  • 重写代码让这些部分symbolically traceable
  • 将这些部分转换成observed和quantized的子模块

相关的具体操作见(PROTOTYPE) FX GRAPH MODE QUANTIZATION USER GUIDE。

训练后量化尝试

环境准备:

import torch
import copy
from torch.quantization import get_default_qconfig
from torch.quantization.quantize_fx import prepare_fx, convert_fx, fuse_fx

步骤

  1. 准备训练好的权重、数据及网络模型
  2. 初始化网络,加载训练好的权重(一般用copy.deepcopy保留原始模型),并将其置于eval模式:
float_model = load_model(saved_model_dir + float_model_file).to("cpu")
float_model.eval()
model_to_quantize = copy.deepcopy(float_model)
model_to_quantize.eval()
  1. 指定量化模型的qconfig_dict
qconfig = get_default_qconfig("fbgemm")
qconfig_dict = {"": qconfig}

qconfig是QConfig的一个实例,QConfig这个类就是维护了两个observer,一个是activation所使用的observer,一个是op权重所使用的observer。

backend activation weight
fbgemm (x86) HistogramObserver (reduce_range=True) PerChannelMinMaxObserver (default_per_channel_weight_observer)
qnnpack (arm) HistogramObserver (reduce_range=False) MinMaxObserver (default_weight_observer)
default MinMaxObserver (default_observer) MinMaxObserver (default_weight_observer)
  1. 准备模型并打印模型:
prepared_model = prepare_fx(model_to_quantize, qconfig_dict)
print(prepared_model.graph)
  1. 模型较准
def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)
calibrate(prepared_model, data_loader_test)  # run calibration on sample data
  1. 量化模型
quantized_model = convert_fx(prepared_model)
print(quantized_model)
  1. 对比量化前后,评估量化效果,包括模型大小、性能、时延等

你可能感兴趣的:(python,算法,深度学习,pytorch,深度学习,人工智能)