Pytorch量化之静态量化

env:

  • pytorch==1.7.1
  • torchvision==0.8.2
  • python==3.6

注意:

  • 精度变差
  • 操作比较简单,但还是需要动模型
  • 层合并的部分需要对结构有了解
  • 模型大小变为原来的1/4
  • 推理速度提高20+%

step1:加载模型

就正常加载即可,没啥特别的

model = Resnet().to(device)
checkpoint = torch.load(weights, map_location=device)
model.load_state_dict(checkpoint)
model.to(device).eval()

step2:量化

照猫画虎即可,没啥特别的

backend = "fbgemm"
model.qconfig = torch.quantization.get_default_qconfig(backend)  # 不同平台不同配置

listmix = [['conv','relu']] # 可以是conv+bn conv+relu conv+bn+relu 
model = torch.quantization.fuse_modules(model,listmix) # 合并某些层,不想合并这句也可以跳过

model_fp32_prepared = torch.quantization.prepare(model)
model_int8 = torch.quantization.convert(model_fp32_prepared)

step3:持久化(保存模型)

两种,一种保存变量,一种保存变量+结构

保存变量+结构会节省加载模型的时间

# 保存
traced_model = torch.jit.trace(model_int8, img)
torch.jit.save(traced_model, "traced_int8.pt")

# 加载
model = torch.jit.load("traced_int8.pt")
model(img)

保存变量

# 保存
torch.save(model_int8.state_dict(), "int_8_post.pt")

# 加载
'''定义模型结构'''
model = YourNet().to(device)
checkpoint = torch.load("int_8_post.pt", map_location=device)
model.load_state_dict(checkpoint)
model.to(device).eval()

''' 把之前量化的操作粘贴进来'''
backend = "fbgemm"
model.qconfig = torch.quantization.get_default_qconfig(backend)  # 不同平台不同配置

listmix = [['conv','relu']] # 可以是conv+bn conv+relu conv+bn+relu 
model = torch.quantization.fuse_modules(model,listmix) # 合并某些层

model_fp32_prepared = torch.quantization.prepare(model)
model_int8 = torch.quantization.convert(model_fp32_prepared)
''' 加载变量'''
checkpoint = torch.load("int_8_post.pt", map_location=device)
model_int8.load_state_dict(checkpoint)
model_int8.eval()
model_int8(img)

step4:input压缩与解压缩

这步需要对模型输入修改一下,因为量化的模型需要量化的输入,python的计算需要解量化

class YourNet(nn.Module):
 
    def __init__(self, cfg, img_size=(416, 416), verbose=False):
        ... ...
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        ... ...
    def forward(self,input):
        x = self.quant(input)
        x = self.layer(x)
        x = self.dequant(x)
        ... ...

参考:

https://pytorch.org/docs/stable/quantization.html

https://github.com/pytorch/pytorch/issues/43016

https://github.com/pytorch/pytorch/issues/28331

你可能感兴趣的:(模型部署,加速,优化)