量化(Quantization)是指将高精度浮点数表示为低精度整数的过程,从而提高神经网络的效率和性能。在能够接受一定的精度损失的情况下,可以有以下的好处:
减小内存占用
加速计算
减小功耗和延迟
部署灵活性:由于量化模型更小、更快,它们可以更容易地部署在各种设备上,包括但不限于智能手机、IoT设备和边缘计算设备。
提高能效比:在很多场景下,能效比(性能与功耗的比值)是一个关键指标。通过减少内存和计算需求,量化可以显著提高神经网络的能效比。
降低部署成本:使用小型、低功耗的硬件部署量化模型可以降低整体部署成本。
需要注意的是,虽然量化带来了很多好处,但也可能导致模型精度的损失。因此,在使用量化之前,建议进行详细的评估和测试,以确保模型的效果满足特定应用的需求。
import torch
import torchvision.models as models
model = models.resnet50()
input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, input, "resnet50.onnx")
可以看出来,没有量化的模型是直接从输出一直传递FP32的数据类型到output,如下图
公式:
scale = float_max − float_min quant_max − quant_min \text{scale} = \frac{\text{float\_max} - \text{float\_min}}{\text{quant\_max} - \text{quant\_min}} scale=quant_max−quant_minfloat_max−float_min
插入数值:
scale = 1.62 − ( − 0.61 ) 127 − ( − 128 ) = 0.00874509 \text{scale} = \frac{1.62 - (-0.61)}{127 - (-128)} = 0.00874509 \ scale=127−(−128)1.62−(−0.61)=0.00874509
[-70, -59, 185]
由于最大整数值为127
,因此185
需要被截断:
[-70, -59, 185]
→ [-70, -59, 127]
[-0.6121563, -0.5199999, 1.1062843]
[-70, -59, 185]
--> [-128, -117, 127]
import numpy as np
# 截断操作
def saturate(x, int_max, int_min):
return np.clip(x, int_min, int_max)
# 计算缩放和偏移量
def scale_z_cal(x, int_max, int_min):
scale = (x.max() - x.min()) / (int_max - int_min)
z = int_max - np.round((x.max() / scale))
return scale, z
# 量化
def quant_float_data(x, scale, z, int_max, int_min):
xq = saturate(np.round(x / scale + z), int_max, int_min)
return xq
# 反量化
def dequant_data(xq, scale, z):
x = ((xq - z) * scale).astype('float32')
return x
if __name__ == '__main__':
# np.random.seed(0)
data_float32 = np.random.randn(3).astype('float32')
# data_float32 = np.random.randn(100).astype('float32')
# data_float32[99] = 100
# data_float32 = np.array([-0.61, -0.52, 1.62], dtype='float32')
print(f"input: {data_float32}")
# uint8 bound
# int_max = 255
# int_min = 0
# int8 bound
int_max = 127
int_min = -128
scale, z = scale_z_cal(data_float32, int_max, int_min)
print(f"scale: {scale}, z: {z}")
data_int8 = quant_float_data(data_float32, scale, z, int_max, int_min)
print(f"quant: {data_int8}")
data_dequant_float = dequant_data(data_int8, scale, z)
print(f"dequant: {data_dequant_float}")
print(f"diff: {data_dequant_float - data_float32}")
import numpy as np
# 截断操作
def saturate(x):
return np.clip(x, -127, 127)
# 缩放
def scale_cal(x):
max_val = np.max(np.abs(x))
return max_val / 127
# 量化
def quant_float_data(x, scale):
xq = saturate(np.round(x / scale))
return xq
# 反量化
def dequant_data(xq, scale):
x = (xq * scale).astype('float32')
return x
if __name__ == '__main__':
np.random.seed(4)
# data_float32 = np.random.randn(3).astype('float32')
data_float32 = np.array([1.62, -1.62, 0, -0.52, 1.62], dtype='float32')
print(f"input: {data_float32}")
scale = scale_cal(data_float32)
print(f"scale: {scale}")
data_int8 = quant_float_data(data_float32, scale)
print(f"quant: {data_int8}")
data_dequant_float = dequant_data(data_int8, scale)
print(f"dequant: {data_dequant_float}")
print(f"diff: {data_dequant_float - data_float32}")