Pytorch Tutoriais (PROTOTYPE) FX GRAPH MODE QUANTIZATION USER GUIDE

(PROTOTYPE) FX GRAPH MODE QUANTIZATION USER GUIDE

Tutorials > (prototype) FX Graph Mode Quantization User Guide

doc : (prototype) FX Graph Mode Quantization User Guide — PyTorch Tutorials 1.11.0+cu102 documentation

Author: Jerry Zhang

2022年5月24日

tag : 翻译学习

topic : Pytorch 量化


(PROTOTYPE) FX GRAPH MODE QUANTIZATION USER GUIDE

(文章疑似未完成)

​ FX GRAPH MODE QUANTIZATION 需要一个symbolically traceable符号可追溯的模型。我们使用FX框架(TODO:link)来转换一个符号可追溯的nn。模块实例到 IR,我们在 IR 上操作以执行量化传递。在 PyTorch 讨论论坛中发布有关符号跟踪模型的问题。

​ FX GRAPH MODE QUANTIZATION 量化仅适用于模型的符号可追溯部分。数据相关的控制流(如果语句/for循环等使用符号跟踪值)是一种不支持的常见模式。如果您的模型在符号上无法端到端地追踪,则有几个选项可以仅在模型的某个部分上启用 FX 图形模式量化。您可以使用以下选项的任意组合:

  1. 不可追溯的代码不需要量化

    • 仅以符号方式跟踪需要量化的代码

    • 跳过不可跟踪代码的符号跟踪

  2. 不可追溯的代码需要量化

    • 重构代码以使其在符号上可追溯

    • 编写自己的observed和量化子模块

1.a. Symbolically trace only the code that needs to be quantized

​ 当整个模型在符号上不可追溯,但我们想要量化的子模块在符号上是可追溯的时,我们只能在该子模块上运行量化

前:(文章疑似未完成)

后:

量化代码:

qconfig_dict = {"": qconfig}
model_fp32.traceable_submodule = \
  prepare_fx(model_fp32.traceable_submodule, qconfig_dict)

​ 注意,如果需要保留原始模型,则必须在调用量化 API 之前自行复制。(考虑此处Inplace或固定为True)

​ 当我们在模块中有一些不可跟踪的代码,并且这部分代码不需要量化时,我们可以将这部分代码分解成一个子模块,并符号化地跳过跟踪该子模块。

前:

class M(nn.Module):

    def forward(self, x):
        x = self.traceable_code_1(x)
        x = non_traceable_code(x)
        x = self.traceable_code_2(x)
        return x

后,不可追溯的部件移动到模块并标记为叶子。non-traceable parts moved to a module and marked as a leaf

class FP32NonTraceable(nn.Module):

    def forward(self, x):
        x = non_traceable_code(x)
        return x

class M(nn.Module):

    def __init__(self):
        ...
        self.non_traceable_submodule = FP32NonTraceable(...)

    def forward(self, x):
        x = self.traceable_code_1(x)
        # we will configure the quantization call to not trace through
        # this submodule
        x = self.non_traceable_submodule(x)
        x = self.traceable_code_2(x)
        return x

量化代码:

qconfig_dict = {"": qconfig}

prepare_custom_config_dict = {
    # option 1
    "non_traceable_module_name": "non_traceable_submodule",
    # option 2
    "non_traceable_module_class": [MNonTraceable],
}
model_prepared = prepare_fx(
    model_fp32,
    qconfig_dict,
    prepare_custom_config_dict=prepare_custom_config_dict,
)

如果无法进行符号可追溯的代码需要量化,我们有以下两个选项:

​ 如果很容易重构代码并使代码在符号上可跟踪,我们可以重构代码并删除python中不可跟踪构造的使用。

​ 有关符号跟踪支持的详细信息,请参阅:(TODO:链接)

​ 前:

def transpose_for_scores(self, x):
    new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
    x = x.view(*new_x_shape)
    return x.permute(0, 2, 1, 3)

​ 这在符号上是不可追踪的,因为在``x.view([*](https://pytorch.org/tutorials/prototype/fx_graph_mode_quant_guide.html#id1)new_x_shape)中不支持解包,但是,由于x.view还支持list input`列表输入,因此很容易删除解包。

后:

def transpose_for_scores(self, x):
    new_x_shape = x.size()[:-1] + (self.num_attention_heads, self.attention_head_size)
    x = x.view(new_x_shape)
    return x.permute(0, 2, 1, 3)

量化代码:

这可以与其他方法结合使用,量化代码取决于模型。

如果不可跟踪的代码无法重构为符号可跟踪,例如它有一些无法消除的循环,例如nn.LSTM,我们需要将不可追溯的代码分解到子模块中(我们在fx图模式量化中称之为CustomModule),并定义子模块的观察和量化版本(在训练后静态量化或静态量化的量化感知训练中)或定义量化版本(在训练后动态和权重量化中)

前:

class M(nn.Module):

    def forward(self, x):
        x = traceable_code_1(x)
        x = non_traceable_code(x)
        x = traceable_code_1(x)
        return x

后:

  1. 将non_traceable_code分解为FP32不可追踪的不可追溯逻辑,包装在模块中
class FP32NonTraceable:
    ...
  1. 定义 FP32 的观测版本不可追踪
class ObservedNonTraceable:

    @classmethod
    def from_float(cls, ...):
        ...
  1. 定义 FP32NonTraceable 的静态量化版本和一个类方法“from_observed”,以从 ObservedNonTraceable 转换为 StaticQuantNonTraceable
class StaticQuantNonTraceable:

    @classmethod
    def from_observed(cls, ...):
        ...
# refactor parent class to call FP32NonTraceable
class M(nn.Module):

   def __init__(self):
        ...
        self.non_traceable_submodule = FP32NonTraceable(...)

    def forward(self, x):
        x = self.traceable_code_1(x)
        # this part will be quantized manually
        x = self.non_traceable_submodule(x)
        x = self.traceable_code_1(x)
        return x

量化代码:

# post training static quantization or
# quantization aware training (that produces a statically quantized module)v
prepare_custom_config_dict = {
    "float_to_observed_custom_module_class": {
        "static": {
            FP32NonTraceable: ObservedNonTraceable,
        }
    },
}

model_prepared = prepare_fx(
    model_fp32,
    qconfig_dict,
    prepare_custom_config_dict=prepare_custom_config_dict)

校准 / 训练 calibrate / train (not shown)

convert_custom_config_dict = {
    "observed_to_quantized_custom_module_class": {
        "static": {
            ObservedNonTraceable: StaticQuantNonTraceable,
        }
    },
}
model_quantized = convert_fx(
    model_prepared,
    convert_custom_config_dict)

训练后的动态/权重在这两种模式下的量化我们不需要观察原始模型,所以我们只需要定义量化模型

class DynamicQuantNonTraceable: # or WeightOnlyQuantMNonTraceable
   ...
   @classmethod
   def from_observed(cls, ...):
       ...

   prepare_custom_config_dict = {
       "non_traceable_module_class": [
           FP32NonTraceable
       ]
   }
# The example is for post training quantization
model_fp32.eval()
model_prepared = prepare_fx(
    model_fp32,
    qconfig_dict,
    prepare_custom_config_dict=prepare_custom_config_dict)

convert_custom_config_dict = {
    "observed_to_quantized_custom_module_class": {
        "dynamic": {
            FP32NonTraceable: DynamicQuantNonTraceable,
        }
    },
}
model_quantized = convert_fx(
    model_prepared,
    convert_custom_config_dict)

还可以在 中找到测试中的自定义模块的示例。test_custom_module_class torch/test/quantization/test_quantize_fx.py

你可能感兴趣的:(#,Pytorch,相关,pytorch)