获取config
quantization_config = torch.quantization.get_default_qat_qconfig("fbgemm")
获取config的函数定义如下,有两种方式,fbgemm是per_channel的,qnnpack是逐层的
def get_default_qat_qconfig(backend='fbgemm'):
# Histogram observer is too slow for quantization aware training
if backend == 'fbgemm':
qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
quant_min=0,
quant_max=255,
reduce_range=True),
weight=default_per_channel_weight_fake_quant)
elif backend == 'qnnpack':
qconfig = QConfig(activation=FakeQuantize.with_args(observer=MovingAverageMinMaxObserver,
quant_min=0,
quant_max=255,
reduce_range=False),
weight=default_weight_fake_quant)
else:
qconfig = default_qat_qconfig
return qconfig
其中调用了with_args函数,在源码中可以看到
with_args = classmethod(_with_args)
with_args是classmethod,查看_with_args的定义
def _with_args(cls_or_self, **kwargs):
r"""Wrapper that allows creation of class factories.
This can be useful when there is a need to create classes with the same
constructor arguments, but different instances.
Example::
>>> Foo.with_args = classmethod(_with_args)
>>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42)
>>> foo_instance1 = foo_builder()
>>> foo_instance2 = foo_builder()
>>> id(foo_instance1) == id(foo_instance2)
False
"""
class _PartialWrapper(object):
def __init__(self, p):
self.p = p
def __call__(self, *args, **keywords):
return self.p(*args, **keywords)
def __repr__(self):
return self.p.__repr__()
with_args = _with_args
r = _PartialWrapper(partial(cls_or_self, **kwargs))
return r
可以看出,_with_args是一个装饰器,允许创建class factory。参照例子
>>> Foo.with_args = classmethod(_with_args)
>>> foo_builder = Foo.with_args(a=3, b=4).with_args(answer=42)
foo_builder 为类工厂,其创建的类有a,b,answer属性。
接口调用
torch.quantization.prepare_qat(quantized_model, inplace=True)
prepare_qat函数将原来的浮点模型,替换为插入了伪定点算子的新模型,函数如下
def prepare_qat(model, mapping=None, inplace=False):
r"""
Prepares a copy of the model for quantization calibration or
quantization-aware training and converts it to quantized version.
Quantization configuration should be assigned preemptively
to individual submodules in `.qconfig` attribute.
Args:
model: input model to be modified in-place
mapping: dictionary that maps float modules to quantized modules to be
replaced.
inplace: carry out model transformations in-place, the original module
is mutated
"""
if mapping is None:
mapping = get_qat_module_mappings()
if not inplace:
model = copy.deepcopy(model)
propagate_qconfig_(model, qconfig_dict=None)
convert(model, mapping=mapping, inplace=True, remove_qconfig=False)
prepare(model, observer_non_leaf_module_list=set(mapping.values()), inplace=True)
return model
获取浮点module和用于替换的定点module映射:
获取函数
def get_qat_module_mappings():
''' Get module mapping for quantization aware training
'''
return QAT_MODULE_MAPPINGS
映射关系
# Map for swapping float module to qat modules
QAT_MODULE_MAPPINGS = {
nn.Linear: nnqat.Linear,
nn.Conv2d: nnqat.Conv2d,
# Intrinsic modules:
nni.ConvBn2d: nniqat.ConvBn2d,
nni.ConvBnReLU2d: nniqat.ConvBnReLU2d,
nni.ConvReLU2d: nniqat.ConvReLU2d,
nni.LinearReLU: nniqat.LinearReLU
}
其中
ConvBn2d、ConvBnReLU2d等是算子融合的module,在量化训练前,先进行算子融合
在prepare_qat函数,调用convert(model, mapping=mapping, inplace=True, remove_qconfig=False),进而继续调用量化模块的from_float方法,nnqat.Linear模块的from_float方法如下
@classmethod
def from_float(cls, mod):
r"""Create a qat module from a float module or qparams_dict
Args: `mod` a float module, either produced by torch.quantization utilities
or directly from user
"""
assert type(mod) == cls._FLOAT_MODULE, ' qat.' + cls.__name__ + '.from_float only works for ' + \
cls._FLOAT_MODULE.__name__
assert hasattr(mod, 'qconfig'), 'Input float module must have qconfig defined'
assert mod.qconfig, 'Input float module must have a valid qconfig'
if type(mod) == LinearReLU:
mod = mod[0]
qconfig = mod.qconfig
qat_linear = cls(mod.in_features, mod.out_features, bias=mod.bias is not None, qconfig=qconfig)
qat_linear.weight = mod.weight
qat_linear.bias = mod.bias
return qat_linear
此方法,会构造qat_linear类实例
来看看nnqat.Linear模块
初始化函数
class Linear(nn.Linear):
r"""
A linear module attached with FakeQuantize modules for weight,
used for quantization aware training.
We adopt the same interface as `torch.nn.Linear`, please see
https://pytorch.org/docs/stable/nn.html#torch.nn.Linear
for documentation.
Similar to `torch.nn.Linear`, with FakeQuantize modules initialized to
default.
Attributes:
weight: fake quant module for weight
"""
_FLOAT_MODULE = nn.Linear
def __init__(self, in_features, out_features, bias=True,
qconfig=None):
super(Linear, self).__init__(in_features, out_features, bias)
assert qconfig, 'qconfig must be provided for QAT module'
self.qconfig = qconfig
self.weight_fake_quant = qconfig.weight()
其中qconfig为一个namedtuple,有weight和activation熟悉,weight属性为一个构造函数,qconfig.weight()构造了fake_quant类实例,self.weight_fake_quant指向了它。
nnqat.Linear模块的foward函数为
def forward(self, input):
return F.linear(input, self.weight_fake_quant(self.weight), self.bias)
再来看看FakeQuantize类,其forward函数:
def forward(self, X):
if self.observer_enabled[0] == 1:
self.activation_post_process(X.detach())
_scale, _zero_point = self.calculate_qparams()
_scale, _zero_point = _scale.to(self.scale.device), _zero_point.to(self.zero_point.device)
self.scale.resize_(_scale.shape)
self.scale.copy_(_scale)
self.zero_point.resize_(_zero_point.shape)
self.zero_point.copy_(_zero_point)
if self.fake_quant_enabled[0] == 1:
if self.qscheme == torch.per_channel_symmetric or self.qscheme == torch.per_channel_affine:
X = torch.fake_quantize_per_channel_affine(X, self.scale, self.zero_point,
self.ch_axis, self.quant_min, self.quant_max)
else:
X = torch.fake_quantize_per_tensor_affine(X, float(self.scale),
int(self.zero_point), self.quant_min,
self.quant_max)
return X
如果fake_quant_enabled为1,则调用fake_quantize_per_channel_affine或者fake_quantize_per_tensor_affine实现。