调用的记录数据的函数接口需要有三个参数:model、input、和output
import torch
from typing import Tuple, Any
def forward_hook(module: torch.nn.Module, inputs: Tuple, outputs: Any):
data_buffer.append(outputs)
data_buffer = list()
input = torch.randn(1, 3, 50, 50)
model = torch.nn.Sequential(
torch.nn.Conv2d(3, 10, 3, 1, 1),
torch.nn.Conv2d(10, 5, 3, 1, 1),
torch.nn.Conv2d(5, 2, 3, 2, 1),
torch.nn.Conv2d(2, 4, 3, 1, 1)
)
model_record_name = '2'
for name, module in model.named_modules():
if name == model_record_name: # 查找模型中命名的模型部分和记录的蒸馏位点名称是否一致,如果一致,就...
module.register_forward_hook(forward_hook) # module.register_forward_hook是pytorch里的东西
break
output = model(input)
print(data_buffer[0].shape)
from mmengine.hub import get_model
from mmengine.config import Config
cfg = {'cfg_path': 'mmseg::all_changed/baseline-convnext-tiny_upernet-rotate.py'}
cfg = Config._dict_to_config_dict(cfg)
if cfg.get('cfg_path', None) and not cfg.get('type', None):
model = get_model(**cfg)
print(type(model))
for name, module in model.named_modules():
print(name)
demo中的model.named_modules()可以用来查看可以使用的蒸馏位点名称,放入到配置文件的source参数中
import torch.nn as nn
import torch
from typing import Type
class Registry:
def __init__(self) -> None:
self.data_buffer = list()
def __enter__(self, ):
self._data_buffer = list()
def record_data_hook(self, model: nn.Module, input: Type, output: Type):
self.data_buffer.append(output)
def __exit__(self, *args, **kwargs):
pass
input = torch.randn(16, 3, 512, 512)
model = torch.nn.Sequential(
torch.nn.Conv2d(3, 10, 3, 1, 1),
torch.nn.Conv2d(10, 5, 3, 1, 1),
torch.nn.Conv2d(5, 2, 3, 2, 1),
torch.nn.Conv2d(2, 4, 3, 1, 1)
)
registry = Registry()
source = '2'
for name, module in model.named_modules():
if name == source:
module.register_forward_hook(registry.record_data_hook)
break
with registry: # 进入时清空;前向传播时记录数据到data_buffer
_ = model(input)
print("拿到了forward时特定位点的特征图: {}".format(registry.data_buffer[0].shape))