Pytorch Post-training Static Quantization 和 Quantization Aware Training 加载模型

  1. Post-training Static Quantization
self.model.eval()
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
load_model_weight(self.model, checkpoint)
self.model.qconfig = torch.quantization.get_default_qconfig(self.qconfig_name)
fuse_module(self.model)
torch.quantization.prepare(self.model, inplace=True)
dummy_input = torch.randn(1, 3, *self.cfg.data.eval.pipeline.input_size).to(self.device)
_ = self.model(dummy_input)
self.model.apply(torch.quantization.disable_observer)
torch.quantization.convert(self.model, inplace=True)

这种情况下模型是在正常浮点模式下训练的,注意在推理的时候要在前后module的forward头尾加上QuantStub, DeQuantStub

  1. 加载QAT模型
self.model.qconfig = torch.quantization.get_default_qat_qconfig(self.qconfig_name)
self.model.train()
fuse_module(self.model)
torch.quantization.prepare_qat(self.model, inplace=True)
dummy_input = torch.randn(1, 3, *self.cfg.data.eval.pipeline.input_size).to(self.device)
_ = self.model(dummy_input)
self.model.eval()
checkpoint = torch.load(checkpoint_path, map_location=lambda storage, loc: storage)
load_model_weight(self.model, checkpoint)
self.model.apply(torch.quantization.disable_observer)
self.model = torch.quantization.convert(self.model)

这种情况下,模型是QAT训练的,用QAT的模式加载

你可能感兴趣的:(Pytorch Post-training Static Quantization 和 Quantization Aware Training 加载模型)