mmrazor蒸馏中的无插入式提取蒸馏位点的方式和跨库调用的使用

mmrazor蒸馏中的无插入式提取蒸馏位点的方式:torch模型的子模型module.register_forward_hook(将数据保存下来的一个可以调用的函数)位点的提取特征图

调用的记录数据的函数接口需要有三个参数:model、input、和output

import torch
from typing import Tuple, Any

def forward_hook(module: torch.nn.Module, inputs: Tuple,  outputs: Any):
    data_buffer.append(outputs)


data_buffer = list()
input = torch.randn(1, 3, 50, 50)
model = torch.nn.Sequential(
    torch.nn.Conv2d(3, 10, 3, 1, 1),
    torch.nn.Conv2d(10, 5, 3, 1, 1),
    torch.nn.Conv2d(5, 2, 3, 2, 1),
    torch.nn.Conv2d(2, 4, 3, 1, 1)
)

model_record_name = '2'
for name, module in model.named_modules():
    if name == model_record_name:  # 查找模型中命名的模型部分和记录的蒸馏位点名称是否一致,如果一致,就...
        module.register_forward_hook(forward_hook)  # module.register_forward_hook是pytorch里的东西
        break

output = model(input)
print(data_buffer[0].shape)

可以使用这种方式进行跨库调用
在这里插入图片描述
跨库调用的demo

from mmengine.hub import get_model
from mmengine.config import Config

cfg = {'cfg_path': 'mmseg::all_changed/baseline-convnext-tiny_upernet-rotate.py'}
cfg = Config._dict_to_config_dict(cfg)


if cfg.get('cfg_path', None) and not cfg.get('type', None):
    model = get_model(**cfg)
    
print(type(model)) 

for name, module in model.named_modules():
    print(name)

demo中的model.named_modules()可以用来查看可以使用的蒸馏位点名称,放入到配置文件的source参数中
mmrazor蒸馏中的无插入式提取蒸馏位点的方式和跨库调用的使用_第1张图片

在这里插入图片描述

上下文形式,记录蒸馏位点数据

import torch.nn as nn
import torch
from typing import Type

class Registry:
    def __init__(self) -> None:
        self.data_buffer = list()
        
    def __enter__(self, ):
        self._data_buffer = list()
        
    def record_data_hook(self, model: nn.Module, input: Type, output: Type):
        self.data_buffer.append(output)        
        
    def __exit__(self, *args, **kwargs):
        pass


input = torch.randn(16, 3, 512, 512)
model = torch.nn.Sequential(
    torch.nn.Conv2d(3, 10, 3, 1, 1),
    torch.nn.Conv2d(10, 5, 3, 1, 1),
    torch.nn.Conv2d(5, 2, 3, 2, 1),
    torch.nn.Conv2d(2, 4, 3, 1, 1)
)

registry = Registry()

source = '2'
for name, module in model.named_modules():
    if name == source:
        module.register_forward_hook(registry.record_data_hook)
        break

with registry:              # 进入时清空;前向传播时记录数据到data_buffer
    _ = model(input)
    
    
print("拿到了forward时特定位点的特征图: {}".format(registry.data_buffer[0].shape))

你可能感兴趣的:(MMLab,pytorch,深度学习,python)