网上很多关于量化的文章,要么就是跑一跑官方残缺的例子,要么就是过旧的API,早已经不潮流。现在比较fashion的方式,是使用 torch.fx
来做量化。本文将使用100行代码,极简的教你入门比较标准的量化步骤。这些步骤不是简单的告诉你torch.fx有什么卵用,大家都知道它有什么卵用,只是怎么用,用在哪里,哪里不能用,这些问题需要解答。本文100行代码,麻雀虽小五脏俱全,不管你量化什么模型,一顿套用就是了,出了问题我背锅。
很多古老的文章,还在用手动插入stub来做量化节点,这就好比在21世纪还在飞鸽传书。我们必然会包含一下几个完整的内容:
以上问题,本文都将囊括。
此处省略三万字,具体大家清百度。没啥好讲的。
如果你要问我现在最好的量化工具是什么,我的回答是没有。真的,不管是 nni,还是 nvidia的 pytorch_quantization ,还是nncf so on,不是说这些东西不好,而是在做的各位都是垃圾。
这些东西本质上是在做一件事情,至少从量化角度上看是这样的,但是到最后不具备通用性,当你看到 pytorch_quanzation 这个工具保存的模型体积根float32一样的时候,就会开始怀疑人生了,这tm是人干的事儿?这就好比普通人想要中杯,他便要说这是大杯。
轮子不好用,那就只能自己造轮子了。只能说,torch.fx
yyds. 用了都说好,谁用谁知道。
talk is cheap,我们直接上代码。需要注意的是,torch.fx
最好使用最新的stable版本,老版本API或有不同之处,我测试的是 `1.11`。
由于pytorch的自带的 imagnet系列模型,我们没有办法做calibration,我们用小一些的Cifra10, 不需要下载,pytorch自己可以处理,但是这就需要我们自己finetune一下。
先把finetune的代码备好:
这只是用来fintune一个我们准备去量化,并且校准的模型:
import torch import torch.nn as nn import torch.nn.functional as F import copy import torchvision from torchvision import transforms from torchvision.models.resnet import resnet50, resnet18 from torch.quantization.quantize_fx import prepare_fx, convert_fx from torch.ao.quantization.fx.graph_module import ObservedGraphModule from torch.quantization import ( get_default_qconfig, ) from torch import optim import os import time def train_model(model, train_loader, test_loader, device): # The training configurations were not carefully selected. learning_rate = 1e-2 num_epochs = 20 criterion = nn.CrossEntropyLoss() model.to(device) # It seems that SGD optimizer is better than Adam optimizer for ResNet18 training on CIFAR10. optimizer = optim.SGD( model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5 ) # optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False) for epoch in range(num_epochs): # Training model.train() running_loss = 0 running_corrects = 0 for inputs, labels in train_loader: inputs = inputs.to(device) labels = labels.to(device) # zero the parameter gradients optimizer.zero_grad() # forward + backward + optimize outputs = model(inputs) _, preds = torch.max(outputs, 1) loss = criterion(outputs, labels) loss.backward() optimizer.step() # statistics running_loss += loss.item() * inputs.size(0) running_corrects += torch.sum(preds == labels.data) train_loss = running_loss / len(train_loader.dataset) train_accuracy = running_corrects / len(train_loader.dataset) # Evaluation model.eval() eval_loss, eval_accuracy = evaluate_model( model=model, test_loader=test_loader, device=device, criterion=criterion ) print( "Epoch: {:02d} Train Loss: {:.3f} Train Acc: {:.3f} Eval Loss: {:.3f} Eval Acc: {:.3f}".format( epoch, train_loss, train_accuracy, eval_loss, eval_accuracy ) ) return model def prepare_dataloader(num_workers=8, train_batch_size=128, eval_batch_size=256): train_transform = transforms.Compose( [ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ] ) test_transform = transforms.Compose( [ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)), ] ) train_set = torchvision.datasets.CIFAR10( root="data", train=True, download=True, transform=train_transform ) # We will use test set for validation and test in this project. # Do not use test set for validation in practice! test_set = torchvision.datasets.CIFAR10( root="data", train=False, download=True, transform=test_transform ) train_sampler = torch.utils.data.RandomSampler(train_set) test_sampler = torch.utils.data.SequentialSampler(test_set) train_loader = torch.utils.data.DataLoader( dataset=train_set, batch_size=train_batch_size, sampler=train_sampler, num_workers=num_workers, ) test_loader = torch.utils.data.DataLoader( dataset=test_set, batch_size=eval_batch_size, sampler=test_sampler, num_workers=num_workers, ) return train_loader, test_loader
然后训练一波模型:
if __name__ == "__main__": train_loader, test_loader = prepare_dataloader() # first finetune model on cifar, we don't have imagnet so using cifar as test model = resnet18(pretrained=True) model.fc = nn.Linear(512, 10) if os.path.exists("r18_row.pth"): model.load_state_dict(torch.load("r18_row.pth", map_location="cpu")) else: train_model(model, train_loader, test_loader, torch.device("cuda")) print("train finished.") torch.save(model.state_dict(), "r18_row.pth")
接下来就是核心代码:
def quant_fx(model): model.eval() qconfig = get_default_qconfig("fbgemm") qconfig_dict = { "": qconfig, # 'object_type': [] } model_to_quantize = copy.deepcopy(model) prepared_model = prepare_fx(model_to_quantize, qconfig_dict) print("prepared model: ", prepared_model) quantized_model = convert_fx(prepared_model) print("quantized model: ", quantized_model) torch.save(model.state_dict(), "r18.pth") torch.save(quantized_model.state_dict(), "r18_quant.pth")
懂了吗?很快阿,啪一下,一个int8的量化模型就生成了。
没错,其实都不用100行,15行就够了。torch.fx 就是这么的牛逼!
我们做一个evaluation,来验证一下,在不校准的情况下,精度如何:
def evaluate_model(model, test_loader, device=torch.device("cpu"), criterion=None): t0 = time.time() model.eval() model.to(device) running_loss = 0 running_corrects = 0 for inputs, labels in test_loader: inputs = inputs.to(device) labels = labels.to(device) outputs = model(inputs) _, preds = torch.max(outputs, 1) if criterion is not None: loss = criterion(outputs, labels).item() else: loss = 0 # statistics running_loss += loss * inputs.size(0) running_corrects += torch.sum(preds == labels.data) eval_loss = running_loss / len(test_loader.dataset) eval_accuracy = running_corrects / len(test_loader.dataset) t1 = time.time() print(f"eval loss: {eval_loss}, eval acc: {eval_accuracy}, cost: {t1 - t0}") return eval_loss, eval_accuracy
这是evaluation的结果:
eval loss: 0.0, eval acc: 0.8476999998092651, cost: 2.8914074897766113 eval loss: 0.0, eval acc: 0.15240000188350677, cost: 1.240293264389038
可以看到,精度下降严重。此时需要进行一下校准,我直接放校准函数:
def calib_quant_model(model, calib_dataloader): assert isinstance( model, ObservedGraphModule ), "model must be a perpared fx ObservedGraphModule." model.eval() with torch.inference_mode(): for inputs, labels in calib_dataloader: model(inputs) print("calib done.")
that's all. 就这么简单。
如果你有其他非分类模型,也可以直接把dataloader丢进来。请注意,这里的标签并没有用到。只需要统计数据的分布即可。
非常简单。
最后我们再次eval一下:
def quant_calib_and_eval(model): # test only on CPU model.to(torch.device("cpu")) model.eval() qconfig = get_default_qconfig("fbgemm") qconfig_dict = { "": qconfig, # 'object_type': [] } model2 = copy.deepcopy(model) model_prepared = prepare_fx(model2, qconfig_dict) model_int8 = convert_fx(model_prepared) model_int8.load_state_dict(torch.load("r18_quant.pth")) model_int8.eval() a = torch.randn([1, 3, 224, 224]) o1 = model(a) o2 = model_int8(a) diff = torch.allclose(o1, o2, 1e-4) print(diff) print(o1.shape, o2.shape) print(o1, o2) get_output_from_logits(o1) get_output_from_logits(o2) train_loader, test_loader = prepare_dataloader() evaluate_model(model, test_loader) evaluate_model(model_int8, test_loader) # calib quant model model2 = copy.deepcopy(model) model_prepared = prepare_fx(model2, qconfig_dict) model_int8 = convert_fx(model_prepared) torch.save(model_int8.state_dict(), "r18.pth") model_int8.eval() model_prepared = prepare_fx(model2, qconfig_dict) calib_quant_model(model_prepared, test_loader) model_int8 = convert_fx(model_prepared) torch.save(model_int8.state_dict(), "r18_quant_calib.pth") evaluate_model(model_int8, test_loader)
得到结果:
eval loss: 0.0, eval acc: 0.8476999998092651, cost: 2.8914074897766113 eval loss: 0.0, eval acc: 0.15240000188350677, cost: 1.240293264389038 calib done. eval loss: 0.0, eval acc: 0.8442999720573425, cost: 1.2966759204864502
精度瞬间恢复了。速度快了超过一半。
ok,我们用几十行代码就完成这个量化过程。并且使用校准,恢复了精度。由此可见fx的强大之处。
抛出一个问题,欢迎留言区解答:
torch.fx
量化的模型,如果export 到onnx并使用其他前推引擎推理。如果你对量化、网络蒸馏压缩感兴趣,可以加入微信群: