【半精度】Pytorch模型加速和减少显存

如标题所示,这是PyTorch框架提供的一个方便好用的trick:开启半精度。直接可以加快运行速度、减少GPU占用,并且只有不明显的accuracy损失

之前做硬件加速的时候,尝试过多种精度的权重和偏置。在FPGA里用8位精度和16位精度去处理MNIST手写数字识别,完全可以达到差不多的准确率,并且可以节省一半的资源消耗。这一思想用到GPU里也是完全可以行通的。即将pytorch默认的32位浮点型都改成16位浮点型。

只需:

model.half()

 注意1:这一步要放在模型载入GPU之前,即放到model.cuda()之前。大概步骤就是:

model.half()
model.cuda()
model.eval()

注意2:模型改为半精度以后,输入也需要改成半精度。步骤大概是:

model.half()
model.cuda()
model.eval()

img = torch.from_numpy(image).float()
img = img.cuda()
img = img.half()

res = model(img)

本地做的测试结果为:速度提升25%~35%,显存节约40~60%,而accuracy几乎没变。仅供大家参考。

你可能感兴趣的:(python,Pytorch那些事儿,cuda,gpu,pytorch,模型加速,模型压缩)