PyTorch量化报错后端不匹配

环境:PyTorch-1.7.1
错误描述:使用PyTorch Quantization包进行量化感知训练(QAT)时,最后一步convert报错:

Traceback (most recent call last):
  File "train.py", line 136, in <module>
    main()
  File "train.py", line 126, in main
    quantized_model = torch.quantization.convert(model.eval(), inplace=False)
  File "/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/quantization/quantize.py", line 414, in convert
    _convert(module, mapping, inplace=True)
  File "/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/quantization/quantize.py", line 458, in _convert
    _convert(mod, mapping, inplace=True)
  File "/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/quantization/quantize.py", line 459, in _convert
    reassign[name] = swap_module(mod, mapping)
  File "/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/quantization/quantize.py", line 485, in swap_module
    new_mod = mapping[type(mod)].from_float(mod)
  File "/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py", line 368, in from_float
    return cls.get_qconv(mod, activation_post_process, weight_post_process)
  File "/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/nn/quantized/modules/conv.py", line 157, in get_qconv
    qweight = _quantize_weight(mod.weight.float(), weight_post_process)
  File "/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/nn/quantized/modules/utils.py", line 16, in _quantize_weight
    wt_scale.to(torch.double), wt_zp.to(torch.int64), wt_axis, torch.qint8)
RuntimeError: Could not run 'aten::quantize_per_channel' with arguments from the 'CUDA' backend. 'aten::quantize_per_channel' is only available for these backends: [CPU, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].

CPU: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/build/aten/src/ATen/CPUType.cpp:2127 [kernel]
BackendSelect: fallthrough registered at /opt/conda/conda-bld/pytorch_1607370141920/work/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
AutogradOther: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradCPU: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradCUDA: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradXLA: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradPrivateUse1: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradPrivateUse2: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
AutogradPrivateUse3: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/VariableType_2.cpp:8078 [autograd kernel]
Tracer: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/torch/csrc/autograd/generated/TraceType_2.cpp:9654 [kernel]
Autocast: fallthrough registered at /opt/conda/conda-bld/pytorch_1607370141920/work/aten/src/ATen/autocast_mode.cpp:254 [backend fallback]
Batched: registered at /opt/conda/conda-bld/pytorch_1607370141920/work/aten/src/ATen/BatchingRegistrations.cpp:511 [backend fallback]
VmapMode: fallthrough registered at /opt/conda/conda-bld/pytorch_1607370141920/work/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]

(pytorch-1.7.1) ➜  CIFAR-10 python train.py
Files already downloaded and verified
Files already downloaded and verified
/data/yyl/anaconda3/envs/pytorch-1.7.1/lib/python3.7/site-packages/torch/quantization/observer.py:121: UserWarning: Please use quant_min and quant_max to specify the range for observers.                     reduce_range will be deprecated in a future release of PyTorch.
  reduce_range will be deprecated in a future release of PyTorch."

解决方案:我的模型训练过程在cuda上完成,而量化支持的是cpu后端,因此需要先将模型转到cpu上再量化:

quantized_model = torch.quantization.convert(model.cpu().eval(), inplace=False)

你可能感兴趣的:(pytorch,后端,深度学习)