PyTorch模型量化(二)- FX Graph模式的量化方法介绍

Introduction

由于最近项目需要,研究和学习PyTorch PTQ和QAT 量化的使用。比较新的PyTorch版本目前推荐使用FX Graph Mode Quantization


FX Graph 模式量化Demo演示使用

Post-Training-Quantization (PTQ) 静态量化的主要流程:
PyTorch FX Graph模式进行量化的主要流程 step1 ~ step4:

  • step1: 设置,选择量化方式 : 比如逐通道/layer QScheme, 量化之后的值域表示范围(Qmin, Qmax)
  • step2: prepare_fx:
    * a) 将输入的模型(nn.Module)转为GraphModule (IR转换)
    * b) Graph子图,op融合(比如conv+relu --> convReLu)
    * c) 在Conv, Linear等OP前后插入Observer, 用于收集激活值Feature map的特征(范围)
  • step3: 喂数据,进行Activation标定
  • step4: 计算Weight和Activation量化参数 (比如 scale, zero_point), 模型FP32 --> INT8
  • step5: 验证INT8 量化之后模型的精度
from ctypes import util
from torchvision.models import resnet18, resnet50
import torch
from torch.ao.quantization import quantize_fx, get_default_qconfig
import os
import copy
import utils


def calibrate(model, data_loader, num_batch, device):
    utils.evaluate(model=model, data_loader=data_loader, neval_batches=num_batch, n_print=1, device=device)


if __name__ == '__main__':
    device = torch.device('cuda', 0)
    eval_batch_size = 32
    imagenet_data='/media/wei/Document/ImageNet/ILSVRC2012'

    model_fp = resnet50(pretrained=True, progress=True).to(device)
    model_fp.eval()

    _, test_dataloader = utils.prepare_dataloader(data_path=imagenet_data, eval_batch_size=eval_batch_size, num_workers=8)
    utils.evaluate(model=model_fp, criterion=None, data_loader=test_dataloader, device=device)
    # ResNet-18: Tested on imagenet-val: batch:3125 Acc@1  56.25 ( 69.76), Acc@5  75.00 ( 89.08)
    # ResNet-50: batch:1560 Acc@1  59.38 ( 76.18), Acc@5  90.62 ( 92.87)

    # torch quantization
    model_prepare = copy.deepcopy(model_fp)
    model_prepare.eval()

    # 设置量化方式
    qconfig_dict = {"", get_default_qconfig(backend='fbgemm')}
    model_prepare = quantize_fx.prepare_fx(model=model_prepare, qconfig_dict=qconfig_dict)
    model_prepare.eval()

    # 标定,确定Activation的量化范围
    calibrate(model_prepare, test_dataloader, 10, device)

    # 根据之前设置的量化方式以及标定计算的参数, 进行模型转换, FP32--> INT8
    quantized_model = quantize_fx.convert_fx(graph_module=model_prepare)
    quantized_model.eval()

    # 测试量化之后模型的精度
    utils.evaluate(quantized_model, data_loader=test_dataloader)

得益于PyTorch FX Graph Quantization API的精简设计, 我们只需要很少的代码以及修改就可以实现量化, 激动!!!, 接下来我们一探FX Graph 量化背后的具体实现原理。

下面逐一分析FX Graph 量化的过程

PyTorch FX Graph量化——Step1. 量化方式的配置选择

这里是pytorch默认的PTQ量化配置, 'fbgemm' --- 这是一个矩阵计算的库,支持server 端x86 CPU 的 Int8 Conv, Linear等OP。

qconfig_dict = {"", get_default_qconfig(backend='fbgemm')}

def get_default_qconfig(backend='fbgemm'):
    """
    Returns the default PTQ qconfig for the specified backend.

    Args:
      * `backend`: a string representing the target backend. Currently supports `fbgemm`
        and `qnnpack`.

    Return:
        qconfig
    """

    if backend == 'fbgemm':
        qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True),
                          weight=default_per_channel_weight_observer)
    elif backend == 'qnnpack':
        qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=False),
                          weight=default_weight_observer)
    else:
        qconfig = default_qconfig
    return qconfig
default_per_channel_weight_observer = PerChannelMinMaxObserver.with_args(
    dtype=torch.qint8, qscheme=torch.per_channel_symmetric
)

我们发现qconfig包含2部分: 分别对weight, 以及activation的量化方式的配置, 其中 activation采用 HistogramObserver 基于直方图统计的逐tensor/layer非对称量化方式, Weight采用PerChannelMinMaxObserver 逐channel对称量化方式。

Why ? 为什么Activation和Weight的量化方式不同?

  1. Weight的量化方式:
  • weight中元素的分布和activation有所不同: 因为weight一般都是均值为0, 左右对称的Gaussian分布, 因此采用对称量化
  • 为了减少量化OP中的计算量, 因为对称量化的zero_point=0

参考高通AI的量化白皮书介绍:

image.png

Observer的作用

总的来说Observer是用于观测数据分布, 计算量化参数 scale, zero_point. 接下来从代码进行解析.
分析 PerChannelMinxMaxObserver

class PerChannelMinMaxObserver(_ObserverBase):
    r"""Observer module for computing the quantization parameters based on the
    running per channel min and max values.

    This observer uses the tensor min/max statistics to compute the per channel
    quantization parameters. The module records the running minimum and maximum
    of incoming tensors, and uses this statistic to compute the quantization
    parameters.

    Args:
        ch_axis: Channel axis
        dtype: Quantized data type
        qscheme: Quantization scheme to be used
        reduce_range: Reduces the range of the quantized data type by 1 bit
        quant_min: Minimum quantization value. If unspecified, it will follow the 8-bit setup.
        quant_max: Maximum quantization value. If unspecified, it will follow the 8-bit setup.
        memoryless: Boolean that controls whether observer removes old data when a new input is seen.
                    This is most useful for simulating dynamic quantization, especially during QAT.

    The quantization parameters are computed the same way as in
    :class:`~torch.ao.quantization.observer.MinMaxObserver`, with the difference
    that the running min/max values are stored per channel.
    Scales and zero points are thus computed per channel as well.

    .. note:: If the running minimum equals to the running maximum, the scales
              and zero_points are set to 1.0 and 0.
    """
    min_val: torch.Tensor
    max_val: torch.Tensor

    def __init__(
        self,
        ch_axis=0,
        dtype=torch.quint8,
        qscheme=torch.per_channel_affine,
        reduce_range=False,
        quant_min=None,
        quant_max=None,
        factory_kwargs=None,
        memoryless=False,
    ) -> None:
        super(PerChannelMinMaxObserver, self).__init__(
            dtype=dtype,
            qscheme=qscheme,
            reduce_range=reduce_range,
            quant_min=quant_min,
            quant_max=quant_max,
            factory_kwargs=factory_kwargs,
        )
        self.memoryless = memoryless
        factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
        self.ch_axis = ch_axis
        self.register_buffer("min_val", torch.tensor([], **factory_kwargs))
        self.register_buffer("max_val", torch.tensor([], **factory_kwargs))
        if (
            self.qscheme == torch.per_channel_symmetric
            and self.reduce_range
            and self.dtype == torch.quint8
        ):
            raise NotImplementedError(
                "Cannot reduce range for symmetric quantization for quint8"
            )

    def forward(self, x_orig):
        return self._forward(x_orig)

    def _forward(self, x_orig):
        if x_orig.numel() == 0:
            return x_orig
        x = x_orig.detach()  # avoid keeping autograd tape
        min_val = self.min_val
        max_val = self.max_val
        x_dim = x.size()

        new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
        new_axis_list[self.ch_axis] = 0
        new_axis_list[0] = self.ch_axis
        y = x.permute(new_axis_list)
        # Need to match dtype of min/max because the updates to buffers
        # are done in place and types need to match for comparisons
        y = y.to(self.min_val.dtype)
        y = torch.flatten(y, start_dim=1)
        if min_val.numel() == 0 or max_val.numel() == 0 or self.memoryless:
            min_val, max_val = torch.aminmax(y, dim=1)
        else:
            min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
            min_val = torch.min(min_val_cur, min_val)
            max_val = torch.max(max_val_cur, max_val)
        self.min_val.resize_(min_val.shape)
        self.max_val.resize_(max_val.shape)
        self.min_val.copy_(min_val)
        self.max_val.copy_(max_val)
        return x_orig

    @torch.jit.export
    def calculate_qparams(self):
        return self._calculate_qparams(self.min_val, self.max_val)

为了计算量化所需的参数, PyTorch定义了一系列的Observer, 比如MinMaxObserver, MovingAveragingMinMaxObserver等等, 所有这些XXXObserver都继承自一个基类,在基类的Observer中主要定义了以下2个重要的函数:
我们发现Observer中主要的2个函数:

  • forward(self, x_orig): 观测weight中元素的最大,最小值
  • calculate_qparams(self): 计算scale, zero_point

forward(self, x_orig) 函数的功能实现:

  • 输入: x_orig: 也就是weight tensor, 一般CNN的weight的shape为: Oc * Ic * Kh * Kw 4D Tensor
  • 输出/结果: 观测到的最大,最小值

在实例化Observer对象的时候, init() 函数中的一个参数 ch_axis=0 用于指定channel维度, ch_axis=0说明Observer观测的是weight的 Oc (output_channels) 方向的最大和最小值。 观测最大、最小值的核心代码:
min_val, max_val = torch.aminmax(y, dim=1)
因为Oc的在axis=0的维度上, 因此aminmax(dim=1)对axis=1的维度上进行了规约reduction, 得到了Oc个 min, max_val, 即Weight的每个output_channel包含一个scale, zero_point

    def _forward(self, x_orig):
        if x_orig.numel() == 0:
            return x_orig
        x = x_orig.detach()  # avoid keeping autograd tape
        min_val = self.min_val
        max_val = self.max_val
        x_dim = x.size()

        new_axis_list = [i for i in range(len(x_dim))]  # noqa: C416
        new_axis_list[self.ch_axis] = 0
        new_axis_list[0] = self.ch_axis
        y = x.permute(new_axis_list)
        # Need to match dtype of min/max because the updates to buffers
        # are done in place and types need to match for comparisons
        y = y.to(self.min_val.dtype)
        y = torch.flatten(y, start_dim=1)
        if min_val.numel() == 0 or max_val.numel() == 0 or self.memoryless:
            min_val, max_val = torch.aminmax(y, dim=1)
        else:
            min_val_cur, max_val_cur = torch.aminmax(y, dim=1)
            min_val = torch.min(min_val_cur, min_val)
            max_val = torch.max(max_val_cur, max_val)
        self.min_val.resize_(min_val.shape)
        self.max_val.resize_(max_val.shape)
        self.min_val.copy_(min_val)
        self.max_val.copy_(max_val)
        return x_orig

calculate_qparams 函数的功能实现

很容易理解这个函数是用于计算量化参数: scale & zero_point (对于线性量化)的, 下面分析代码实现:

  • 输入: 观测得到的 max_val, min_val, 以及定义好的qmax, qmin
  • 输出: 计算得到的scale, zero_point
    def _calculate_qparams(
        self, min_val: torch.Tensor, max_val: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        r"""Calculates the quantization parameters, given min and max
        value tensors. Works for both per tensor and per channel cases

        Args:
            min_val: Minimum values per channel
            max_val: Maximum values per channel

        Returns:
            scales: Scales tensor of shape (#channels,)
            zero_points: Zero points tensor of shape (#channels,)
        """
        if not check_min_max_valid(min_val, max_val):
            return torch.tensor([1.0], device=min_val.device.type), torch.tensor([0], device=min_val.device.type)

        quant_min, quant_max = self.quant_min, self.quant_max
        min_val_neg = torch.min(min_val, torch.zeros_like(min_val))
        max_val_pos = torch.max(max_val, torch.zeros_like(max_val))

        device = min_val_neg.device
        scale = torch.ones(min_val_neg.size(), dtype=torch.float32, device=device)
        zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)

        if (
            self.qscheme == torch.per_tensor_symmetric
            or self.qscheme == torch.per_channel_symmetric
        ):
            max_val_pos = torch.max(-min_val_neg, max_val_pos)
            scale = max_val_pos / (float(quant_max - quant_min) / 2)
            scale = torch.max(scale, self.eps)
            if self.dtype == torch.quint8:
                if self.has_customized_qrange:
                    # When customized quantization range is used, down-rounded midpoint of the range is chosen.
                    zero_point = zero_point.new_full(
                        zero_point.size(), (quant_min + quant_max) // 2
                    )
                else:
                    zero_point = zero_point.new_full(zero_point.size(), 128)
        elif self.qscheme == torch.per_channel_affine_float_qparams:
            scale = (max_val - min_val) / float(quant_max - quant_min)
            scale = torch.where(scale > self.eps, scale, torch.ones_like(scale))
            # We use the quantize function
            # xq = Round(Xf * inv_scale + zero_point),
            # setting zero_point to (-1 * min *inv_scale) we get
            # Xq = Round((Xf - min) * inv_scale)
            zero_point = -1 * min_val / scale
        else:
            scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
            scale = torch.max(scale, self.eps)
            zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
            zero_point = torch.clamp(zero_point, quant_min, quant_max)

        # For scalar values, cast them to Tensors of size 1 to keep the shape
        # consistent with default values in FakeQuantize.
        if len(scale.shape) == 0:
            # TODO: switch to scale.item() after adding JIT support
            scale = torch.tensor([float(scale)], dtype=scale.dtype, device=device)
        if len(zero_point.shape) == 0:
            # TODO: switch to zero_point.item() after adding JIT support
            zero_point = torch.tensor(
                [int(zero_point)], dtype=zero_point.dtype, device=device
            )
            if self.qscheme == torch.per_channel_affine_float_qparams:
                zero_point = torch.tensor(
                    [float(zero_point)], dtype=zero_point.dtype, device=device
                )

        return scale, zero_point

计算量化参数Scale , zero_point的核心代码

  • 对称量化 (symmetric Quantization)
max_val_pos = torch.max(-min_val_neg, max_val_pos)
scale = max_val_pos / (float(quant_max - quant_min) / 2)
zero_point = torch.zeros(min_val_neg.size(), dtype=torch.int64, device=device)
  • 非对称量化 (Affine Quantization)
scale = (max_val_pos - min_val_neg) / float(quant_max - quant_min)
scale = torch.max(scale, self.eps)
zero_point = quant_min - torch.round(min_val_neg / scale).to(torch.int)
zero_point = torch.clamp(zero_point, quant_min, quant_max)

以上分析了nn.Conv2d layer的weight的量化参数的计算过程以及PerChannelMinMaxObserver的实现过程。下面继续分析Activation的量化参数计算过程。


Activation的量化参数计算以及HistgramObserver分析

在选择量化设置的时候, 默认的backend=fbgemm中Activation采用 HistogramObserver, 即基于直方图分析的方式计算量化参数。

if backend == 'fbgemm':
        qconfig = QConfig(activation=HistogramObserver.with_args(reduce_range=True),
                          weight=default_per_channel_weight_observer)

HistogramObserver过程分析

  1. 初始化: init()
  • 默认bins=2048, 因为进行直方图统计需要设置一个bins代表直方图的统计区间,即把min_val到max_val区间划分2048份。
  • qscheme=per_tensor_affine, 即量化粒度采用逐tensor/layer 仿射量化, 逐tensor代表只有一个量化参数scale + zero_point, 而不是一组

class HistogramObserver(_ObserverBase):
    r"""
    The module records the running histogram of tensor values along with
    min/max values. ``calculate_qparams`` will calculate scale and zero_point.

    Args:
        bins: Number of bins to use for the histogram
        upsample_rate: Factor by which the histograms are upsampled, this is
                       used to interpolate histograms with varying ranges across observations
        dtype: Quantized data type
        qscheme: Quantization scheme to be used
        reduce_range: Reduces the range of the quantized data type by 1 bit

    The scale and zero point are computed as follows:

    1. Create the histogram of the incoming inputs.
        The histogram is computed continuously, and the ranges per bin change
        with every new tensor observed.
    2. Search the distribution in the histogram for optimal min/max values.
        The search for the min/max values ensures the minimization of the
        quantization error with respect to the floating point model.
    3. Compute the scale and zero point the same way as in the
        :class:`~torch.ao.quantization.MinMaxObserver`
    """
    histogram: torch.Tensor
    min_val: torch.Tensor
    max_val: torch.Tensor

    def __init__(
        self,
        bins: int = 2048,
        upsample_rate: int = 128,
        dtype: torch.dtype = torch.quint8,
        qscheme=torch.per_tensor_affine,
        reduce_range=False,
        quant_min=None,
        quant_max=None,
        factory_kwargs=None,
    ) -> None:
        # bins: The number of bins used for histogram calculation.
        super(HistogramObserver, self).__init__(
            dtype=dtype,
            qscheme=qscheme,
            reduce_range=reduce_range,
            quant_min=quant_min,
            quant_max=quant_max,
            factory_kwargs=factory_kwargs,
        )
        factory_kwargs = torch.nn.factory_kwargs(factory_kwargs)
        self.bins = bins
        self.register_buffer("histogram", torch.zeros(self.bins, **factory_kwargs))
        self.register_buffer("min_val", torch.tensor(float("inf"), **factory_kwargs))
        self.register_buffer("max_val", torch.tensor(float("-inf"), **factory_kwargs))
        self.dst_nbins = 2 ** torch.iinfo(self.dtype).bits
        self.upsample_rate = upsample_rate
  1. 对Activation的 Tensor进行统计观察
    def forward(self, x_orig: torch.Tensor) -> torch.Tensor:
        if x_orig.numel() == 0:
            return x_orig
        x = x_orig.detach()
        min_val = self.min_val
        max_val = self.max_val
        same_values = min_val.item() == max_val.item()
        is_uninitialized = min_val == float("inf") and max_val == float("-inf")
        if is_uninitialized or same_values:
            min_val, max_val = torch.aminmax(x)
            self.min_val.resize_(min_val.shape)
            self.min_val.copy_(min_val)
            self.max_val.resize_(max_val.shape)
            self.max_val.copy_(max_val)
            assert (
                min_val.numel() == 1 and max_val.numel() == 1
            ), "histogram min/max values must be scalar."
            torch.histc(
                x, self.bins, min=int(min_val), max=int(max_val), out=self.histogram
            )
        else:
            new_min, new_max = torch.aminmax(x)
            combined_min = torch.min(new_min, min_val)
            combined_max = torch.max(new_max, max_val)
            # combine the existing histogram and new histogram into 1 histogram
            # We do this by first upsampling the histogram to a dense grid
            # and then downsampling the histogram efficiently
            (
                combined_min,
                combined_max,
                downsample_rate,
                start_idx,
            ) = self._adjust_min_max(combined_min, combined_max, self.upsample_rate)
            assert (
                combined_min.numel() == 1 and combined_max.numel() == 1
            ), "histogram min/max values must be scalar."
            combined_histogram = torch.histc(
                x, self.bins, min=int(combined_min), max=int(combined_max)
            )
            if combined_min == min_val and combined_max == max_val:
                combined_histogram += self.histogram
            else:
                combined_histogram = self._combine_histograms(
                    combined_histogram,
                    self.histogram,
                    self.upsample_rate,
                    downsample_rate,
                    start_idx,
                    self.bins,
                )

            self.histogram.detach_().resize_(combined_histogram.shape)
            self.histogram.copy_(combined_histogram)
            self.min_val.detach_().resize_(combined_min.shape)
            self.min_val.copy_(combined_min)
            self.max_val.detach_().resize_(combined_max.shape)
            self.max_val.copy_(combined_max)
        return x_orig

你可能感兴趣的:(PyTorch模型量化(二)- FX Graph模式的量化方法介绍)