pytorch模型量化

 

量化是一种加速推理的技术,量化算子并且仅仅支持前向传递。Pytorch支持int8量化,相比于float32,模型的大小减少4倍,内存要求减少4倍。与float32计算相比,对int8计算的硬件支持通常快2到4倍。

大多数情况下,模型需要以float32精度训练,然后将模型转换为int8。如今,PyTorch支持在具有AVX2支持或者更高版本的x86 CPU或者ARM CPU上运行量化运算符。

pytorch提供了三种量化模型的方法:

  1. 训练后动态量化:最简单的量化形式,权重被提前量化,激活在推理过程中被动态量化
  2. 训练后静态量化:最常用的量化形式,权重提前量化,并且基于观察校准过程中模型的行为来预先计算激活张量的比例因子和偏差。
  3. 量化意识训练:在极少数情况下,训练后量化不能提供足够的准确性,可以使用torch.quantization.FakeQuantize函数通过模拟量化来进行训练。

1. 前言

深度学习在移动端的应用越来越广泛,而移动端相对于GPU服务来讲算力较低并且存储空间也相对较小。基于这一点我们需要为移动端定制一些深度学习网络来满足我们的日常续需求,例如SqueezeNet,MobileNet,ShuffleNet等轻量级网络就是专为移动端设计的。但除了在网络方面进行改进,模型剪枝和量化应该算是最常用的优化方法了。剪枝就是将训练好的「大模型」的不重要的通道删除掉,在几乎不影响准确率的条件下对网络进行加速。而量化就是将浮点数(高精度)表示的权重和偏置用低精度整数(常用的有INT8)来近似表示,在量化到低精度之后就可以应用移动平台上的优化技术如NEON对计算过程进行加速,并且原始模型量化后的模型容量也会减少,使其能够更好的应用到移动端环境。但需要注意的问题是,将高精度模型量化到低精度必然会存在一个精度下降的问题,如何获取性能和精度的TradeOff很关键。

这篇文章是介绍使用Pytorch复现这篇论文:https://arxiv.org/abs/1806.08342 的一些细节并给出一些自测实验结果。注意,代码实现的是「Quantization Aware Training」 ,而后量化 「Post Training Quantization」 后面可能会再单独讲一下。代码实现是来自666DZY666博主实现的https://github.com/666DZY666/model-compression

2. 对称量化

在上次的视频中梁德澎作者已经将这些概念讲得非常清楚了,如果不愿意看文字表述可以移步到这个视频链接下观看视频:深度学习量化技术科普 。然后直接跳到第四节,但为了保证本次故事的完整性,我还是会介绍一下这两种量化方式。

对称量化的量化公式如下:

对称量化量化公式

其中表示量化的缩放因子,和分别表示量化前和量化后的数值。这里通过除以缩放因子接取整操作就把原始的浮点数据量化到了一个小区间中,比如对于「有符号的8Bit」 就是(无符号就是0到255了)。

这里有个Trick,即对于权重是量化到,这是为了累加的时候减少溢出的风险。

因为8bit的取值区间是[-2^7, 2^7-1],两个8bit相乘之后取值区间是 (-2^14,2^14],累加两次就到了(-2^15,2^15],所以最多只能累加两次而且第二次也有溢出风险,比如相邻两次乘法结果都恰好是2^14会超过2^15-1(int16正数可表示的最大值)。

所以把量化之后的权值限制在(-127,127)之间,那么一次乘法运算得到结果永远会小于-128*-128 = 2^14

对应的反量化公式为:

对称量化的反量化公式

即将量化后的值乘以就得到了反量化的结果,当然这个过程是有损的,如下图所示,橙色线表示的就是量化前的范围,而蓝色线代表量化后的数据范围,注意权重取。

量化和反量化的示意图

我们看一下上面橙色线的第个「黑色圆点对应的float32值」,将其除以缩放系数就量化为了一个在之间的值,然后取整之后就是,如果是反量化就乘以缩放因子返回上面的「第个黑色圆点」 ,用这个数去代替以前的数继续做网络的Forward。

那么这个缩放系数是怎么取的呢?如下式:

缩放系数Delta

3. 非对称量化

非对称量化相比于对称量化就在于多了一个零点偏移。一个float32的浮点数非对称量化到一个int8的整数(如果是有符号就是,如果是无符号就是)的步骤为 缩放,取整,零点偏移,和溢出保护,如下图所示:

白皮书非对称量化过程

对于8Bit无符号整数Nlevel的取值

然后缩放系数和零点偏移的计算公式如下:

 

 

4. 中部小结

将上面两种算法直接应用到各个网络上进行量化后(训练后量化PTQ)测试模型的精度结果如下:

红色部分即将上面两种量化算法应用到各个网络上做精度测试结果

5. 训练模拟量化

我们要在网络训练的过程中模型量化这个过程,然后网络分前向和反向两个阶段,前向阶段的量化就是第二节和第三节的内容。不过需要特别注意的一点是对于缩放因子的计算,权重和激活值的计算方法现在不一样了。

对于权重缩放因子还是和第2,3节的一致,即:

weight scale = max(abs(weight)) / 127

但是对于激活值的缩放因子计算就不再是简单的计算最大值,而是在训练过程中通过滑动平均(EMA)的方式去统计这个量化范围,更新的公式如下:

moving_max = moving_max * momenta + max(abs(activation)) * (1- momenta)

其中,momenta取接近1的数就可以了,在后面的Pytorch实验中取0.99,然后缩放因子:

activation scale = moving_max /128

然后反向传播阶段求梯度的公式如下:

QAT反向传播阶段求梯度的公式

我们在反向传播时求得的梯度是模拟量化之后权值的梯度,用这个梯度去更新量化前的权值。

这部分的代码如下,注意我们这个实验中是用float32来模拟的int8,不具有真实的板端加速效果,只是为了验证算法的可行性:


   
     
     
     
     
  1. class Quantizer(nn.Module):
  2. def __init__( self, bits, range_tracker):
  3. super().__init__()
  4. self.bits = bits
  5. self.range_tracker = range_tracker
  6. self.register_buffer( 'scale', None) # 量化比例因子
  7. self.register_buffer( 'zero_point', None) # 量化零点
  8. def update_params( self):
  9. raise NotImplementedError
  10. # 量化
  11. def quantize( self, input):
  12. output = input * self.scale - self.zero_point
  13. return output
  14. def round( self, input):
  15. output = Round.apply( input)
  16. return output
  17. # 截断
  18. def clamp( self, input):
  19. output = torch.clamp( input, self.min_val, self.max_val)
  20. return output
  21. # 反量化
  22. def dequantize( self, input):
  23. output = ( input + self.zero_point) / self.scale
  24. return output
  25. def forward( self, input):
  26. if self.bits == 32:
  27. output = input
  28. elif self.bits == 1:
  29. print( '!Binary quantization is not supported !')
  30. assert self.bits != 1
  31. else:
  32. self.range_tracker( input)
  33. self.update_params()
  34. output = self.quantize( input) # 量化
  35. output = self. round(output)
  36. output = self.clamp(output) # 截断
  37. output = self.dequantize(output) # 反量化
  38. return output

6. 代码实现

基于https://github.com/666DZY666/model-compression/blob/master/quantization/WqAq/IAO/models/util_wqaq.py 进行实验,这里实现了对称和非对称量化两种方案。需要注意的细节是,对于权值的量化需要分通道进行求取缩放因子,然后对于激活值的量化整体求一个缩放因子,这样效果最好(论文中提到)。

这部分的代码实现如下:


   
     
     
     
     
  1. # ********************* range_trackers(范围统计器,统计量化前范围) *********************
  2. class RangeTracker(nn.Module):
  3. def __init__( self, q_level):
  4. super().__init__()
  5. self.q_level = q_level
  6. def update_range( self, min_val, max_val):
  7. raise NotImplementedError
  8. @torch.no_grad()
  9. def forward( self, input):
  10. if self.q_level == 'L': # A,min_max_shape=(1, 1, 1, 1),layer级
  11. min_val = torch. min( input)
  12. max_val = torch. max( input)
  13. elif self.q_level == 'C': # W,min_max_shape=(N, 1, 1, 1),channel级
  14. min_val = torch. min(torch. min(torch. min( input, 3, keepdim= True)[ 0], 2, keepdim= True)[ 0], 1, keepdim= True)[ 0]
  15. max_val = torch. max(torch. max(torch. max( input, 3, keepdim= True)[ 0], 2, keepdim= True)[ 0], 1, keepdim= True)[ 0]
  16. self.update_range(min_val, max_val)
  17. class GlobalRangeTracker( RangeTracker): # W,min_max_shape=(N, 1, 1, 1),channel级,取本次和之前相比的min_max —— (N, C, W, H)
  18. def __init__( self, q_level, out_channels):
  19. super().__init__(q_level)
  20. self.register_buffer( 'min_val', torch.zeros(out_channels, 1, 1, 1))
  21. self.register_buffer( 'max_val', torch.zeros(out_channels, 1, 1, 1))
  22. self.register_buffer( 'first_w', torch.zeros( 1))
  23. def update_range( self, min_val, max_val):
  24. temp_minval = self.min_val
  25. temp_maxval = self.max_val
  26. if self.first_w == 0:
  27. self.first_w.add_( 1)
  28. self.min_val.add_(min_val)
  29. self.max_val.add_(max_val)
  30. else:
  31. self.min_val.add_(-temp_minval).add_(torch. min(temp_minval, min_val))
  32. self.max_val.add_(-temp_maxval).add_(torch. max(temp_maxval, max_val))
  33. class AveragedRangeTracker( RangeTracker): # A,min_max_shape=(1, 1, 1, 1),layer级,取running_min_max —— (N, C, W, H)
  34. def __init__( self, q_level, momentum=0.1):
  35. super().__init__(q_level)
  36. self.momentum = momentum
  37. self.register_buffer( 'min_val', torch.zeros( 1))
  38. self.register_buffer( 'max_val', torch.zeros( 1))
  39. self.register_buffer( 'first_a', torch.zeros( 1))
  40. def update_range( self, min_val, max_val):
  41. if self.first_a == 0:
  42. self.first_a.add_( 1)
  43. self.min_val.add_(min_val)
  44. self.max_val.add_(max_val)
  45. else:
  46. self.min_val.mul_( 1 - self.momentum).add_(min_val * self.momentum)
  47. self.max_val.mul_( 1 - self.momentum).add_(max_val * self.momentum)

其中self.register_buffer这行代码可以在内存中定一个常量,同时,模型保存和加载的时候可以写入和读出,即这个变量不会参与反向传播。

❝ pytorch一般情况下,是将网络中的参数保存成orderedDict形式的,这里的参数其实包含两种,一种是模型中各种module含的参数,即nn.Parameter,我们当然可以在网络中定义其他的nn.Parameter参数,另一种就是buffer,前者每次optim.step会得到更新,而不会更新后者。

另外,由于卷积层后面经常会接一个BN层,并且在前向推理时为了加速经常把BN层的参数融合到卷积层的参数中,所以训练模拟量化也要按照这个流程。即,我们首先需要把BN层的参数和卷积层的参数融合,然后再对这个参数做量化,具体过程可以借用德澎的这页PPT来说明:

Made By 梁德澎

因此,代码实现包含两个版本,一个是不融合BN的训练模拟量化,一个是融合BN的训练模拟量化,而关于为什么融合之后是上图这样的呢?请看下面的公式:

 

 

 

 

所以:

 

 

公式中的,和分别表示卷积层的权值与偏置,和分别为卷积层的输入与输出,则根据的计算公式,可以推出融合了batchnorm参数之后的权值与偏置,和。

未融合BN的训练模拟量化代码实现如下(带注释):


   
     
     
     
     
  1. # ********************* 量化卷积(同时量化A/W,并做卷积) *********************
  2. class Conv2d_Q(nn.Conv2d):
  3. def __init__(
  4. self,
  5. in_channels,
  6. out_channels,
  7. kernel_size,
  8. stride=1,
  9. padding=0,
  10. dilation=1,
  11. groups=1,
  12. bias=True,
  13. a_bits=8,
  14. w_bits=8,
  15. q_type=1,
  16. first_layer=0,
  17. ):
  18. super().__init__(
  19. in_channels=in_channels,
  20. out_channels=out_channels,
  21. kernel_size=kernel_size,
  22. stride=stride,
  23. padding=padding,
  24. dilation=dilation,
  25. groups=groups,
  26. bias=bias
  27. )
  28. # 实例化量化器(A-layer级,W-channel级)
  29. if q_type == 0:
  30. self.activation_quantizer = SymmetricQuantizer(bits=a_bits, range_tracker=AveragedRangeTracker(q_level= 'L'))
  31. self.weight_quantizer = SymmetricQuantizer(bits=w_bits, range_tracker=GlobalRangeTracker(q_level= 'C', out_channels=out_channels))
  32. else:
  33. self.activation_quantizer = AsymmetricQuantizer(bits=a_bits, range_tracker=AveragedRangeTracker(q_level= 'L'))
  34. self.weight_quantizer = AsymmetricQuantizer(bits=w_bits, range_tracker=GlobalRangeTracker(q_level= 'C', out_channels=out_channels))
  35. self.first_layer = first_layer
  36. def forward( self, input):
  37. # 量化A和W
  38. if not self.first_layer:
  39. input = self.activation_quantizer( input)
  40. q_input = input
  41. q_weight = self.weight_quantizer(self.weight)
  42. # 量化卷积
  43. output = F.conv2d(
  44. input=q_input,
  45. weight=q_weight,
  46. bias=self.bias,
  47. stride=self.stride,
  48. padding=self.padding,
  49. dilation=self.dilation,
  50. groups=self.groups
  51. )
  52. return output

而考虑了折叠BN的代码实现如下(带注释):


   
     
     
     
     
  1. def reshape_to_activation( input):
  2. return input.reshape( 1, - 1, 1, 1)
  3. def reshape_to_weight( input):
  4. return input.reshape(- 1, 1, 1, 1)
  5. def reshape_to_bias( input):
  6. return input.reshape(- 1)
  7. # ********************* bn融合_量化卷积(bn融合后,同时量化A/W,并做卷积) *********************
  8. class BNFold_Conv2d_Q( Conv2d_Q):
  9. def __init__(
  10. self,
  11. in_channels,
  12. out_channels,
  13. kernel_size,
  14. stride=1,
  15. padding=0,
  16. dilation=1,
  17. groups=1,
  18. bias=False,
  19. eps=1e-5,
  20. momentum=0.01, # 考虑量化带来的抖动影响,对momentum进行调整(0.1 ——> 0.01),削弱batch统计参数占比,一定程度抑制抖动。经实验量化训练效果更好,acc提升1%左右
  21. a_bits=8,
  22. w_bits=8,
  23. q_type=1,
  24. first_layer=0,
  25. ):
  26. super().__init__(
  27. in_channels=in_channels,
  28. out_channels=out_channels,
  29. kernel_size=kernel_size,
  30. stride=stride,
  31. padding=padding,
  32. dilation=dilation,
  33. groups=groups,
  34. bias=bias
  35. )
  36. self.eps = eps
  37. self.momentum = momentum
  38. self.gamma = Parameter(torch.Tensor(out_channels))
  39. self.beta = Parameter(torch.Tensor(out_channels))
  40. self.register_buffer( 'running_mean', torch.zeros(out_channels))
  41. self.register_buffer( 'running_var', torch.ones(out_channels))
  42. self.register_buffer( 'first_bn', torch.zeros( 1))
  43. init.uniform_(self.gamma)
  44. init.zeros_(self.beta)
  45. # 实例化量化器(A-layer级,W-channel级)
  46. if q_type == 0:
  47. self.activation_quantizer = SymmetricQuantizer(bits=a_bits, range_tracker=AveragedRangeTracker(q_level= 'L'))
  48. self.weight_quantizer = SymmetricQuantizer(bits=w_bits, range_tracker=GlobalRangeTracker(q_level= 'C', out_channels=out_channels))
  49. else:
  50. self.activation_quantizer = AsymmetricQuantizer(bits=a_bits, range_tracker=AveragedRangeTracker(q_level= 'L'))
  51. self.weight_quantizer = AsymmetricQuantizer(bits=w_bits, range_tracker=GlobalRangeTracker(q_level= 'C', out_channels=out_channels))
  52. self.first_layer = first_layer
  53. def forward( self, input):
  54. # 训练态
  55. if self.training:
  56. # 先做普通卷积得到A,以取得BN参数
  57. output = F.conv2d(
  58. input= input,
  59. weight=self.weight,
  60. bias=self.bias,
  61. stride=self.stride,
  62. padding=self.padding,
  63. dilation=self.dilation,
  64. groups=self.groups
  65. )
  66. # 更新BN统计参数(batch和running)
  67. dims = [dim for dim in range( 4) if dim != 1]
  68. batch_mean = torch.mean(output, dim=dims)
  69. batch_var = torch.var(output, dim=dims)
  70. with torch.no_grad():
  71. if self.first_bn == 0:
  72. self.first_bn.add_( 1)
  73. self.running_mean.add_(batch_mean)
  74. self.running_var.add_(batch_var)
  75. else:
  76. self.running_mean.mul_( 1 - self.momentum).add_(batch_mean * self.momentum)
  77. self.running_var.mul_( 1 - self.momentum).add_(batch_var * self.momentum)
  78. # BN融合
  79. if self.bias is not None:
  80. bias = reshape_to_bias(self.beta + (self.bias - batch_mean) * (self.gamma / torch.sqrt(batch_var + self.eps)))
  81. else:
  82. bias = reshape_to_bias(self.beta - batch_mean * (self.gamma / torch.sqrt(batch_var + self.eps))) # b融batch
  83. weight = self.weight * reshape_to_weight(self.gamma / torch.sqrt(self.running_var + self.eps)) # w融running
  84. # 测试态
  85. else:
  86. #print(self.running_mean, self.running_var)
  87. # BN融合
  88. if self.bias is not None:
  89. bias = reshape_to_bias(self.beta + (self.bias - self.running_mean) * (self.gamma / torch.sqrt(self.running_var + self.eps)))
  90. else:
  91. bias = reshape_to_bias(self.beta - self.running_mean * (self.gamma / torch.sqrt(self.running_var + self.eps))) # b融running
  92. weight = self.weight * reshape_to_weight(self.gamma / torch.sqrt(self.running_var + self.eps)) # w融running
  93. # 量化A和bn融合后的W
  94. if not self.first_layer:
  95. input = self.activation_quantizer( input)
  96. q_input = input
  97. q_weight = self.weight_quantizer(weight)
  98. # 量化卷积
  99. if self.training: # 训练态
  100. output = F.conv2d(
  101. input=q_input,
  102. weight=q_weight,
  103. bias=self.bias, # 注意,这里不加bias(self.bias为None)
  104. stride=self.stride,
  105. padding=self.padding,
  106. dilation=self.dilation,
  107. groups=self.groups
  108. )
  109. # (这里将训练态下,卷积中w融合running参数的效果转为融合batch参数的效果)running ——> batch
  110. output *= reshape_to_activation(torch.sqrt(self.running_var + self.eps) / torch.sqrt(batch_var + self.eps))
  111. output += reshape_to_activation(bias)
  112. else: # 测试态
  113. output = F.conv2d(
  114. input=q_input,
  115. weight=q_weight,
  116. bias=bias, # 注意,这里加bias,做完整的conv+bn
  117. stride=self.stride,
  118. padding=self.padding,
  119. dilation=self.dilation,
  120. groups=self.groups
  121. )
  122. return output

注意一个点,在训练的时候bias设置为None,即训练的时候不量化bias

7. 实验结果

在CIFAR10做Quantization Aware Training实验,网络结构为:


   
     
     
     
     
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. from .util_wqaq import Conv2d_Q, BNFold_Conv2d_Q
  5. class QuanConv2d(nn.Module):
  6. def __init__(self, input_channels, output_channels,
  7. kernel_size=-1, stride=-1, padding=-1, groups=1, last_relu=0, abits=8, wbits=8, bn_fold=0, q_type=1, first_layer=0):
  8. super(QuanConv2d, self).__init__()
  9. self.last_relu = last_relu
  10. self.bn_fold = bn_fold
  11. self.first_layer = first_layer
  12. if self.bn_fold == 1:
  13. self.bn_q_conv = BNFold_Conv2d_Q(input_channels, output_channels,
  14. kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, a_bits=abits, w_bits=wbits, q_type=q_type, first_layer=first_layer)
  15. else:
  16. self.q_conv = Conv2d_Q(input_channels, output_channels,
  17. kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, a_bits=abits, w_bits=wbits, q_type=q_type, first_layer=first_layer)
  18. self.bn = nn.BatchNorm2d(output_channels, momentum=0.01) # 考虑量化带来的抖动影响,对momentum进行调整(0.1 ——> 0.01),削弱batch统计参数占比,一定程度抑制抖动。经实验量化训练效果更好,acc提升1%左右
  19. self.relu = nn.ReLU(inplace=True)
  20. def forward(self, x):
  21. if not self.first_layer:
  22. x = self.relu(x)
  23. if self.bn_fold == 1:
  24. x = self.bn_q_conv(x)
  25. else:
  26. x = self.q_conv(x)
  27. x = self.bn(x)
  28. if self.last_relu:
  29. x = self.relu(x)
  30. return x
  31. class Net(nn.Module):
  32. def __init__(self, cfg = None, abits=8, wbits=8, bn_fold=0, q_type=1):
  33. super(Net, self).__init__()
  34. if cfg is None:
  35. cfg = [192, 160, 96, 192, 192, 192, 192, 192]
  36. # model - A/W全量化(除输入、输出外)
  37. self.quan_model = nn.Sequential(
  38. QuanConv2d(3, cfg[0], kernel_size=5, stride=1, padding=2, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type, first_layer=1),
  39. QuanConv2d(cfg[0], cfg[1], kernel_size=1, stride=1, padding=0, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
  40. QuanConv2d(cfg[1], cfg[2], kernel_size=1, stride=1, padding=0, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
  41. nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
  42. QuanConv2d(cfg[2], cfg[3], kernel_size=5, stride=1, padding=2, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
  43. QuanConv2d(cfg[3], cfg[4], kernel_size=1, stride=1, padding=0, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
  44. QuanConv2d(cfg[4], cfg[5], kernel_size=1, stride=1, padding=0, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
  45. nn.MaxPool2d(kernel_size=3, stride=2, padding=1),
  46. QuanConv2d(cfg[5], cfg[6], kernel_size=3, stride=1, padding=1, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
  47. QuanConv2d(cfg[6], cfg[7], kernel_size=1, stride=1, padding=0, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
  48. QuanConv2d(cfg[7], 10, kernel_size=1, stride=1, padding=0, last_relu=1, abits=abits, wbits=wbits, bn_fold=bn_fold, q_type=q_type),
  49. nn.AvgPool2d(kernel_size=8, stride=1, padding=0),
  50. )
  51. def forward(self, x):
  52. x = self.quan_model(x)
  53. x = x.view(x.size(0), -1)
  54. return x

训练Epoch数为30,学习率调整策略为:


   
     
     
     
     
  1. def adjust_learning_rate( optimizer, epoch):
  2. if args.bn_fold == 1:
  3. if args.model_type == 0:
  4. update_list = [ 12, 15, 25]
  5. else:
  6. update_list = [ 8, 12, 20, 25]
  7. else:
  8. update_list = [ 15, 17, 20]
  9. if epoch in update_list:
  10. for param_group in optimizer.param_groups:
  11. param_group[ 'lr'] = param_group[ 'lr'] * 0.1
  12. return

类型Acc备注原模型(nin)91.01%全精度对称量化, bn不融合88.88%INT8对称量化,bn融合86.66%INT8非对称量化,bn不融合88.89%INT8非对称量化,bn融合87.30%INT8

现在不清楚为什么量化后的精度损失了1-2个点,根据德澎在MxNet的实验结果来看,分类任务不会损失精度,所以不知道这个代码是否存在问题,有经验的大佬欢迎来指出问题。

然后白皮书上提供的一些分类网络的训练模拟量化精度情况如下:

QAT方式明显好于Post Train Quantzation

注意前面有一些精度几乎为0的数据是因为MobileNet训练出来之后某些层的权重非常接近0,使用训练后量化方法之后权重也为0,这就导致推理后结果完全错误。

8. 总结

今天介绍了一下基于Pytorch实现QAT量化,并用一个小网络测试了一下效果,但比较遗憾的是并没有获得论文中那么理想的数据,仍需要进一步研究。

你可能感兴趣的:(深度学习,人工智能,模型部署,pytorch,模型量化,量化感知训练,模型轻量化)