神经网络量化----TensorRT深刻解读

神经网络量化----TensorRT深刻解读


 

目录

神经网络量化----TensorRT深刻解读

前言

一、TensorRT简介

二、难点

1.架构

2.功能

三、实现

1.conv和ReLU的融合

2.conv和ReLU的融合

 quant_utils.py

3.调用示例

总结



前言

本文将聚焦于英伟达TensorRT训练后量化的算法。
论文地址为:https://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf
代码地址为:官方好像没有公布代码,可以参考的有https://github.com/deepglint/EasyQuant,https://github.com//apache/incubator-mxnet/blob/master/python/mxnet/contrib/quantization.py


一、TensorRT简介

整个量化算法使用对称量化(-max, max)-> (-127, 127)

神经网络量化----TensorRT深刻解读_第1张图片

权重的量化所需的最值直接统计得出,激活值的量化使用饱和映射的方式,设置最佳的阈值作为最值来进行量化。

神经网络量化----TensorRT深刻解读_第2张图片

计算最佳阈值的方法:1. 统计激活值的直方图,2. 采用遍历的方法找到量化后KL散度最小时对应的最佳阈值。伪代码如下:

神经网络量化----TensorRT深刻解读_第3张图片

二、难点

1.架构

以上所给的参考代码语言分别为caffe和mxnet,caffe代码里有保存和加载文件的操作,对于量化后的参数不能直观地和网络架构联系在一起,mxnet代码里的功能较全,但是语言比较小众化,开发新的算法难度较大。

故计划采用pytorch来进行量化参数的计算和保存,且跟网络是一个整体,新建module即可,比较方便。

2.功能

有的卷积后有激活和BN,所以还要将其考虑在内。

所以在处理之前需要将网络中的卷积、BN、ReLU融合在一起。

在量化时,不考虑是否有ReLU,全部量化在(-max, max)之间。(格林深瞳算法只考虑了带有ReLU的,即将(0,max)量化到(0,127)),这样就简化了运算,不用再分情况了。

三、实现

1.conv和ReLU的融合

from torch import nn
import torch

# the module that replace BN layer
class DummyModule(nn.Module):
    def __init__(self):
        super(DummyModule, self).__init__()

    def forward(self, x):
        return x

# BN flod
def bn_folding(conv, bn):
    
    # ******************** BN parameter *********************
    mean = bn.running_mean
    std = torch.sqrt(bn.running_var + bn.eps)
    gamma = bn.weight
    beta = bn.bias
    # ******************* conv parameter********************
    w = conv.weight
    w_fold = w.clone()
    if conv.bias is not None:
        b = conv.bias
    else:
        b = mean.new_zeros(mean.shape)
    b_fold = b.clone()
    
    w_fold = w * (gamma / std).reshape([conv.out_channels, 1, 1, 1])
    b_fold = beta + (b - mean) * (gamma / std) 
    
    bnfold_conv = nn.Conv2d(conv.in_channels,
                         conv.out_channels,
                         conv.kernel_size,
                         conv.stride,
                         conv.padding,
                         groups=conv.groups,
                         bias=True)
    bnfold_conv.weight.data = w_fold
    bnfold_conv.bias.data = b_fold
    return bnfold_conv

'''BN must be after convolution'''
def model_bn_folding(model):
    children = list(model.named_children())
    # children = list(model.named_modules())
    #print(children)
    name_temp = None
    child_temp = None
    for name, child in children:
        #print(name, '   ', child)
        if isinstance(child, nn.BatchNorm2d):
            bnfold_conv = bn_folding(child_temp, child) # BN融合
            model._modules[name_temp] = bnfold_conv
            model._modules[name] = DummyModule()
            child_temp = None
        elif isinstance(child, nn.Conv2d):
            name_temp = name
            child_temp = child
        else:
            
            model_bn_folding(child)
    return model

2.conv和ReLU的融合

新建一个module将卷积和ReLU包含在内了。

import torch
from torch import nn

import torch.nn.functional as F
from quant_utils import ConvRelu, LinearRelu, DummyModule

# device = torch.device("cpu")
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

'''BN must be after convolution'''
def model_relu_folding(model):
    children = list(model.named_children())
    
    # children = list(model.named_modules())
    #print(children)
    name_temp = None
    child_temp = None
    is_conv = True
    for name, child in children:
        print(name, '   ', child)
        if isinstance(child, nn.ReLU):
            if is_conv:
                model._modules[name_temp] = ConvRelu(child_temp, is_relu=1).to(device)
               
            else: 
                model._modules[name_temp] = LinearRelu(child_temp, is_relu=1).to(device)
            
            model._modules[name] = DummyModule().to(device)
            
            # child_temp = None
            # name_temp = None
        elif isinstance(child, nn.Conv2d):
            name_temp = name
            child_temp = child               
            
            model._modules[name] = ConvRelu(child, is_relu=0).to(device)            
            is_conv = True
            
        elif isinstance(child, nn.Linear):
            name_temp = name
            child_temp = child            
            model._modules[name] = LinearRelu(child, is_relu=0).to(device)
            is_conv = False
            
        else:
            model_relu_folding(child)
    return model

 quant_utils.py

在ConvRelu中,使用register_buffer申请了权重和激活值量化相关的变量,采用model.train()的形式创建了一些mode,用来进行不同阶段的TensorRT算法。

weight_quant():统计权重的绝对值的最大值,量化的scale

initial_activate_max():统计激活值的最值,这个需要在校准集上跑一遍才能统计出的。

initial_histograms():统计激活值的直方图,这个也需要跑一遍校准集,需要注意的一点,如果有ReLU的话,0值对应的直方图数量很多,会减小其他值的权重,由于采用对称映射,0几乎无误差,所以将0值对应的直方图设置为0.

get_optimal_threshold():计算KL散度,获取最佳的阈值。

from torch import nn
import torch
import torch.nn.functional as F
import copy
from collections import OrderedDict
import numpy as np


INTERVAL_NUM = 4001
QUANTIZE_NUM = 127    # 7bit
STATISTIC = 1.0


device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# the module that replace relu layer
class DummyModule(nn.Module):
    def __init__(self):
        super(DummyModule, self).__init__()

    def forward(self, x):
        return x
    
    
# the module that replace conv layer
class ConvRelu(nn.Module):
    def __init__(self, conv, is_relu=0, bits=8, threshold=204800):
        super(ConvRelu, self).__init__()
        
        #self.conv_relu_fold = conv
        self.threshold = threshold
        self.bits = bits
        self.is_relu = is_relu
        self.kernel_size = conv.kernel_size
        self.stride = conv.stride
        self.padding = conv.padding
        self.groups = conv.groups
        self.bias = conv.bias
        self.weight = conv.weight
        '''mode : Normal, TRT_weight_quant, TRT_activate_collection_max, TRT_activate_collection_hist, TRT_activate_KL, Normal_TRT'''
        self.mode = 'TRT_weight_quant'         
        
        
        #self.register_buffer('is_relu', torch.tensor(is_relu))
        self.register_buffer('quant_num', torch.tensor((1 << bits) - 1))
        
        '''activation_para'''
        self.register_buffer('activate_flag', torch.zeros(1))
        self.register_buffer('activate_distubution', torch.zeros(INTERVAL_NUM))
        self.register_buffer('activate_distubution_edges', torch.zeros(INTERVAL_NUM+1))
        self.register_buffer('activate_max', torch.zeros(1))
        self.register_buffer('th', torch.zeros(1))
        self.register_buffer('optimal_th', torch.zeros(1))
        # self.register_buffer('activate_distubution_interval', torch.zeros(1))
        '''weight_para'''
        self.register_buffer('weight_flag', torch.zeros(1))
        self.register_buffer('weight_scale', torch.zeros(conv.weight.data.shape[0]))
        self.register_buffer('weight_zero', torch.zeros(conv.weight.data.shape[0]))
        self.register_buffer('weight_max', torch.zeros(conv.weight.data.shape[0]))
        
    def initial_activate_max(self, input):
        max_val = torch.max(input)
        min_val = torch.min(input)
        self.activate_max = torch.max(self.activate_max, torch.max(torch.abs(max_val), torch.abs(min_val)))
        # Avoid unusually large activation by clip blob_max with threshold
        self.th= min(self.activate_max, self.threshold)
        # print('test: ', self.th)
        
    def weight_quant(self):
        '''Avoid multiple operations caused by multiple identification of the module'''
        self.weight_flag = torch.ones(1).to(device)
        
        weight_max = torch.max(torch.max(torch.max(self.weight, 3, keepdim=True)[0], 2, keepdim=True)[0], 1, keepdim=True)[0]
        weight_min = torch.min(torch.min(torch.min(self.weight, 3, keepdim=True)[0], 2, keepdim=True)[0], 1, keepdim=True)[0]
        # weight_max_min = torch.cat((torch.abs(weight_max), torch.abs(weight_min)), 0).view([2,-1])
        # self.weight_max = torch.max(weight_max_min,0,keepdim=True)[0]
        weight_threshold = torch.max(torch.abs(weight_max), torch.abs(weight_min))
        self.weight_max = weight_threshold
        # print('weight_shape: ', weight_threshold.shape)   
        self.weight_scale = torch.where(weight_threshold < torch.tensor(0.0001).to(device), torch.tensor(0.0).to(device), ((1 << (self.bits-1))-1) / weight_threshold)
        # print('weight_scale111: ', self.weight_scale)
        self.weight_zero = torch.where(weight_threshold < torch.tensor(0.0001).to(device), torch.tensor(1.0).to(device), torch.tensor(0.0).to(device))
            
        
    # def initial_activate_distubution_interval(self):
    #     self.activate_distubution_interval = (torch.tensor(STATISTIC).to(device)) * self.th / torch.tensor(INTERVAL_NUM).to(device).astype(float)
        
    def initial_histograms(self, input):
        # Truncate the boundary of the active hist graph,
        # so the number exceeding the boundary value will not fall into statistics.
        # print('id0: ', id(input))
        input_cpu = input.cpu()
        # print('id1: ', id(input_cpu))
        # print(input_cpu)
        input_cpu_numpy = input_cpu.numpy().flatten()
        th = self.th.cpu().item()
        # print(th)
        hist, hist_edges = np.histogram(input_cpu_numpy, bins=INTERVAL_NUM, range=(-th, th))
        
        #hist = torch.histc(input, bins=INTERVAL_NUM, min=-self.th, max=self.th)
        
        self.activate_distubution += torch.from_numpy(hist).to(device)
        self.activate_distubution[2000] = torch.tensor(0).to(device)
        self.activate_distubution_edges = torch.from_numpy(hist_edges).to(device)
        
    def plot_hist(self, optimal_th=None):
        a = self.activate_distubution_edges.cpu().numpy()[:-1]
        b = self.activate_distubution.cpu().numpy()
        print('hist: ', a)
        print('hist_edge: ', b)
        import matplotlib.pyplot as plt
        plt.plot(self.activate_distubution_edges.cpu().numpy()[:-1], self.activate_distubution.cpu().numpy())
        if optimal_th is not None:
            plt.plot(optimal_th, 0, 'om')
            plt.annotate('optimal_th', xy=(optimal_th, 0), xytext=(optimal_th+1, 10000), arrowprops=dict(arrowstyle='->'))
        plt.ylabel('activate distubution')
        plt.show()
        
    def get_optimal_threshold(self):
        '''Avoid multiple operations caused by multiple identification of the module'''
        self.activate_flag = torch.ones(1).to(device)
        
        length = self.activate_distubution.shape[0]
        assert (length % 2 == 1)
        hist = self.activate_distubution.cpu().numpy()
        hist_edge = self.activate_distubution_edges.cpu().numpy()
        num_quantized_bins = self.quant_num.cpu().item()
        
        optimal_threshold = calibrate(hist, hist_edge, num_quantized_bins)
        self.optimal_th = torch.tensor(optimal_threshold).to(device)
        print('th: ', self.th)
        print('optimal_th: ', self.optimal_th)
        self.plot_hist(optimal_th=optimal_threshold)
    
        
    

    def forward(self, x):
        assert self.training is False
        
        
        
        # print('test')
        x  = F.conv2d(x, self.weight, self.bias, self.stride, self.padding, self.groups)
        # x = self.conv_relu_fold(x)
        if self.is_relu:
            x = F.relu(x)
            
        if self.mode == 'TRT_activate_collection_max':
            '''collect max,min,threshold'''
            self.initial_activate_max(x)
            
        elif self.mode == 'TRT_activate_collection_hist':
            '''collect histograms'''
            self.initial_histograms(x) 
 
        elif self.mode == 'TRT_activate_KL':
            '''calibrate for optimal_threshold'''
            #self.initial_activate_distubution_interval()            
            # self.get_optimal_threshold()
                    
            pass
        elif self.mode == 'Normal_TRT':
            pass
        elif self.mode != 'TRT_weight_quant':
            
            raise ValueError("mode error")
        return x

 以下代码是第二个参考代码中调用的C++代码,该C++代码有点错误,处理边界存在叠加,问题在:merge hist into num_quantized_bins bins部分,注意区分(已修改)。

def calibrate(hist, hist_edge, num_quantized_bins=255):
    num_bins = hist.size
    
    assert num_bins+1 == hist_edge.size
    zero_bin_idx = num_bins // 2
    num_half_quantized_bins = num_quantized_bins // 2
    thresholds = np.zeros(zero_bin_idx + 1 - num_half_quantized_bins)
    divergence = np.zeros(zero_bin_idx + 1 - num_half_quantized_bins)
    
    for i in range(num_half_quantized_bins, zero_bin_idx+1, 1):
        p_bin_index_start = zero_bin_idx - i
        p_bin_index_stop = zero_bin_idx + i + 1
        thresholds[i - num_half_quantized_bins] = hist_edge[p_bin_index_stop];
        
        
        sliced_nd_hist = np.zeros(p_bin_index_stop - p_bin_index_start)
        p = np.zeros(p_bin_index_stop - p_bin_index_start)
        
        # for j in range(num_bins):
        #     if j <= p_bin_index_start:
        #         p[0] +=
        
        p[1:] = hist[p_bin_index_start+1 : p_bin_index_stop]
        sliced_nd_hist[1:] = hist[p_bin_index_start+1 : p_bin_index_stop]
        p[0] = np.sum(hist[:p_bin_index_start+1])
        p[-1] = p[-1] + np.sum(hist[p_bin_index_stop:])
        # print(p)
        # print(sliced_nd_hist)
        
        '''calculate how many bins should be merged to generate quantized distribution q'''
        num_merged_bins = sliced_nd_hist.size // num_quantized_bins
        '''merge hist into num_quantized_bins bins'''
        quantized_bins = np.zeros(num_quantized_bins)
        for j in range(num_quantized_bins):
            start = j * num_merged_bins
            stop = (j+1) * num_merged_bins
            quantized_bins[j] = np.sum(sliced_nd_hist[start:stop])
            
        quantized_bins[-1] = quantized_bins[-1] + np.sum(sliced_nd_hist[num_quantized_bins * num_merged_bins : ])
        '''expand quantized_bins into p.size bins'''
        q = np.zeros(p_bin_index_stop - p_bin_index_start)
        is_nonzeros = (p != 0).astype(np.int64)
        for j in range(num_quantized_bins):
            start = j * num_merged_bins
            stop = q.size if (j == num_quantized_bins-1)  else (j+1) * num_merged_bins
            norm = is_nonzeros[start:stop].sum()
            if norm != 0:
                q[start:stop] = float(quantized_bins[j]) / float(norm)
        q[p == 0] = 0
        p = _smooth_distribution(p);
        q = _smooth_distribution(q);
        # p[p == 0] = 0.0001
        # q[q == 0] = 0.0001
        # print('p: ', p)
        # print('q: ', q)
        divergence[i - num_half_quantized_bins] = ComputeEntropy(p, q)
        # print(divergence[i - num_half_quantized_bins])
        # print('done')
    
    min_kl_divergence = np.argmin(divergence)
    return thresholds[min_kl_divergence]
        
            
            
            
def _smooth_distribution(p, eps=0.0001):
    
    is_zeros = (p == 0).astype(np.float32)
    is_nonzeros = (p != 0).astype(np.float32)
    n_zeros = is_zeros.sum()
    n_nonzeros = p.size - n_zeros
    if not n_nonzeros:
        raise ValueError('The discrete probability distribution is malformed. All entries are 0.')
    eps1 = eps * float(n_zeros) / float(n_nonzeros)
    assert eps1 < 1.0, 'n_zeros=%d, n_nonzeros=%d, eps1=%f' % (n_zeros, n_nonzeros, eps1)
    hist = p.astype(np.float32)
    hist += eps * is_zeros + (-eps1) * is_nonzeros
    assert (hist <= 0).sum() == 0
    return hist

#from scipy import *
def ComputeEntropy(p, q):
    assert p.size == q.size 
    p_sum = np.sum(p)
    q_sum = np.sum(q)
    p = p / p_sum
    q = q / q_sum
    KL_dis = np.sum(p * np.lib.scimath.log(p / q))
    return KL_dis

3.调用示例

import torch
import sys
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
sys.path.append('vgg')
from VggNet import * 
from datetime import datetime
from torch.utils.data import DataLoader

from torchvision import datasets,transforms

from ConvReluFold import model_relu_folding

from ConvBNFold import model_bn_folding
from quant_utils import ConvRelu, LinearRelu, DummyModule, TRT_Quantizer

model = torch.load('./model/vgg0.904_bnrelufold.pth')
model.eval()

'''---------------------------------------------------------------------------------------'''
'''---------------------- TRT_weight_quant ------------------------------------'''
TRT_Quantizer(model, mode='TRT_weight_quant')

'''---------------------------------------------------------------------------------------'''
'''---------------------- TRT_activate_collection_max ------------------------------------'''
TRT_Quantizer(model, mode='TRT_activate_collection_max')

model.eval()
correct = 0.0
total = 0
num = 0
with torch.no_grad():  
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device) 
  
        outputs = model(inputs)
        pred = outputs.argmax(dim = 1)  # 
        total += inputs.size(0)
        correct += torch.eq(pred,labels).sum().item()
        num += 1
        if num > 20:
            break
print('Accuracy of the network on the 10000 test images: %.2f %%' % (100.0 * correct / total))

'''---------------------------------------------------------------------------------------'''
'''---------------------- TRT_activate_collection_hist ------------------------------------'''
TRT_Quantizer(model, mode='TRT_activate_collection_hist')
correct = 0.0
total = 0
num = 0
with torch.no_grad():  # 训练集不需要反向传播
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device) 
  
        outputs = model(inputs)
        pred = outputs.argmax(dim = 1)  
        total += inputs.size(0)
        correct += torch.eq(pred,labels).sum().item()
        num += 1
        if num > 20:
            break
print('Accuracy of the network on the 10000 test images: %.2f %%' % (100.0 * correct / total))

'''---------------------------------------------------------------------------------------'''
'''---------------------- TRT_activate_KL ------------------------------------'''
TRT_Quantizer(model, mode='TRT_activate_KL')

注意:在使用融合后的模型时,必须import之前的model,否则会报错:缺少某个组件。(相当于在之前的方法的基础上新建了方法,所以还需要导入之前的方法才行)。


总结

这次编程让我对pytorch的了解又加深了一步,另外之前学的C++现在派上了用场,否则关于直方图那部分还真的不好编写,真的是学无止境呀。

你可能感兴趣的:(神经网络量化,pytorch,C++,算法,python,深度学习)