2.1 ppq量化pytorch->onnx

前言

torchvision中加载一个模型,转换为 onnx 格式、导出 quantized graph.

code

from typing import Iterable

import torch
import torchvision
from torch.utils.data import DataLoader

from ppq import BaseGraph, QuantizationSettingFactory, TargetPlatform
from ppq.api import export_ppq_graph, quantize_torch_model

BATCHSIZE = 32
INPUT_SHAPE = [3, 224, 224]
DEVICE = 'cuda' # only cuda is fully tested :(  For other executing device there might be bugs.
PLATFORM = TargetPlatform.PPL_CUDA_INT8  # identify a target platform for your network.

def load_calibration_dataset() -> Iterable:
    return [torch.rand(size=INPUT_SHAPE) for _ in range(32)]

def collate_fn(batch: torch.Tensor) -> torch.Tensor:
    return batch.to(DEVICE)

# Load a pretrained mobilenet v2 model
model = torchvision.models.mobilenet.mobilenet_v2(pretrained=True)
model = model.to(DEVICE)

# create a setting for quantizing your network with PPL CUDA.
quant_setting = QuantizationSettingFactory.pplcuda_setting()
quant_setting.equalization = True # use layerwise equalization algorithm.
quant_setting.dispatcher   = 'conservative' # dispatch this network in conservertive way.

# Load training data for creating a calibration dataloader.
calibration_dataset = load_calibration_dataset()
calibration_dataloader = DataLoader(
    dataset=calibration_dataset,
    batch_size=BATCHSIZE, shuffle=True)

# quantize your model.
quantized = quantize_torch_model(
    model=model, calib_dataloader=calibration_dataloader,
    calib_steps=32, input_shape=[BATCHSIZE] + INPUT_SHAPE,
    setting=quant_setting, collate_fn=collate_fn, platform=PLATFORM,
    onnx_export_file='./onnx.model', device=DEVICE, verbose=0)

# Quantization Result is a PPQ BaseGraph instance.
assert isinstance(quantized, BaseGraph)

# export quantized graph.
export_ppq_graph(graph=quantized, platform=PLATFORM,
                 graph_save_to='./quantized(onnx).onnx',
                 config_save_to='./quantized(onnx).json')

# analyse quantization error brought in by every layer
from ppq.quantization.analyse import layerwise_error_analyse, graphwise_error_analyse
graphwise_error_analyse(
    graph=quantized, # ppq ir graph
    running_device=DEVICE, # cpu or cuda
    method='snr',  # the metric is signal noise ratio by default, adjust it to 'cosine' if that's desired
    steps=32, # how many batches of data will be used for error analysis
    dataloader=calibration_dataloader,
    collate_fn=lambda x: x.to(DEVICE)
)
 
layerwise_error_analyse(
    graph=quantized,
    running_device=DEVICE,
    method='snr',  # the metric is signal noise ratio by default, adjust it to 'cosine' if that's desired
    steps=32,
    dataloader=calibration_dataloader,
    collate_fn=lambda x: x.to(DEVICE)
)

结果

加载预训练的mobilenet v2 model
最终生成三个文件信息

# python eaxmple.py 

      ____  ____  __   ____                    __              __
     / __ \/ __ \/ /  / __ \__  ______ _____  / /_____  ____  / /
    / /_/ / /_/ / /  / / / / / / / __ `/ __ \/ __/ __ \/ __ \/ /
   / ____/ ____/ /__/ /_/ / /_/ / /_/ / / / / /_/ /_/ / /_/ / /
  /_/   /_/   /_____\___\_\__,_/\__,_/_/ /_/\__/\____/\____/_/


[07:04:00] PPQ Layerwise Equalization Pass Running ... 2 equalization pair(s) was found, ready to run optimization.
Layerwise Equalization: 100%|█████████████████████████████████████████████████████████| 10/10 [00:00<00:00, 1274.01it/s]
Finished.
[07:04:00] PPQ Quantization Config Refine Pass Running ... Finished.
[07:04:00] PPQ Quantization Fusion Pass Running ...        Finished.
[07:04:00] PPQ Quantize Point Reduce Pass Running ...      Finished.
[07:04:00] PPQ Parameter Quantization Pass Running ...     Finished.
Calibration Progress(Phase 1): 100%|████████████████████████████████████████████████████| 32/32 [01:59<00:00,  3.74s/it]
[07:04:00] PPQ Runtime Calibration Pass Running ...        Finished.
[07:06:00] PPQ Quantization Alignment Pass Running ...     Finished.
[07:06:00] PPQ Passive Parameter Quantization Running ...  Finished.
[07:06:00] PPQ Parameter Baking Pass Running ...           Finished.
--------- Network Snapshot ---------
Num of Op:                    [100]
Num of Quantized Op:          [100]
Num of Variable:              [277]
Num of Quantized Var:         [277]
------- Quantization Snapshot ------
Num of Quant Config:          [386]
BAKED:                        [53]
OVERLAPPED:                   [125]
SLAVE:                        [20]
ACTIVATED:                    [65]
PASSIVE_BAKED:                [53]
FP32:                         [70]
Network Quantization Finished.
Analysing Graphwise Quantization Error(Phrase 1):: 100%|██████████████████████████████████| 1/1 [00:00<00:00,  8.19it/s]
Analysing Graphwise Quantization Error(Phrase 2):: 100%|██████████████████████████████████| 1/1 [00:00<00:00,  6.84it/s]
Layer     | NOISE:SIGNAL POWER RATIO 
Conv_8:   | ████████████████████ | 1.678653
Conv_26:  | ████████████████     | 1.313450
Conv_9:   | █████████████        | 1.087763
Conv_13:  | █████████████        | 1.074564
Conv_55:  | ████████████         | 0.991271
Conv_28:  | ██████████           | 0.857988
Conv_17:  | ████████             | 0.730895
Conv_154: | ████████             | 0.676669
Conv_22:  | ████████             | 0.659212
Conv_152: | ███████              | 0.618322
Conv_142: | ███████              | 0.582268
Conv_133: | ██████               | 0.534554
Conv_45:  | ██████               | 0.520580
Conv_51:  | █████                | 0.464549
Conv_144: | █████                | 0.441523
Conv_41:  | █████                | 0.414612
Conv_36:  | █████                | 0.411636
Conv_57:  | ████                 | 0.387930
Conv_113: | ████                 | 0.368550
Conv_148: | ████                 | 0.351853
Conv_123: | ████                 | 0.333928
Conv_104: | ████                 | 0.331407
Conv_134: | ███                  | 0.319796
Conv_4:   | ███                  | 0.309758
Conv_138: | ███                  | 0.274523
Conv_125: | ███                  | 0.272312
Conv_32:  | ███                  | 0.269519
Conv_94:  | ███                  | 0.255700
Conv_47:  | ███                  | 0.255035
Conv_129: | ███                  | 0.246983
Conv_18:  | ██                   | 0.222586
Conv_65:  | ██                   | 0.205310
Conv_162: | ██                   | 0.190181
Conv_84:  | ██                   | 0.189721
Conv_90:  | ██                   | 0.183772
Conv_96:  | ██                   | 0.181663
Conv_70:  | ██                   | 0.174435
Conv_163: | ██                   | 0.167765
Conv_115: | ██                   | 0.164749
Conv_100: | █                    | 0.152931
Conv_86:  | █                    | 0.150768
Conv_105: | █                    | 0.148656
Conv_80:  | █                    | 0.134689
Conv_109: | █                    | 0.131509
Conv_37:  | █                    | 0.124499
Conv_119: | █                    | 0.122543
Conv_74:  | █                    | 0.096819
Conv_76:  |                      | 0.072862
Conv_61:  |                      | 0.071023
Conv_0:   |                      | 0.067830
Conv_66:  |                      | 0.064776
Gemm_169: |                      | 0.035677
Conv_158: |                      | 0.032427
Analysing Layerwise quantization error:: 100%|██████████████████████████████████████████| 53/53 [00:01<00:00, 34.86it/s]
Layer     | NOISE:SIGNAL POWER RATIO 
Conv_4:   | ████████████████████ | 0.007448
Conv_22:  | ███                  | 0.001254
Conv_133: | ███                  | 0.000973
Conv_142: | █                    | 0.000488
Conv_152: | █                    | 0.000487
Conv_162: | █                    | 0.000420
Conv_104: | █                    | 0.000372
Conv_8:   | █                    | 0.000269
Conv_65:  | █                    | 0.000214
Conv_113: | █                    | 0.000204
Conv_123: | █                    | 0.000190
Conv_13:  |                      | 0.000183
Conv_41:  |                      | 0.000142
Conv_17:  |                      | 0.000136
Conv_26:  |                      | 0.000113
Conv_36:  |                      | 0.000108
Conv_70:  |                      | 0.000096
Conv_94:  |                      | 0.000092
Conv_109: |                      | 0.000078
Conv_100: |                      | 0.000068
Conv_125: |                      | 0.000066
Conv_119: |                      | 0.000065
Gemm_169: |                      | 0.000064
Conv_55:  |                      | 0.000060
Conv_84:  |                      | 0.000058
Conv_138: |                      | 0.000053
Conv_80:  |                      | 0.000051
Conv_28:  |                      | 0.000050
Conv_45:  |                      | 0.000043
Conv_57:  |                      | 0.000036
Conv_90:  |                      | 0.000034
Conv_105: |                      | 0.000032
Conv_32:  |                      | 0.000032
Conv_144: |                      | 0.000031
Conv_74:  |                      | 0.000027
Conv_96:  |                      | 0.000027
Conv_115: |                      | 0.000026
Conv_154: |                      | 0.000026
Conv_51:  |                      | 0.000024
Conv_0:   |                      | 0.000023
Conv_148: |                      | 0.000022
Conv_86:  |                      | 0.000021
Conv_134: |                      | 0.000021
Conv_66:  |                      | 0.000018
Conv_18:  |                      | 0.000015
Conv_76:  |                      | 0.000012
Conv_129: |                      | 0.000010
Conv_47:  |                      | 0.000009
Conv_61:  |                      | 0.000008
Conv_9:   |                      | 0.000006
Conv_163: |                      | 0.000005
Conv_37:  |                      | 0.000004
Conv_158: |                      | 0.000002

你可能感兴趣的:(模型推理,pytorch,人工智能,python)