一、摘要
- 介绍了三种方法,用于CNN模型的超低比特量化(4bits)和比特数自动选择。
- Analytical Clipping for Integer Quantization(ACIQ),一种阶段阈值选择方法。
- Per-channel bit allocation,一种对feature map各个channel实现不同比特量化的方法
- bias-correction,一种偏移修正方法, 用以提高量化后的精度
二 Analytical Clipping for Integer Quantization (ACIQ)
ACIQ是一种量化阈值选择方法。对于后量化而言,最直接的方法是等价量化,不做截取,但这样损失较为严重(原始连续分布长尾太大)。通常需要找到一个阈值T用来截断,不在区间[-T, T]的值截取为-T或T。这样可以有效提高后量化精度。所以,问题就转变为,如何选择较好的截取值T,传统方法用KL散度,遍历可能的截取值不断计算量化前后分布的KL散度然后选取KL散度最小的T值作为截取值。论文中提出的ACIQ即一种基于优化思想的阈值T选取方法。
该方法用于激活值的量化。
首先,对于一个tensor(feature map),ACIQ假设该tensor的分布服从两种可能:拉普拉斯分布或高斯分布。量化过程就是将服从该分布的tensor中的值量化到离散区间中。其中M表示比特数。
ACIQ定义原始浮点分布的密度函数
,截取值 以及量化函数
。所以,量化前和量化后的L2 loss就等于:
而整个量化问题就被转变为:求解
,使得上述的loss值最小。
从上述表达式不难看出,量化损失一共分为三段:负无穷到截断产生的误差, 到
之间的round量化误差,以及 到正无穷的截断误差。论文用可导函数来表示各个阶段的误差进而方便求解。论文正文里以tensor服从拉普拉斯分布的情况进行推导。
量化误差如下:
截断误差如下:
所以,最终的整体量化损失如下:
此时,量化函数被成功的转换成了一个可以求导的连续函数,只需要对其求偏导,就可以得到使量化误差最小的截断值:
其中, 为截取值,
为拉普拉斯分布的参数。M为量化后的比特数。最后,求解公式在M = 2,3,4时, ,T分别取值2.83, 3.89, 5.03。
上述即是ACIQ的核心原理,利用优化的思路来求解量化过程截断值进而最小化量化损失。注意ACIQ有一个较强的先验假设,即tensor的数据分布要符合拉普拉斯分布或高斯分布(高斯分布的截取值计算在论文的附页中)。
https://github.com/submission2019/cnn-quantization
print("=> using pre-trained model '{}'".format(args.arch))
if args.arch == 'shufflenet':
import models.ShuffleNet as shufflenet
self.model = shufflenet.ShuffleNet(groups=8)
params = torch.load('ShuffleNet_1g8_Top1_67.408_Top5_87.258.pth.tar')
self.model = torch.nn.DataParallel(self.model, args.device_ids)
self.model.load_state_dict(params)
else:
self.model = models.__dict__[args.arch](pretrained=True)
set_node_names(self.model)
# Mark layers before relue for fusing
if 'resnet' in args.arch:
resnet_mark_before_relu(self.model)
# BatchNorm folding
if 'resnet' in args.arch or args.arch == 'vgg16_bn' or args.arch == 'inception_v3':
print("Perform BN folding")
search_absorbe_bn(self.model)
QM().bn_folding = Trueself.model.to(args.device)
QM().quantize_model(self.model)
if args.device_ids and len(args.device_ids) > 1 and args.arch != 'shufflenet' and args.arch != 'mobilenetv2':
if args.arch.startswith('alexnet') or args.arch.startswith('vgg'):
self.model.features = torch.nn.DataParallel(self.model.features, args.device_ids)
else:
self.model = torch.nn.DataParallel(self.model, args.device_ids)
# define loss function (criterion) and optimizer
self.criterion = nn.CrossEntropyLoss()
self.criterion.to(args.device)
cudnn.benchmark = True
# Data loading code
valdir = os.path.join(args.data, 'val')
if args.arch not in models.__dict__ and args.arch in pretrainedmodels.model_names:
dataparallel = args.device_ids is not None and len(args.device_ids) > 1
tfs = [mutils.TransformImage(self.model.module if dataparallel else self.model)]
else:
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
resize = 256 if args.arch != 'inception_v3' else 299
crop_size = 224 if args.arch != 'inception_v3' else 299
tfs = [
transforms.Resize(resize),
transforms.CenterCrop(crop_size),
transforms.ToTensor(),
normalize,
]
self.val_loader = torch.utils.data.DataLoader(
datasets.ImageFolder(valdir, transforms.Compose(tfs)),
batch_size=args.batch_size, shuffle=(True if (args.kld_threshold or args.aciq_cal or args.shuffle) else False),
num_workers=args.workers, pin_memory=True)
def run(self):
if args.eval_precision:
elog = EvalLog(['dtype', 'val_prec1', 'val_prec5'])
print("\nFloat32 no quantization")
QM().disable()
val_loss, val_prec1, val_prec5 = validate(self.val_loader, self.model, self.criterion)
elog.log('fp32', val_prec1, val_prec5)
logging.info('\nValidation Loss {val_loss:.4f} \t'
'Validation Prec@1 {val_prec1:.3f} \t'
'Validation Prec@5 {val_prec5:.3f} \n'
.format(val_loss=val_loss, val_prec1=val_prec1, val_prec5=val_prec5))
print("--------------------------------------------------------------------------")
for q in [8, 7, 6, 5, 4]:
args.qtype = 'int{}'.format(q)
print("\nQuantize to %s" % args.qtype)
QM().quantize = True
QM().reload(args, get_params())
val_loss, val_prec1, val_prec5 = validate(self.val_loader, self.model, self.criterion)
elog.log(args.qtype, val_prec1, val_prec5)
logging.info('\nValidation Loss {val_loss:.4f} \t'
'Validation Prec@1 {val_prec1:.3f} \t'
'Validation Prec@5 {val_prec5:.3f} \n'
.format(val_loss=val_loss, val_prec1=val_prec1, val_prec5=val_prec5))
print("--------------------------------------------------------------------------")
print(elog)
elog.save('results/precision/%s_%s_clipping.csv' % (args.arch, args.threshold))
elif args.custom_test:
log_name = 'results/custom_test/%s_max_mse_%s_cliping_layer_selection.csv' % (args.arch, args.threshold)
elog = EvalLog(['num_8bit_layers', 'indexes', 'val_prec1', 'val_prec5'], log_name, auto_save=True)
for i in range(len(max_mse_order_id)+1):
_8bit_layers = ['conv0_activation'] + max_mse_order_id[0:i]
print("it: %d, 8 bit layers: %d" % (i, len(_8bit_layers)))
QM().set_8bit_list(_8bit_layers)
val_loss, val_prec1, val_prec5 = validate(self.val_loader, self.model, self.criterion)
elog.log(i+1, str(_8bit_layers), val_prec1, val_prec5)
print(elog)
else:
val_loss, val_prec1, val_prec5 = validate(self.val_loader, self.model, self.criterion)
if self.ml_logger is not None and self.ml_logger.mlflow.active_run() is not None:
self.ml_logger.mlflow.log_metric('top1', val_prec1)
self.ml_logger.mlflow.log_metric('top5', val_prec5)
self.ml_logger.mlflow.log_metric('loss', val_loss)
return val_loss, val_prec1, val_prec5
三、Per-channel bit-allocation
Per-channel bit-allocation核心思想是允许一个tensor中的各个channel的量化bits不相同(channel1可能用4bits量化;channel2可能用5bits量化,channel3可能用3bits量化),并找到每个channel的最佳量化bits。同时要求平均每个channel的量化bits值为4。
首先,该方法借用了ACIQ中对连续密度函数和量化后的离散分布之间的L2误差的定义,并在该定义的基础上,
- 引入各个channel的量化比特数作为限定条件。
- 并引入拉格朗日乘子
得到channel变化bits时的量化损失表达式。
其中,
- 第一项为ACIQ中的量化loss表达式,
- 第二项表示拉格朗日乘子引入的约束损失。
- 表示第i个channel的量化比特,B表示所有channel的量化间隔总和。
对拉格朗日表达式求偏导,得:
最终,我们可以得到**各个channel的最佳量化bit和原始浮点数据分布的关系如下:最终,我们可以得到各个channel的最佳量化bit和原始浮点数据分布的关系如下:
总体来说,该方法延续了ACIQ的优化方法求解量化问题思想,从算法角度来看可以在有限的bits量化时通过灵活调整各个channel的量化比特数,达到量化损失最小的情况。但在实际应用中,如此各个channel分比特量化必须要配合非常特殊的硬件加速实现,实际应用价值值得商榷。
四、Bias-Correction
Bias-Correction方法主要用于对weight的量化,作者观察到量化前后权重分布的均值和方差纯在固有的偏差,该方法即通过一种简单的方法补偿weight量化前后偏移的mean和var。
五、实验
作者选用了常见的几种分类模型,进行了组合/消融实验。同时也分别进行了Weight和feature map用不同比特数量化的实验结果对比: