Post training 4-bit quantization of convolutional networks for rapid-deployment

一、摘要

  • 介绍了三种方法,用于CNN模型的超低比特量化(4bits)和比特数自动选择。
  • Analytical Clipping for Integer Quantization(ACIQ),一种阶段阈值选择方法。
  • Per-channel bit allocation,一种对feature map各个channel实现不同比特量化的方法
  • bias-correction,一种偏移修正方法, 用以提高量化后的精度

二 Analytical Clipping for Integer Quantization (ACIQ)

ACIQ是一种量化阈值选择方法。对于后量化而言,最直接的方法是等价量化,不做截取,但这样损失较为严重(原始连续分布长尾太大)。通常需要找到一个阈值T用来截断,不在区间[-T, T]的值截取为-T或T。这样可以有效提高后量化精度。所以,问题就转变为,如何选择较好的截取值T,传统方法用KL散度,遍历可能的截取值不断计算量化前后分布的KL散度然后选取KL散度最小的T值作为截取值。论文中提出的ACIQ即一种基于优化思想的阈值T选取方法。
该方法用于激活值的量化。

首先,对于一个tensor(feature map),ACIQ假设该tensor的分布服从两种可能:拉普拉斯分布或高斯分布。量化过程就是将服从该分布的tensor中的值量化到离散区间中。其中M表示比特数。
ACIQ定义原始浮点分布的密度函数
,截取值 以及量化函数
。所以,量化前和量化后的L2 loss就等于:

image.png

而整个量化问题就被转变为:求解
,使得上述的loss值最小。

从上述表达式不难看出,量化损失一共分为三段:负无穷到截断产生的误差, 到
之间的round量化误差,以及 到正无穷的截断误差。论文用可导函数来表示各个阶段的误差进而方便求解。论文正文里以tensor服从拉普拉斯分布的情况进行推导。
量化误差如下:

image.png

截断误差如下:
image.png

所以,最终的整体量化损失如下:
image.png

此时,量化函数被成功的转换成了一个可以求导的连续函数,只需要对其求偏导,就可以得到使量化误差最小的截断值:
image.png

其中, 为截取值,
为拉普拉斯分布的参数。M为量化后的比特数。最后,求解公式在M = 2,3,4时, ,T分别取值2.83, 3.89, 5.03。

上述即是ACIQ的核心原理,利用优化的思路来求解量化过程截断值进而最小化量化损失。注意ACIQ有一个较强的先验假设,即tensor的数据分布要符合拉普拉斯分布或高斯分布(高斯分布的截取值计算在论文的附页中)。
https://github.com/submission2019/cnn-quantization


        print("=> using pre-trained model '{}'".format(args.arch))
        if args.arch == 'shufflenet':
            import models.ShuffleNet as shufflenet
            self.model = shufflenet.ShuffleNet(groups=8)
            params = torch.load('ShuffleNet_1g8_Top1_67.408_Top5_87.258.pth.tar')
            self.model = torch.nn.DataParallel(self.model, args.device_ids)
            self.model.load_state_dict(params)
        else:
            self.model = models.__dict__[args.arch](pretrained=True)

        set_node_names(self.model)
        # Mark layers before relue for fusing
        if 'resnet' in args.arch:
            resnet_mark_before_relu(self.model)

        # BatchNorm folding
        if 'resnet' in args.arch or args.arch == 'vgg16_bn' or args.arch == 'inception_v3':
            print("Perform BN folding")
            search_absorbe_bn(self.model)
            QM().bn_folding = Trueself.model.to(args.device)
        QM().quantize_model(self.model)

        if args.device_ids and len(args.device_ids) > 1 and args.arch != 'shufflenet' and args.arch != 'mobilenetv2':
            if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
                self.model.features = torch.nn.DataParallel(self.model.features, args.device_ids)
            else:
                self.model = torch.nn.DataParallel(self.model, args.device_ids)

        # define loss function (criterion) and optimizer
        self.criterion = nn.CrossEntropyLoss()
        self.criterion.to(args.device)

        cudnn.benchmark = True
        # Data loading code
        valdir = os.path.join(args.data, 'val')

        if args.arch not in models.__dict__ and args.arch in pretrainedmodels.model_names:
            dataparallel = args.device_ids is not None and len(args.device_ids) > 1
            tfs = [mutils.TransformImage(self.model.module if dataparallel else self.model)]
        else:
            normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                             std=[0.229, 0.224, 0.225])
            resize = 256 if args.arch != 'inception_v3' else 299
            crop_size = 224 if args.arch != 'inception_v3' else 299
            tfs = [
                transforms.Resize(resize),
                transforms.CenterCrop(crop_size),
                transforms.ToTensor(),
                normalize,
            ]

        self.val_loader = torch.utils.data.DataLoader(
            datasets.ImageFolder(valdir, transforms.Compose(tfs)),
            batch_size=args.batch_size, shuffle=(True if (args.kld_threshold or args.aciq_cal or args.shuffle) else False),
            num_workers=args.workers, pin_memory=True)  


def run(self):
        if args.eval_precision:
            elog = EvalLog(['dtype', 'val_prec1', 'val_prec5'])
            print("\nFloat32 no quantization")
            QM().disable()
            val_loss, val_prec1, val_prec5 = validate(self.val_loader, self.model, self.criterion)
            elog.log('fp32', val_prec1, val_prec5)
            logging.info('\nValidation Loss {val_loss:.4f} \t'
                         'Validation Prec@1 {val_prec1:.3f} \t'
                         'Validation Prec@5 {val_prec5:.3f} \n'
                         .format(val_loss=val_loss, val_prec1=val_prec1, val_prec5=val_prec5))
            print("--------------------------------------------------------------------------")

            for q in [8, 7, 6, 5, 4]:
                args.qtype = 'int{}'.format(q)
                print("\nQuantize to %s" % args.qtype)
                QM().quantize = True
                QM().reload(args, get_params())
                val_loss, val_prec1, val_prec5 = validate(self.val_loader, self.model, self.criterion)
                elog.log(args.qtype, val_prec1, val_prec5)
                logging.info('\nValidation Loss {val_loss:.4f} \t'
                             'Validation Prec@1 {val_prec1:.3f} \t'
                             'Validation Prec@5 {val_prec5:.3f} \n'
                             .format(val_loss=val_loss, val_prec1=val_prec1, val_prec5=val_prec5))
                print("--------------------------------------------------------------------------")
            print(elog)
            elog.save('results/precision/%s_%s_clipping.csv' % (args.arch, args.threshold))
        elif args.custom_test:
            log_name = 'results/custom_test/%s_max_mse_%s_cliping_layer_selection.csv' % (args.arch, args.threshold)
            elog = EvalLog(['num_8bit_layers', 'indexes', 'val_prec1', 'val_prec5'], log_name, auto_save=True)
            for i in range(len(max_mse_order_id)+1):
                _8bit_layers = ['conv0_activation'] + max_mse_order_id[0:i]
                print("it: %d, 8 bit layers: %d" % (i, len(_8bit_layers)))
                QM().set_8bit_list(_8bit_layers)
                val_loss, val_prec1, val_prec5 = validate(self.val_loader, self.model, self.criterion)
                elog.log(i+1, str(_8bit_layers), val_prec1, val_prec5)
            print(elog)
        else:
            val_loss, val_prec1, val_prec5 = validate(self.val_loader, self.model, self.criterion)
            if self.ml_logger is not None and self.ml_logger.mlflow.active_run() is not None:
                self.ml_logger.mlflow.log_metric('top1', val_prec1)
                self.ml_logger.mlflow.log_metric('top5', val_prec5)
                self.ml_logger.mlflow.log_metric('loss', val_loss)

            return val_loss, val_prec1, val_prec5

三、Per-channel bit-allocation

Per-channel bit-allocation核心思想是允许一个tensor中的各个channel的量化bits不相同(channel1可能用4bits量化;channel2可能用5bits量化,channel3可能用3bits量化),并找到每个channel的最佳量化bits。同时要求平均每个channel的量化bits值为4。

首先,该方法借用了ACIQ中对连续密度函数和量化后的离散分布之间的L2误差的定义,并在该定义的基础上,

  • 引入各个channel的量化比特数作为限定条件。
  • 并引入拉格朗日乘子

得到channel变化bits时的量化损失表达式。


image.png

其中,

  • 第一项为ACIQ中的量化loss表达式,
  • 第二项表示拉格朗日乘子引入的约束损失。
  • 表示第i个channel的量化比特,B表示所有channel的量化间隔总和。

对拉格朗日表达式求偏导,得:


image.png

最终,我们可以得到**各个channel的最佳量化bit和原始浮点数据分布的关系如下:最终,我们可以得到各个channel的最佳量化bit和原始浮点数据分布的关系如下:


image.png

总体来说,该方法延续了ACIQ的优化方法求解量化问题思想,从算法角度来看可以在有限的bits量化时通过灵活调整各个channel的量化比特数,达到量化损失最小的情况。但在实际应用中,如此各个channel分比特量化必须要配合非常特殊的硬件加速实现,实际应用价值值得商榷。

四、Bias-Correction

Bias-Correction方法主要用于对weight的量化,作者观察到量化前后权重分布的均值和方差纯在固有的偏差,该方法即通过一种简单的方法补偿weight量化前后偏移的mean和var。

image.png

五、实验

作者选用了常见的几种分类模型,进行了组合/消融实验。同时也分别进行了Weight和feature map用不同比特数量化的实验结果对比:


image.png

你可能感兴趣的:(Post training 4-bit quantization of convolutional networks for rapid-deployment)