目录
神经网络量化----TensorRT深刻解读
前言
一、TensorRT简介
二、难点
1.架构
2.功能
三、实现
1.conv和ReLU的融合
2.conv和ReLU的融合
quant_utils.py
3.调用示例
总结
本文将聚焦于英伟达TensorRT训练后量化的算法。
论文地址为:https://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
代码地址为:官方好像没有公布代码,可以参考的有https://github.com/deepglint/EasyQuant,https://github.com//apache/incubator-mxnet/blob/master/python/mxnet/contrib/quantization.py
整个量化算法使用对称量化(-max, max)-> (-127, 127)
权重的量化所需的最值直接统计得出,激活值的量化使用饱和映射的方式,设置最佳的阈值作为最值来进行量化。
计算最佳阈值的方法:1. 统计激活值的直方图,2. 采用遍历的方法找到量化后KL散度最小时对应的最佳阈值。伪代码如下:
以上所给的参考代码语言分别为caffe和mxnet,caffe代码里有保存和加载文件的操作,对于量化后的参数不能直观地和网络架构联系在一起,mxnet代码里的功能较全,但是语言比较小众化,开发新的算法难度较大。
故计划采用pytorch来进行量化参数的计算和保存,且跟网络是一个整体,新建module即可,比较方便。
有的卷积后有激活和BN,所以还要将其考虑在内。
所以在处理之前需要将网络中的卷积、BN、ReLU融合在一起。
在量化时,不考虑是否有ReLU,全部量化在(-max, max)之间。(格林深瞳算法只考虑了带有ReLU的,即将(0,max)量化到(0,127)),这样就简化了运算,不用再分情况了。
from torch import nn
import torch
# the module that replace BN layer
class DummyModule(nn.Module):
def __init__(self):
super(DummyModule, self).__init__()
def forward(self, x):
return x
# BN flod
def bn_folding(conv, bn):
# ******************** BN parameter *********************
mean = bn.running_mean
std = torch.sqrt(bn.running_var + bn.eps)
gamma = bn.weight
beta = bn.bias
# ******************* conv parameter********************
w = conv.weight
w_fold = w.clone()
if conv.bias is not None:
b = conv.bias
else:
b = mean.new_zeros(mean.shape)
b_fold = b.clone()
w_fold = w * (gamma / std).reshape([conv.out_channels, 1, 1, 1])
b_fold = beta + (b - mean) * (gamma / std)
bnfold_conv = nn.Conv2d(conv.in_channels,
conv.out_channels,
conv.kernel_size,
conv.stride,
conv.padding,
groups=conv.groups,
bias=True)
bnfold_conv.weight.data = w_fold
bnfold_conv.bias.data = b_fold
return bnfold_conv
'''BN must be after convolution'''
def model_bn_folding(model):
children = list(model.named_children())
# children = list(model.named_modules())
#print(children)
name_temp = None
child_temp = None
for name, child in children:
#print(name, ' ', child)
if isinstance(child, nn.BatchNorm2d):
bnfold_conv = bn_folding(child_temp, child) # BN融合
model._modules[name_temp] = bnfold_conv
model._modules[name] = DummyModule()
child_temp = None
elif isinstance(child, nn.Conv2d):
name_temp = name
child_temp = child
else:
model_bn_folding(child)
return model
新建一个module将卷积和ReLU包含在内了。
import torch
from torch import nn
import torch.nn.functional as F
from quant_utils import ConvRelu, LinearRelu, DummyModule
# device = torch.device("cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
'''BN must be after convolution'''
def model_relu_folding(model):
children = list(model.named_children())
# children = list(model.named_modules())
#print(children)
name_temp = None
child_temp = None
is_conv = True
for name, child in children:
print(name, ' ', child)
if isinstance(child, nn.ReLU):
if is_conv:
model._modules[name_temp] = ConvRelu(child_temp, is_relu=1).to(device)
else:
model._modules[name_temp] = LinearRelu(child_temp, is_relu=1).to(device)
model._modules[name] = DummyModule().to(device)
# child_temp = None
# name_temp = None
elif isinstance(child, nn.Conv2d):
name_temp = name
child_temp = child
model._modules[name] = ConvRelu(child, is_relu=0).to(device)
is_conv = True
elif isinstance(child, nn.Linear):
name_temp = name
child_temp = child
model._modules[name] = LinearRelu(child, is_relu=0).to(device)
is_conv = False
else:
model_relu_folding(child)
return model
在ConvRelu中,使用register_buffer申请了权重和激活值量化相关的变量,采用model.train()的形式创建了一些mode,用来进行不同阶段的TensorRT算法。
weight_quant():统计权重的绝对值的最大值,量化的scale
initial_activate_max():统计激活值的最值,这个需要在校准集上跑一遍才能统计出的。
initial_histograms():统计激活值的直方图,这个也需要跑一遍校准集,需要注意的一点,如果有ReLU的话,0值对应的直方图数量很多,会减小其他值的权重,由于采用对称映射,0几乎无误差,所以将0值对应的直方图设置为0.
get_optimal_threshold():计算KL散度,获取最佳的阈值。
from torch import nn
import torch
import torch.nn.functional as F
import copy
from collections import OrderedDict
import numpy as np
INTERVAL_NUM = 4001
QUANTIZE_NUM = 127 # 7bit
STATISTIC = 1.0
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# the module that replace relu layer
class DummyModule(nn.Module):
def __init__(self):
super(DummyModule, self).__init__()
def forward(self, x):
return x
# the module that replace conv layer
class ConvRelu(nn.Module):
def __init__(self, conv, is_relu=0, bits=8, threshold=204800):
super(ConvRelu, self).__init__()
#self.conv_relu_fold = conv
self.threshold = threshold
self.bits = bits
self.is_relu = is_relu
self.kernel_size = conv.kernel_size
self.stride = conv.stride
self.padding = conv.padding
self.groups = conv.groups
self.bias = conv.bias
self.weight = conv.weight
'''mode : Normal, TRT_weight_quant, TRT_activate_collection_max, TRT_activate_collection_hist, TRT_activate_KL, Normal_TRT'''
self.mode = 'TRT_weight_quant'
#self.register_buffer('is_relu', torch.tensor(is_relu))
self.register_buffer('quant_num', torch.tensor((1 << bits) - 1))
'''activation_para'''
self.register_buffer('activate_flag', torch.zeros(1))
self.register_buffer('activate_distubution', torch.zeros(INTERVAL_NUM))
self.register_buffer('activate_distubution_edges', torch.zeros(INTERVAL_NUM+1))
self.register_buffer('activate_max', torch.zeros(1))
self.register_buffer('th', torch.zeros(1))
self.register_buffer('optimal_th', torch.zeros(1))
# self.register_buffer('activate_distubution_interval', torch.zeros(1))
'''weight_para'''
self.register_buffer('weight_flag', torch.zeros(1))
self.register_buffer('weight_scale', torch.zeros(conv.weight.data.shape[0]))
self.register_buffer('weight_zero', torch.zeros(conv.weight.data.shape[0]))
self.register_buffer('weight_max', torch.zeros(conv.weight.data.shape[0]))
def initial_activate_max(self, input):
max_val = torch.max(input)
min_val = torch.min(input)
self.activate_max = torch.max(self.activate_max, torch.max(torch.abs(max_val), torch.abs(min_val)))
# Avoid unusually large activation by clip blob_max with threshold
self.th= min(self.activate_max, self.threshold)
# print('test: ', self.th)
def weight_quant(self):
'''Avoid multiple operations caused by multiple identification of the module'''
self.weight_flag = torch.ones(1).to(device)
weight_max = torch.max(torch.max(torch.max(self.weight, 3, keepdim=True)[0], 2, keepdim=True)[0], 1, keepdim=True)[0]
weight_min = torch.min(torch.min(torch.min(self.weight, 3, keepdim=True)[0], 2, keepdim=True)[0], 1, keepdim=True)[0]
# weight_max_min = torch.cat((torch.abs(weight_max), torch.abs(weight_min)), 0).view([2,-1])
# self.weight_max = torch.max(weight_max_min,0,keepdim=True)[0]
weight_threshold = torch.max(torch.abs(weight_max), torch.abs(weight_min))
self.weight_max = weight_threshold
# print('weight_shape: ', weight_threshold.shape)
self.weight_scale = torch.where(weight_threshold < torch.tensor(0.0001).to(device), torch.tensor(0.0).to(device), ((1 << (self.bits-1))-1) / weight_threshold)
# print('weight_scale111: ', self.weight_scale)
self.weight_zero = torch.where(weight_threshold < torch.tensor(0.0001).to(device), torch.tensor(1.0).to(device), torch.tensor(0.0).to(device))
# def initial_activate_distubution_interval(self):
# self.activate_distubution_interval = (torch.tensor(STATISTIC).to(device)) * self.th / torch.tensor(INTERVAL_NUM).to(device).astype(float)
def initial_histograms(self, input):
# Truncate the boundary of the active hist graph,
# so the number exceeding the boundary value will not fall into statistics.
# print('id0: ', id(input))
input_cpu = input.cpu()
# print('id1: ', id(input_cpu))
# print(input_cpu)
input_cpu_numpy = input_cpu.numpy().flatten()
th = self.th.cpu().item()
# print(th)
hist, hist_edges = np.histogram(input_cpu_numpy, bins=INTERVAL_NUM, range=(-th, th))
#hist = torch.histc(input, bins=INTERVAL_NUM, min=-self.th, max=self.th)
self.activate_distubution += torch.from_numpy(hist).to(device)
self.activate_distubution[2000] = torch.tensor(0).to(device)
self.activate_distubution_edges = torch.from_numpy(hist_edges).to(device)
def plot_hist(self, optimal_th=None):
a = self.activate_distubution_edges.cpu().numpy()[:-1]
b = self.activate_distubution.cpu().numpy()
print('hist: ', a)
print('hist_edge: ', b)
import matplotlib.pyplot as plt
plt.plot(self.activate_distubution_edges.cpu().numpy()[:-1], self.activate_distubution.cpu().numpy())
if optimal_th is not None:
plt.plot(optimal_th, 0, 'om')
plt.annotate('optimal_th', xy=(optimal_th, 0), xytext=(optimal_th+1, 10000), arrowprops=dict(arrowstyle='->'))
plt.ylabel('activate distubution')
plt.show()
def get_optimal_threshold(self):
'''Avoid multiple operations caused by multiple identification of the module'''
self.activate_flag = torch.ones(1).to(device)
length = self.activate_distubution.shape[0]
assert (length % 2 == 1)
hist = self.activate_distubution.cpu().numpy()
hist_edge = self.activate_distubution_edges.cpu().numpy()
num_quantized_bins = self.quant_num.cpu().item()
optimal_threshold = calibrate(hist, hist_edge, num_quantized_bins)
self.optimal_th = torch.tensor(optimal_threshold).to(device)
print('th: ', self.th)
print('optimal_th: ', self.optimal_th)
self.plot_hist(optimal_th=optimal_threshold)
def forward(self, x):
assert self.training is False
# print('test')
x = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.groups)
# x = self.conv_relu_fold(x)
if self.is_relu:
x = F.relu(x)
if self.mode == 'TRT_activate_collection_max':
'''collect max,min,threshold'''
self.initial_activate_max(x)
elif self.mode == 'TRT_activate_collection_hist':
'''collect histograms'''
self.initial_histograms(x)
elif self.mode == 'TRT_activate_KL':
'''calibrate for optimal_threshold'''
#self.initial_activate_distubution_interval()
# self.get_optimal_threshold()
pass
elif self.mode == 'Normal_TRT':
pass
elif self.mode != 'TRT_weight_quant':
raise ValueError("mode error")
return x
以下代码是第二个参考代码中调用的C++代码,该C++代码有点错误,处理边界存在叠加,问题在:merge hist into num_quantized_bins bins部分,注意区分(已修改)。
def calibrate(hist, hist_edge, num_quantized_bins=255):
num_bins = hist.size
assert num_bins+1 == hist_edge.size
zero_bin_idx = num_bins // 2
num_half_quantized_bins = num_quantized_bins // 2
thresholds = np.zeros(zero_bin_idx + 1 - num_half_quantized_bins)
divergence = np.zeros(zero_bin_idx + 1 - num_half_quantized_bins)
for i in range(num_half_quantized_bins, zero_bin_idx+1, 1):
p_bin_index_start = zero_bin_idx - i
p_bin_index_stop = zero_bin_idx + i + 1
thresholds[i - num_half_quantized_bins] = hist_edge[p_bin_index_stop];
sliced_nd_hist = np.zeros(p_bin_index_stop - p_bin_index_start)
p = np.zeros(p_bin_index_stop - p_bin_index_start)
# for j in range(num_bins):
# if j <= p_bin_index_start:
# p[0] +=
p[1:] = hist[p_bin_index_start+1 : p_bin_index_stop]
sliced_nd_hist[1:] = hist[p_bin_index_start+1 : p_bin_index_stop]
p[0] = np.sum(hist[:p_bin_index_start+1])
p[-1] = p[-1] + np.sum(hist[p_bin_index_stop:])
# print(p)
# print(sliced_nd_hist)
'''calculate how many bins should be merged to generate quantized distribution q'''
num_merged_bins = sliced_nd_hist.size // num_quantized_bins
'''merge hist into num_quantized_bins bins'''
quantized_bins = np.zeros(num_quantized_bins)
for j in range(num_quantized_bins):
start = j * num_merged_bins
stop = (j+1) * num_merged_bins
quantized_bins[j] = np.sum(sliced_nd_hist[start:stop])
quantized_bins[-1] = quantized_bins[-1] + np.sum(sliced_nd_hist[num_quantized_bins * num_merged_bins : ])
'''expand quantized_bins into p.size bins'''
q = np.zeros(p_bin_index_stop - p_bin_index_start)
is_nonzeros = (p != 0).astype(np.int64)
for j in range(num_quantized_bins):
start = j * num_merged_bins
stop = q.size if (j == num_quantized_bins-1) else (j+1) * num_merged_bins
norm = is_nonzeros[start:stop].sum()
if norm != 0:
q[start:stop] = float(quantized_bins[j]) / float(norm)
q[p == 0] = 0
p = _smooth_distribution(p);
q = _smooth_distribution(q);
# p[p == 0] = 0.0001
# q[q == 0] = 0.0001
# print('p: ', p)
# print('q: ', q)
divergence[i - num_half_quantized_bins] = ComputeEntropy(p, q)
# print(divergence[i - num_half_quantized_bins])
# print('done')
min_kl_divergence = np.argmin(divergence)
return thresholds[min_kl_divergence]
def _smooth_distribution(p, eps=0.0001):
is_zeros = (p == 0).astype(np.float32)
is_nonzeros = (p != 0).astype(np.float32)
n_zeros = is_zeros.sum()
n_nonzeros = p.size - n_zeros
if not n_nonzeros:
raise ValueError('The discrete probability distribution is malformed. All entries are 0.')
eps1 = eps * float(n_zeros) / float(n_nonzeros)
assert eps1 < 1.0, 'n_zeros=%d, n_nonzeros=%d, eps1=%f' % (n_zeros, n_nonzeros, eps1)
hist = p.astype(np.float32)
hist += eps * is_zeros + (-eps1) * is_nonzeros
assert (hist <= 0).sum() == 0
return hist
#from scipy import *
def ComputeEntropy(p, q):
assert p.size == q.size
p_sum = np.sum(p)
q_sum = np.sum(q)
p = p / p_sum
q = q / q_sum
KL_dis = np.sum(p * np.lib.scimath.log(p / q))
return KL_dis
import torch
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
sys.path.append('vgg')
from VggNet import *
from datetime import datetime
from torch.utils.data import DataLoader
from torchvision import datasets,transforms
from ConvReluFold import model_relu_folding
from ConvBNFold import model_bn_folding
from quant_utils import ConvRelu, LinearRelu, DummyModule, TRT_Quantizer
model = torch.load('./model/vgg0.904_bnrelufold.pth')
model.eval()
'''---------------------------------------------------------------------------------------'''
'''---------------------- TRT_weight_quant ------------------------------------'''
TRT_Quantizer(model, mode='TRT_weight_quant')
'''---------------------------------------------------------------------------------------'''
'''---------------------- TRT_activate_collection_max ------------------------------------'''
TRT_Quantizer(model, mode='TRT_activate_collection_max')
model.eval()
correct = 0.0
total = 0
num = 0
with torch.no_grad():
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
pred = outputs.argmax(dim = 1) #
total += inputs.size(0)
correct += torch.eq(pred,labels).sum().item()
num += 1
if num > 20:
break
print('Accuracy of the network on the 10000 test images: %.2f %%' % (100.0 * correct / total))
'''---------------------------------------------------------------------------------------'''
'''---------------------- TRT_activate_collection_hist ------------------------------------'''
TRT_Quantizer(model, mode='TRT_activate_collection_hist')
correct = 0.0
total = 0
num = 0
with torch.no_grad(): # 训练集不需要反向传播
for inputs, labels in train_loader:
inputs, labels = inputs.to(device), labels.to(device)
outputs = model(inputs)
pred = outputs.argmax(dim = 1)
total += inputs.size(0)
correct += torch.eq(pred,labels).sum().item()
num += 1
if num > 20:
break
print('Accuracy of the network on the 10000 test images: %.2f %%' % (100.0 * correct / total))
'''---------------------------------------------------------------------------------------'''
'''---------------------- TRT_activate_KL ------------------------------------'''
TRT_Quantizer(model, mode='TRT_activate_KL')
注意:在使用融合后的模型时,必须import之前的model,否则会报错:缺少某个组件。(相当于在之前的方法的基础上新建了方法,所以还需要导入之前的方法才行)。
这次编程让我对pytorch的了解又加深了一步,另外之前学的C++现在派上了用场,否则关于直方图那部分还真的不好编写,真的是学无止境呀。